Compare commits

...

102 Commits

Author SHA1 Message Date
刀刀
9d30ceae8d CC 400 返回具体错误信息 && 非 CC 请求时增加 system prompt (#26)
* feat: http 400 返回具体错误

* 更新 workflows

* 优化打包/docker 构建流程

* 400 是返回 原始错误 - json 格式

* feat: 非 cc请求时补充 system

* go mod tidy
2025-12-25 14:47:19 +08:00
IanShaw
60f6ed6bf6 feat: CRS 同步增强 - 自动刷新 OAuth token 和修复测试配置 (#27)
* fix(service): 修复 OpenAI Responses API 测试负载配置

- 所有账号类型统一添加 instructions 字段(不再仅限 OAuth)
- Responses API 要求所有请求必须包含 instructions 参数

* feat(crs-sync): CRS 同步时自动刷新 OAuth token 并保留完整 extra 字段

**核心功能**:
- CRSSyncService 注入 OAuth 服务依赖(Anthropic + OpenAI)
- 账号创建/更新后自动刷新 OAuth token,确保可用性
- 完整保留 CRS extra 字段,避免数据丢失

**Extra 字段增强**:
- 保留 CRS 所有原始 extra 字段
- 新增同步元数据: crs_account_id, crs_kind, crs_synced_at
- Claude 账号: 从 credentials 提取 org_uuid/account_uuid 到 extra
- OpenAI 账号: 映射 crs_email -> email

**Token 刷新逻辑**:
- 新增 refreshOAuthToken() 方法处理 Anthropic/OpenAI 平台
- 保留原有 credentials 字段,仅更新 token 相关字段
- 刷新失败静默处理,不中断同步流程

**依赖注入**:
- wire_gen.go: CRSSyncService 新增 oAuthService/openaiOAuthService

* style(crs-sync): 使用 switch 替代 if-else 修复 golangci-lint 警告

- 将 refreshOAuthToken 中的 if-else 改为 switch 语句
- 符合 staticcheck 规范
- 添加 default 分支处理未知平台
2025-12-25 14:45:17 +08:00
shaw
4a2f7d4a99 chore: CRS迁移功能增加版本提示 2025-12-25 10:57:04 +08:00
shaw
c19a393be9 Merge PR #24: feat: 添加账户同步与批量编辑功能
- 添加从 CRS 同步账户功能 (Claude OAuth/API Key, OpenAI OAuth/Responses)
- 添加批量编辑账户功能,支持 JSONB 字段智能合并
- 新增 CRSSyncService、BulkUpdate 仓储方法
- 前端新增 SyncFromCrsModal 和 BulkEditAccountModal 组件
2025-12-25 10:44:40 +08:00
ianshaw
938ffb002e style(frontend): format code with prettier
格式化前端业务代码,符合代码规范
- 统一代码风格
- 修复 ESLint 警告
2025-12-24 18:07:58 -08:00
ianshaw
372a01290b fix(backend): handle defer Close() errors in crs_sync_service
修复 golangci-lint 错误检查问题
- 使用匿名函数包装 defer Close() 并忽略错误
- 符合 Go 最佳实践
2025-12-24 17:58:47 -08:00
ianshaw
8b163ca49b chore: trigger CI after enabling Actions 2025-12-24 17:56:55 -08:00
ianshaw
d23810dc53 chore: trigger CI workflow 2025-12-24 17:54:43 -08:00
ianshaw
62ed5422dd feat(account): 优化批量更新实现,使用统一 SQL 合并 JSONB 字段
- 新增 BulkUpdate 仓储方法,使用单条 SQL 更新所有账户
- credentials/extra 使用 COALESCE(...) || ? 合并,只更新传入的 key
- name/proxy_id/concurrency/priority/status 只在提供时更新
- 分组绑定仍逐账号处理(需要独立操作)
- 前端优化:Base URL 留空则不修改,按勾选字段更新
- 完善 i18n 文案:说明留空不修改、批量更新行为
2025-12-24 17:16:19 -08:00
ianshaw
2e76302af7 feat(account): 添加批量编辑账户凭据功能并优化 CRS 同步
- 新增批量更新账户凭据接口(account_uuid/org_uuid/intercept_warmup_requests)
- 新增前端批量编辑模态框组件
- 优化 CRS 同步逻辑,改进 extra 字段处理
- 优化 CRS 同步 UI,添加更详细的结果展示
- 完善国际化文案(中英文)
2025-12-24 16:56:48 -08:00
ianshaw
6553828008 feat(account): 添加从 CRS 同步账户功能
- 添加账户同步 API 接口 (account_handler.go)
- 实现 CRS 同步服务 (crs_sync_service.go)
- 添加前端同步对话框组件 (SyncFromCrsModal.vue)
- 更新账户管理界面支持同步操作
- 添加账户仓库批量创建方法
- 添加中英文国际化翻译
- 更新依赖注入配置
2025-12-24 08:48:58 -08:00
ianshaw
adcb7bf00e chore: 更新 .gitignore 忽略配置文件并还原 Makefile
- 添加 backend/config.yaml 到 .gitignore(包含敏感信息)
- 添加 deploy/config.yaml 到 .gitignore(包含敏感信息)
- 添加 backend/.installed 到 .gitignore
- 还原 Makefile 到原始版本
2025-12-24 08:48:49 -08:00
shaw
876e85e7ad Merge branch 'feat/rename-go-module' 2025-12-24 21:34:37 +08:00
shaw
2e7818d688 feat(settings): 添加文档链接配置功能
- 后台系统设置新增文档链接(doc_url)配置项
- 首页顶部导航栏显示文档链接图标(条件渲染)
- Footer区域添加文档链接和GitHub链接
- 支持中英文国际化
2025-12-24 21:30:19 +08:00
Forest
836c4dda2b refactor: 重命名 go module 2025-12-24 21:07:21 +08:00
shaw
e65e9587b4 fix(concurrency): 重构并发管理使用独立Key+原生TTL
问题:旧方案使用计数器模式,每次acquire都刷新TTL,导致僵尸数据永不过期

解决方案:
- 每个槽位使用独立Redis Key: concurrency:account:{id}:{requestID}
- 利用Redis原生TTL,每个槽位独立5分钟过期
- 服务崩溃后僵尸数据自动清理,无需手动干预
- 兼容多实例K8s部署

技术改动:
- 新增SCAN脚本统计活跃槽位数量
- 移除冗余的releaseScript,直接使用DEL命令
- Wait队列TTL只在首次创建时设置,避免刷新
2025-12-24 21:00:29 +08:00
shaw
aaadd6ed04 fix(dashboard): 修复性能指标 RPM/TPM 显示为0的问题
- 修复 Admin Dashboard Handler 遗漏返回 rpm/tpm 字段
- 将性能统计时间窗口从1分钟改为5分钟平均值,数据更稳定
2025-12-24 19:58:33 +08:00
shaw
870b21916c feat(install): 添加安装指定版本和回退功能
- 新增 rollback 命令支持回退到指定版本
- 新增 list-versions 命令列出可用版本
- 新增 -v/--version 参数指定安装版本
- upgrade 命令支持升级到指定版本
- 添加安装状态检查,未安装时给出明确提示
- 版本切换仅替换二进制文件,保留配置和数据
- 自动备份当前版本(带版本号或时间戳后缀)
- 改进网络错误处理,添加超时和友好提示
- 修复 grep -oP 兼容性问题,改用 grep -oE
2025-12-24 17:44:13 +08:00
shaw
fb119f9a67 fix(version): 优化服务重启后页面刷新时机
- 将重启后等待时间从 3 秒增加到 8 秒
- 添加倒计时显示,提升用户体验
- 倒计时结束后先检测服务健康状态再刷新页面
- 避免刷新过早导致 502 错误
2025-12-24 17:21:17 +08:00
shaw
ad54795a24 feat(gateway): 添加上游错误重试机制
- OAuth/Setup Token 账号遇到 403 错误时,等待 2 秒后重试,最多 3 次
- Console 账号遇到未配置的错误码时,同样进行重试
- 重试耗尽后:OAuth 403 标记账号异常,Console 未配置错误码不标记账号
- 移除 handleErrorResponse 中已被重试逻辑覆盖的死代码
2025-12-24 16:55:46 +08:00
shaw
0abe322cca feat(accounts): 账户列表显示实时并发数
- 在账户列表 API 返回中添加 current_concurrency 字段
- 合并平台和类型列为 PlatformTypeBadge 组件,节省表格空间
- 新增并发状态列,显示 当前/最大 并发数,支持颜色编码
2025-12-24 15:44:45 +08:00
shaw
b071511676 refactor(accounts): 优化用量窗口显示,统一 OAuth 和 Setup Token 处理
- Setup Token 账号现在也调用 API 获取 5h 窗口用量数据
- 重新设计 UsageProgressBar UI,将用量统计移到进度条上方
- 删除冗余的 SetupTokenTimeWindow 组件
- 请求数/Token数支持 K/M/B 单位显示
2025-12-24 10:57:40 +08:00
shaw
7d9a757a26 feat(dashboard): 添加 RPM/TPM 性能指标
在 Dashboard 中用 RPM/TPM 卡片替换原来的"今日缓存"卡片,
实时显示最近1分钟的请求数和 Token 吞吐量。
2025-12-24 10:24:02 +08:00
Forest
bbf4024dc7 refactor(usage): 移动 usage 查询到 services 2025-12-24 08:41:31 +08:00
shaw
5831eb8a6a fix: 修复Claude OAuth token交换时authorization code解析错误
原代码中 `parts` 变量被创建但从未使用,导致 `len(parts) == 0`
永远为 true,使得即使成功从 `code#state` 格式中分割出 authCode,
最后也会被覆盖为原始的完整字符串。

这导致传递给Claude Token端点的code包含了 `#state` 部分,
Claude返回 "Invalid 'code' in request" 错误。
2025-12-23 19:42:52 +08:00
shaw
61838cdb3d fix: 兼容GLM等API的usage数据解析
部分第三方API(如GLM)的SSE响应格式与标准Claude API不同:
- 标准Claude: input_tokens在message_start中
- GLM等API: 所有tokens都在message_delta中

现在从message_delta中也解析input_tokens和cache相关字段,
如果message_start中没有值则使用message_delta中的数据。
2025-12-23 19:42:52 +08:00
dexcoder6
50dba656fd feat: 添加用户余额充值/退款功能 (#17)
## 功能特性

### 前端
- 在用户列表操作列添加充值和退款按钮
- 实现充值/退款对话框,支持输入金额和备注
- 从编辑用户表单中移除余额字段,防止直接修改
- 添加余额不足验证,实时显示操作后余额
- 优化备注提示词,提供多种场景示例

### 后端
- 为 redeem_codes 表添加 notes 字段(迁移文件)
- 在 UpdateUserBalance 接口添加 notes 参数支持
- 添加余额验证:金额必须大于0,操作后余额不能为负
- UpdateUser 接口移除 balance 字段处理,防止误操作
- 完整的审计日志和缓存管理

## 安全保护

- 前端:余额不足时禁用提交按钮,实时提示
- 后端:双重验证(输入金额 > 0 + 结果余额 >= 0)
- 权限:仅管理员可访问(AdminAuth 中间件)
- 审计:所有操作记录到 redeem_codes 表

## 修改文件

后端:
- backend/migrations/004_add_redeem_code_notes.sql
- backend/internal/model/redeem_code.go
- backend/internal/service/admin_service.go
- backend/internal/handler/admin/user_handler.go

前端:
- frontend/src/views/admin/UsersView.vue
- frontend/src/api/admin/users.ts
- frontend/src/i18n/locales/zh.ts
- frontend/src/i18n/locales/en.ts

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-23 16:29:57 +08:00
shaw
0e2821456c chore: 忽略TypeScript增量编译缓存文件 2025-12-23 16:27:56 +08:00
shaw
f25ac3aff5 feat: OpenAI OAuth账号显示Codex使用量
从响应头提取x-codex-*使用量信息并保存到账号Extra字段,
前端账号列表展示5h/7d窗口的使用进度条。
2025-12-23 16:26:07 +08:00
shaw
f6341b7f2b chore: 将"代理管理"菜单更名为"IP管理" 2025-12-23 15:46:10 +08:00
shaw
4e257512b9 style: 统一平台和分组列的样式
- 账号页面平台列改为与分组页面一致的标签样式
- 订阅页面分组列改用 GroupBadge 组件展示
- 修正 OpenAI OAuth 类型描述文案
2025-12-23 15:40:22 +08:00
shaw
e53b34f321 Merge PR #15: feat: 增强用户管理功能,添加用户名、微信号和备注字段 2025-12-23 14:03:07 +08:00
shaw
12ddae0184 fix: 优化OpenAI模型定价查找的回退逻辑
当模型ID在model_pricing.json中找不到时,增加智能回退策略:
- gpt-5.2-codex → 回退到 gpt-5.2
- gpt-5.2-20251222 → 去掉日期后缀回退到 gpt-5.2
- 最终回退到 DefaultTestModel (gpt-5.1-codex)
2025-12-23 13:58:56 +08:00
shaw
7b9c3f165e feat: 账号管理新增使用统计功能
- 新增账号统计弹窗,展示30天使用数据
- 显示总费用、请求数、日均费用、日均请求等汇总指标
- 显示今日概览、最高费用日、最高请求日
- 包含费用与请求趋势图(双Y轴)
- 复用模型分布图组件展示模型使用分布
- 显示实际扣费和标准计费(标准计费以较淡颜色显示)
2025-12-23 13:42:33 +08:00
dexcoder6
0b8e84f942 feat: 增强用户管理功能,添加用户名、微信号和备注字段
- 新增User模型字段:username(用户名)、wechat(微信号)、notes(备注)
- 扩展用户搜索功能,支持通过用户名和微信号搜索
- 添加用户个人资料更新功能,用户可自行编辑用户名和微信号
- 管理员用户列表新增用户名、微信号、备注显示列
- 备注字段仅对管理员可见,增强数据安全性
- 完善中英文国际化翻译
- 修复国际化文件中重复属性的TypeScript错误

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-23 11:26:22 +08:00
shaw
d9e27df9af feat: 账号列表显示所属分组
- Account模型新增Groups虚拟字段
- 账号列表API预加载Group信息
- 账号管理页面新增分组列,使用GroupBadge展示
2025-12-23 11:20:02 +08:00
shaw
f0fabf89a1 feat: 用户列表显示订阅分组及剩余天数
- User模型新增Subscriptions关联
- 用户列表批量加载订阅信息避免N+1查询
- GroupBadge组件支持显示剩余天数(过期红色、<=3天红色、<=7天橙色)
- 用户管理页面新增订阅分组列
2025-12-23 11:03:10 +08:00
shaw
5bbfbcdae9 fix: 修复订阅窗口过期后进度条显示不正确的问题
问题:滑动窗口过期后(如昨天用满额度),前端仍显示历史数据(红色进度条100%、"即将重置")

解决:
- 后端返回数据前检查窗口是否过期,过期则清零展示数据
- 前端处理 window_start 为 null 的情况,显示"窗口未激活"
- 不影响实际的窗口激活逻辑,窗口仍从当天零点开始
2025-12-23 10:38:15 +08:00
shaw
eb55947ec4 fix: 修复golangci-lint检查问题
- 移除OpenAIGatewayHandler中未使用的userService字段
- 将账号类型判断的if-else链改为switch语句
2025-12-23 10:25:32 +08:00
shaw
5f7e5184eb feat: admin/subscriptions新增重置时间显示 2025-12-23 10:14:41 +08:00
shaw
008a111268 chore: 更新前端构建信息 2025-12-23 10:03:34 +08:00
shaw
fda753278c feat: 平台图标与计费修复
- fix(billing): 修复 OpenAI 兼容 API 缓存 token 重复计费问题
- fix(auth): 隐藏数据库错误详情,返回通用服务不可用错误
- feat(ui): 新增 PlatformIcon 组件,GroupBadge 支持平台颜色区分
- feat(ui): 账号管理新增重置状态按钮,重授权后自动清除错误
- feat(ui): 分组管理新增计费类型列,显示订阅限额信息
- ui: 首页 GPT 状态改为已支持
2025-12-23 10:01:58 +08:00
shaw
6c469b42ed feat: 新增支持codex转发 2025-12-22 22:58:31 +08:00
shaw
dacf3a2a6e fix: 去掉accept-encoding透传 2025-12-21 21:30:19 +08:00
shaw
e6add93ae3 fix(build): add -tags embed to ensure frontend is embedded
- Add -tags=embed flag to GoReleaser builds
- Add -tags embed flag to Dockerfile builds
- Fix Dockerfile COPY order to prevent frontend dist being overwritten
- Update README build instructions with embed tag explanation
2025-12-20 19:13:26 +08:00
NepetaLemon
b2273ec695 ci(backend): 修复 backend-ci 2025-12-20 16:52:38 +08:00
Forest
aa89777dda ci(backend): 调整 embed server 2025-12-20 16:44:25 +08:00
Forest
1e1f3c0c74 ci(backend): 添加 gofmt 配置 2025-12-20 16:19:40 +08:00
Forest
1fab9204eb ci(backend): 添加 unused 配置 2025-12-20 16:12:44 +08:00
Forest
dbd3e71637 ci(backend): 添加 staticcheck 配置 2025-12-20 16:01:24 +08:00
Forest
974f67211b ci(backend): 添加 ineffassign 配置 2025-12-20 15:58:08 +08:00
Forest
0338c83b90 ci(backend): 添加 errcheck 配置 2025-12-20 15:52:13 +08:00
NepetaLemon
c6b3de1199 ci(backend): 添加 github actions (#10)
## 变更内容

### CI/CD
- 添加 GitHub Actions 工作流(test + golangci-lint)
- 添加 golangci-lint 配置,启用 errcheck/govet/staticcheck/unused/depguard
- 通过 depguard 强制 service 层不能直接导入 repository

### 错误处理修复
- 修复 CSV 写入、SSE 流式输出、随机数生成等未处理的错误
- GenerateRedeemCode() 现在返回 error

### 资源泄露修复
- 统一使用 defer func() { _ = xxx.Close() }() 模式

### 代码清理
- 移除未使用的常量
- 简化 nil map 检查
- 统一代码格式
2025-12-20 02:29:52 -05:00
shaw
f1325e9ae6 chore: 调整Turnstile设置跳转地址 2025-12-20 15:14:36 +08:00
shaw
587012396b feat: 支持创建管理员APIKEY 2025-12-20 15:11:43 +08:00
shaw
adebd941e1 fix: 修复Oauth账号自动刷新token失败的bug 2025-12-20 13:01:58 +08:00
Wesley Liddick
bb500b7b2a Merge pull request #9 from NepetaLemon/refactor/add-http-service-ports
refactor(backend): service http ports
2025-12-19 23:35:13 -05:00
Forest
cceada7dae refactor(backend): service http ports 2025-12-20 11:57:02 +08:00
shaw
5c2e7ae265 fix: 调整订阅计费时间窗口为每日0点
- 窗口激活/重置时使用当天零点而非精确时间
- 使用服务器本地时区计算零点(支持 UTC+8 等时区)
- 窗口重置时失效 Redis 缓存,避免数据不一致
2025-12-20 11:33:06 +08:00
shaw
420bedd615 Merge PR #8: refactor(backend): 添加 service 缓存端口 2025-12-20 11:05:01 +08:00
shaw
a79f6c5e1e feat: 给所有表格页面增加刷新按钮 2025-12-20 10:48:42 +08:00
shaw
0484c59ead feat: /admin/usage页面增加模型分布情况显示 2025-12-20 10:06:55 +08:00
Forest
7bbf621490 refactor(backend): 添加 service 缓存端口 2025-12-19 23:44:18 +08:00
shaw
ef81aeb463 fix: 修复dashboard页面用户名的显示bug 2025-12-19 22:41:26 +08:00
shaw
22414326cc fix: 修复前端切换页面时logo跟标题闪烁的问题 2025-12-19 22:33:36 +08:00
Wesley Liddick
14b155c66b Merge pull request #7 from NepetaLemon/refactor/ports-pattern
refactor(backend): 引入端口接口模式
2025-12-19 08:29:04 -05:00
Forest
e99b344b2b refactor(backend): 引入端口接口模式 2025-12-19 21:26:19 +08:00
shaw
7fd94ab78b fix: 修复usage页面未显示缓存写入的问题 2025-12-19 16:57:31 +08:00
shaw
078529e51e chore: 更新docker的postgres版本为18 2025-12-19 16:42:03 +08:00
shaw
23a4cf11c8 fix: 设置默认logo作为favicon 2025-12-19 16:41:00 +08:00
shaw
d1f0902ec0 feat(account): 支持账号级别拦截预热请求
- 新增 intercept_warmup_requests 配置项,存储在 credentials 字段
- 启用后,标题生成、Warmup 等预热请求返回 mock 响应,不消耗上游 token
- 前端支持所有账号类型(OAuth、Setup Token、API Key)的开关配置
- 修复 OAuth 凭证刷新时丢失非 token 配置的问题
2025-12-19 16:39:25 +08:00
shaw
ee86dbca9d feat(account): 账号测试支持选择模型
- 新增 GET /api/v1/admin/accounts/:id/models 接口获取账号可用模型
- 账号测试弹窗新增模型选择下拉框
- 测试时支持传入 model_id 参数,不传则默认使用 Sonnet
- API Key 账号支持根据 model_mapping 映射测试模型
- 将模型常量提取到 claude 包统一管理
2025-12-19 16:00:09 +08:00
Wesley Liddick
733d4c2b85 Merge pull request #6 from dexcoder6/main
fix(frontend): 修复移动端菜单栏和使用记录页面 UI 问题
2025-12-19 02:59:05 -05:00
dexcoder6
406d3f3cab fix(frontend): 修复移动端菜单栏和使用记录页面 UI 问题
- 修复移动端无法打开菜单栏的问题
  - 在 app.ts 中添加 mobileOpen 状态管理
  - 修复 AppHeader.vue 中移动端菜单按钮调用错误的方法
  - 修复 AppSidebar.vue 使用本地 ref 而非全局状态的问题

- 添加移动端菜单自动关闭功能
  - 点击菜单项后自动关闭侧边栏
  - 添加 150ms 延迟以显示关闭动画

- 修复使用记录页面总消费卡片溢出问题
  - 调整总消费卡片布局,将删除线价格移至说明行
  - 添加 min-w-0 flex-1 防止内容溢出
  - 保持与其他卡片高度一致

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-19 15:55:42 +08:00
shaw
1ed93a5fd0 refactor: 提取 Claude 客户端常量到独立包
- 新增 internal/pkg/claude 包统一管理 Claude Code 相关常量
  - 统一账号测试逻辑,所有账号类型使用相同的 Claude Code 风格请求
  - 网关服务使用常量包替换硬编码的 beta header 字符串
2025-12-19 15:22:52 +08:00
shaw
463ddea36f fix(frontend): 修复代理快捷添加弹窗的 i18n 解析错误
batchInputHint 中的 @ 符号需要使用 {'@'} 转义
2025-12-19 11:24:22 +08:00
shaw
e769f67699 fix(setup): 支持从配置文件读取 Setup Wizard 监听地址
Setup Wizard 之前硬编码使用 8080 端口,现在支持从 config.yaml 或
环境变量 (SERVER_HOST, SERVER_PORT) 读取监听地址,方便用户在端口
被占用时使用其他地址启动初始化向导。
2025-12-19 11:21:58 +08:00
shaw
52d2ae9708 feat(gateway): 添加 /v1/messages/count_tokens 端点
实现 Claude API 的 token 计数功能,支持 OAuth、SetupToken 和 ApiKey 三种账号类型。

特点:
- 校验订阅/余额(不扣费)
- 不计算用户和账号并发
- 不记录使用量
- 支持模型映射(ApiKey 账号)
- 支持 OAuth 账号的指纹管理和 401 重试
2025-12-19 11:12:41 +08:00
shaw
2e59998c51 fix: 代理表单字段保存时自动去除前后空格
前后端同时处理,防止因意外空格导致代理连接失败
2025-12-19 10:39:30 +08:00
shaw
32e58115cc fix(frontend): 修复代理快捷添加弹窗的 i18n 解析错误
转义 batchInputPlaceholder 中的 @ 符号,防止 Vue I18n 将其误解析为链接消息语法
2025-12-19 10:32:22 +08:00
shaw
ba27026399 docs: 调整源码编译步骤的顺序 2025-12-19 09:47:17 +08:00
shaw
c15b419c4c feat(backend): 添加 event_logging 接口直接返回200
将原本在nginx处理的遥测日志请求移至后端,
忽略Claude Code客户端发送的日志数据。
2025-12-19 09:39:57 +08:00
shaw
5bd27a5d17 fix(frontend): 优化分组表单中订阅模式的字段显示逻辑
- 订阅模式下隐藏 Exclusive 字段并默认为开启状态
- 编辑分组时禁用计费类型字段,防止修改
- 移除编辑表单中无用的 subscription_type watch
2025-12-19 08:41:30 +08:00
Wesley Liddick
0e7b8aab8c Merge pull request #4 from NepetaLemon/refactor/backend-wire-provider-sets
refactor(backend): 拆分 Wire ProviderSet
2025-12-18 19:27:49 -05:00
Forest
236908c03d refactor(backend): 拆分 Wire ProviderSet 2025-12-19 00:03:29 +08:00
shaw
67d028cf50 fix: 修复用户修改密码接口404问题
将后端路由与前端API调用对齐:
- /user/profile -> /users/me
- PUT /user/password -> POST /users/me/password
2025-12-18 22:59:49 +08:00
shaw
66ba487697 fix: 修复前端github项目地址 2025-12-18 22:47:42 +08:00
Wesley Liddick
8c7875aa4d Merge pull request #3 from NepetaLemon/refactor/backend-wire-bootstrap
refactor(backend): 引入 Wire 重构服务启动与依赖组装
2025-12-18 09:12:15 -05:00
shaw
145171464f fix: 修复前端多个 bug
1. 版本号闪烁问题
   - 将版本信息缓存到 Pinia store,避免每次路由切换都重新请求
   - 添加加载占位符,版本为空时显示骨架屏

2. 管理员登录跳转问题
   - 管理员登录后现在正确跳转到 /admin/dashboard
   - 普通用户仍跳转到 /dashboard

3. Dashboard 页面空白报错
   - 修复 API 返回 null 时访问 .length 导致的 TypeError
   - 为 computed 属性添加可选链操作符保护
   - 为数据赋值添加空数组默认值
2025-12-18 22:11:29 +08:00
Forest
e5aa676853 refactor(backend): 引入 Wire 重构服务启动与依赖组装 2025-12-18 22:07:17 +08:00
shaw
9b4fc42457 feat: 实现后台在线更新功能
- 前端添加更新和重启按钮,支持一键更新 Release 构建
- 修复条件判断优先级问题,确保错误/成功状态正确显示
- 后端使用原子文件替换模式,确保更新过程安全可靠
- 在可执行文件同目录创建临时文件,保证 rename 原子性
- 删除未使用的 copyFile 函数,保持代码整洁
2025-12-18 21:15:10 +08:00
shaw
caae7e4603 feat: 改进安装脚本的交互体验和自动化流程
- 修复 curl | bash 管道模式下无法交互式输入的问题
  - 使用 /dev/tty 检测终端可用性替代 stdin 检测
  - 所有 read 命令从 /dev/tty 读取用户输入
- 安装完成后自动启动服务和启用开机自启
- 使用 ipinfo.io API 获取公网 IP 用于显示访问地址
- 简化安装完成后的输出信息
2025-12-18 20:53:29 +08:00
shaw
a26db8b3e2 fix: 修复前端页面刷新时偶发空白渲染的竞态条件问题
使用 router.isReady() 等待路由器完成初始导航后再挂载应用,
避免 RouterView 在路由未就绪时渲染空的 Comment 节点。
2025-12-18 20:45:56 +08:00
shaw
8e81e395b3 refactor: 使用行业标准方案重构服务重启逻辑
重构内容:
- 移除复杂的 sudo systemctl restart 方案
- 改用 os.Exit(0) + systemd Restart=always 的标准做法
- 删除 sudoers 配置及相关代码
- 删除 sub2api-sudoers 文件

优势:
- 代码从 85+ 行简化到 47 行
- 无需 sudo 权限配置
- 无需特殊用户 shell 配置
- 更简单、更可靠
- 符合行业最佳实践(Docker/K8s 等均采用此方案)

工作原理:
- 服务调用 os.Exit(0) 优雅退出
- systemd 检测到退出后自动重启(Restart=always)
2025-12-18 20:32:24 +08:00
shaw
f0e89992f7 fix: 使用 setsid 确保重启命令独立于父进程执行
问题原因:
- cmd.Start() 启动的子进程与父进程在同一会话中
- 当 systemctl restart 发送 SIGTERM 给父进程时
- 子进程可能也会被终止,导致重启命令无法完成

修复内容:
- 使用 setsid 创建新会话,子进程完全独立于父进程
- 分离标准输入/输出/错误流
- 确保即使父进程被 kill,重启命令仍能执行完成
2025-12-18 20:00:53 +08:00
shaw
4eaa0cf14a fix: 使用完整路径执行 sudo 和 systemctl 命令
问题原因:
- systemd 服务的 PATH 环境变量可能受限
- 直接使用 "sudo" 可能找不到可执行文件

修复内容:
- 添加 findExecutable 函数动态查找可执行文件路径
- 先尝试 exec.LookPath,再检查常见系统路径
- 添加日志显示实际使用的路径,方便调试
- 兼容不同 Linux 发行版的路径差异
2025-12-18 19:58:25 +08:00
shaw
e9ec2280ec fix: 修复 sudo 在非交互模式下无法执行的问题
问题原因:
- sudo 命令没有 -n 选项
- 在后台服务中,sudo 会尝试从终端读取密码
- 由于没有终端,命令静默失败

修复内容:
- 添加 sudo -n 选项强制非交互模式
- 如果需要密码会立即失败并返回错误,而不是挂起
2025-12-18 19:37:41 +08:00
Wesley Liddick
bb7bfb6980 Merge pull request #1 from 7836246/fix/concurrent-proxy-race-condition
fix: 修复并发请求时共享httpClient.Transport导致的竞态条件
2025-12-18 06:37:22 -05:00
shaw
b66f97c100 fix: 修复 install.sh 优先使用旧 sudoers 文件的问题
问题原因:
- install.sh 优先从 tar.gz 复制 sudoers 文件
- 旧版 Release 中的 sudoers 文件没有 /usr/bin/systemctl 路径
- 即使脚本更新了,仍然会使用旧的配置

修复内容:
- 移除对 tar.gz 中 sudoers 文件的依赖
- 总是使用脚本中内嵌的最新配置
- 确保新版脚本立即生效,无需等待新 Release
2025-12-18 19:27:47 +08:00
shaw
b51ad0d893 fix: 修复 sudoers 中 systemctl 路径不兼容的问题
问题原因:
- sudoers 只配置了 /bin/systemctl 路径
- 部分系统(如 Ubuntu 22.04+)的 systemctl 位于 /usr/bin/systemctl
- 路径不匹配导致 sudo 仍然需要密码

修复内容:
- 同时支持 /bin/systemctl 和 /usr/bin/systemctl 两个路径
- 兼容 Debian/Ubuntu 和 RHEL/CentOS 等不同发行版
2025-12-18 19:17:05 +08:00
shaw
4eb22d8ee9 fix: 修复服务用户 shell 导致无法执行 sudo 重启的问题
问题原因:
- 服务用户 sub2api 的 shell 被设置为 /bin/false
- 导致无法执行 sudo systemctl restart 命令
- 安装/升级后服务无法自动重启

修复内容:
- 新安装时使用 /bin/sh 替代 /bin/false
- 升级时自动检测并修复旧版本用户的 shell 配置
- 修复失败时给出警告和手动修复命令,不中断安装流程
2025-12-18 19:07:33 +08:00
江西小徐
2392e7cf99 fix: 修复并发请求时共享httpClient.Transport导致的竞态条件
问题描述:
当多个请求并发执行且使用不同代理配置时,它们会同时修改共享的
s.httpClient.Transport,导致请求可能使用错误的代理(数据泄露风险)
或意外失败。

修复方案:
为需要代理的请求创建独立的http.Client,而不是修改共享的httpClient.Transport。

改动内容:
- 新增 buildUpstreamRequestResult 结构体,返回请求和可选的独立client
- 修改 buildUpstreamRequest 方法,配置代理时创建独立client
- 更新 Forward 方法,根据是否有代理选择合适的client
2025-12-18 18:14:48 +08:00
210 changed files with 17995 additions and 4376 deletions

38
.github/workflows/backend-ci.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: CI
on:
push:
pull_request:
permissions:
contents: read
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: backend/go.mod
check-latest: true
cache: true
- name: Run tests
working-directory: backend
run: go test ./...
golangci-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: backend/go.mod
check-latest: true
cache: true
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:
version: v2.7
args: --timeout=5m
working-directory: backend

View File

@@ -85,6 +85,19 @@ jobs:
go-version: '1.24' go-version: '1.24'
cache-dependency-path: backend/go.sum cache-dependency-path: backend/go.sum
# Docker setup for GoReleaser
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Fetch tags with annotations - name: Fetch tags with annotations
run: | run: |
# 确保获取完整的 annotated tag 信息 # 确保获取完整的 annotated tag 信息
@@ -117,87 +130,16 @@ jobs:
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }} TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
# =========================================================================== # Update DockerHub description
# Docker Build and Push
# ===========================================================================
docker:
needs: [update-version, build-frontend]
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Download VERSION artifact
uses: actions/download-artifact@v4
with:
name: version-file
path: backend/cmd/server/
- name: Download frontend artifact
uses: actions/download-artifact@v4
with:
name: frontend-dist
path: backend/internal/web/dist/
# Extract version from tag
- name: Extract version
id: version
run: |
VERSION=${GITHUB_REF#refs/tags/v}
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Version: $VERSION"
# Set up Docker Buildx for multi-platform builds
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# Login to DockerHub
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Extract metadata for Docker
- name: Extract Docker metadata
id: meta
uses: docker/metadata-action@v5
with:
images: |
weishaw/sub2api
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=raw,value=latest,enable={{is_default_branch}}
# Build and push Docker image
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
COMMIT=${{ github.sha }}
DATE=${{ github.event.head_commit.timestamp }}
cache-from: type=gha
cache-to: type=gha,mode=max
# Update DockerHub description (optional)
- name: Update DockerHub description - name: Update DockerHub description
uses: peter-evans/dockerhub-description@v4 uses: peter-evans/dockerhub-description@v4
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
repository: weishaw/sub2api repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
short-description: "Sub2API - AI API Gateway Platform" short-description: "Sub2API - AI API Gateway Platform"
readme-filepath: ./deploy/DOCKER.md readme-filepath: ./deploy/DOCKER.md

16
.gitignore vendored
View File

@@ -28,6 +28,7 @@ node_modules/
frontend/node_modules/ frontend/node_modules/
frontend/dist/ frontend/dist/
*.local *.local
*.tsbuildinfo
# 日志 # 日志
npm-debug.log* npm-debug.log*
@@ -81,14 +82,27 @@ build/
release/ release/
# 后端嵌入的前端构建产物 # 后端嵌入的前端构建产物
# Keep a placeholder file so `//go:embed all:dist` always has a match in CI/lint,
# while still ignoring generated frontend build outputs.
backend/internal/web/dist/ backend/internal/web/dist/
!backend/internal/web/dist/
backend/internal/web/dist/*
!backend/internal/web/dist/.keep
# 后端运行时缓存数据 # 后端运行时缓存数据
backend/data/ backend/data/
# ===================
# 本地配置文件(包含敏感信息)
# ===================
backend/config.yaml
deploy/config.yaml
backend/.installed
# =================== # ===================
# 其他 # 其他
# =================== # ===================
tests tests
CLAUDE.md CLAUDE.md
.claude .claude
scripts

View File

@@ -11,6 +11,8 @@ builds:
dir: backend dir: backend
main: ./cmd/server main: ./cmd/server
binary: sub2api binary: sub2api
flags:
- -tags=embed
env: env:
- CGO_ENABLED=0 - CGO_ENABLED=0
goos: goos:
@@ -50,10 +52,58 @@ changelog:
# 禁用自动 changelog完全使用 tag 消息 # 禁用自动 changelog完全使用 tag 消息
disable: true disable: true
# Docker images
dockers:
- id: amd64
goos: linux
goarch: amd64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
- id: arm64
goos: linux
goarch: arm64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
# Docker manifests for multi-arch support
docker_manifests:
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
release: release:
github: github:
owner: Wei-Shaw owner: "{{ .Env.GITHUB_REPO_OWNER }}"
name: sub2api name: "{{ .Env.GITHUB_REPO_NAME }}"
draft: false draft: false
prerelease: auto prerelease: auto
name_template: "Sub2API {{.Version}}" name_template: "Sub2API {{.Version}}"
@@ -71,7 +121,7 @@ release:
**One-line install (Linux):** **One-line install (Linux):**
```bash ```bash
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash curl -sSL https://raw.githubusercontent.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/main/deploy/install.sh | sudo bash
``` ```
**Manual download:** **Manual download:**
@@ -79,5 +129,5 @@ release:
## 📚 Documentation ## 📚 Documentation
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api) - [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }})
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md) - [Installation Guide](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/blob/main/deploy/README.md)

View File

@@ -40,14 +40,15 @@ WORKDIR /app/backend
COPY backend/go.mod backend/go.sum ./ COPY backend/go.mod backend/go.sum ./
RUN go mod download RUN go mod download
# Copy frontend dist from previous stage # Copy backend source first
COPY --from=frontend-builder /app/frontend/../backend/internal/web/dist ./internal/web/dist
# Copy backend source
COPY backend/ ./ COPY backend/ ./
# Build the binary (BuildType=release for CI builds) # Copy frontend dist from previous stage (must be after backend copy to avoid being overwritten)
COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
# Build the binary (BuildType=release for CI builds, embed frontend)
RUN CGO_ENABLED=0 GOOS=linux go build \ RUN CGO_ENABLED=0 GOOS=linux go build \
-tags embed \
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \ -ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
-o /app/sub2api \ -o /app/sub2api \
./cmd/server ./cmd/server

40
Dockerfile.goreleaser Normal file
View File

@@ -0,0 +1,40 @@
# =============================================================================
# Sub2API Dockerfile for GoReleaser
# =============================================================================
# This Dockerfile is used by GoReleaser to build Docker images.
# It only packages the pre-built binary, no compilation needed.
# =============================================================================
FROM alpine:3.19
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
LABEL description="Sub2API - AI API Gateway Platform"
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
# Install runtime dependencies
RUN apk add --no-cache \
ca-certificates \
tzdata \
curl \
&& rm -rf /var/cache/apk/*
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
WORKDIR /app
# Copy pre-built binary from GoReleaser
COPY sub2api /app/sub2api
# Create data directory
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
USER sub2api
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
ENTRYPOINT ["/app/sub2api"]

View File

@@ -16,6 +16,14 @@ English | [中文](README_CN.md)
--- ---
## Demo
Try Sub2API online: **https://v2.pincc.ai/**
| Email | Password |
|-------|----------|
| admin@sub2api.com | admin123 |
## Overview ## Overview
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding. Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
@@ -208,26 +216,25 @@ Build and run from source code for development or customization.
git clone https://github.com/Wei-Shaw/sub2api.git git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api cd sub2api
# 2. Build backend # 2. Build frontend
cd backend cd frontend
go build -o sub2api ./cmd/server
# 3. Build frontend
cd ../frontend
npm install npm install
npm run build npm run build
# Output will be in ../backend/internal/web/dist/
# 4. Copy frontend build to backend (for embedding) # 3. Build backend with embedded frontend
cp -r dist ../backend/internal/web/
# 5. Create configuration file
cd ../backend cd ../backend
go build -tags embed -o sub2api ./cmd/server
# 4. Create configuration file
cp ../deploy/config.example.yaml ./config.yaml cp ../deploy/config.example.yaml ./config.yaml
# 6. Edit configuration # 5. Edit configuration
nano config.yaml nano config.yaml
``` ```
> **Note:** The `-tags embed` flag embeds the frontend into the binary. Without this flag, the binary will not serve the frontend UI.
**Key configuration in `config.yaml`:** **Key configuration in `config.yaml`:**
```yaml ```yaml
@@ -258,7 +265,7 @@ default:
``` ```
```bash ```bash
# 7. Run the application # 6. Run the application
./sub2api ./sub2api
``` ```

View File

@@ -16,6 +16,14 @@
--- ---
## 在线体验
体验地址:**https://v2.pincc.ai/**
| 邮箱 | 密码 |
|------|------|
| admin@sub2api.com | admin123 |
## 项目概述 ## 项目概述
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。 Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
@@ -208,26 +216,25 @@ docker-compose logs -f
git clone https://github.com/Wei-Shaw/sub2api.git git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api cd sub2api
# 2. 编译 # 2. 编译
cd backend cd frontend
go build -o sub2api ./cmd/server
# 3. 编译前端
cd ../frontend
npm install npm install
npm run build npm run build
# 构建产物输出到 ../backend/internal/web/dist/
# 4. 复制前端构建产物到后端(用于嵌入 # 3. 编译后端(嵌入前端
cp -r dist ../backend/internal/web/
# 5. 创建配置文件
cd ../backend cd ../backend
go build -tags embed -o sub2api ./cmd/server
# 4. 创建配置文件
cp ../deploy/config.example.yaml ./config.yaml cp ../deploy/config.example.yaml ./config.yaml
# 6. 编辑配置 # 5. 编辑配置
nano config.yaml nano config.yaml
``` ```
> **注意:** `-tags embed` 参数会将前端嵌入到二进制文件中。不使用此参数编译的程序将不包含前端界面。
**`config.yaml` 关键配置:** **`config.yaml` 关键配置:**
```yaml ```yaml
@@ -258,7 +265,7 @@ default:
``` ```
```bash ```bash
# 7. 运行应用 # 6. 运行应用
./sub2api ./sub2api
``` ```

594
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,594 @@
version: "2"
linters:
default: none
enable:
- depguard
- errcheck
- govet
- ineffassign
- staticcheck
- unused
settings:
depguard:
rules:
# Enforce: service must not depend on repository.
service-no-repository:
list-mode: original
files:
- "**/internal/service/**"
deny:
- pkg: sub2api/internal/repository
desc: "service must not import repository"
handler-no-repository:
list-mode: original
files:
- "**/internal/handler/**"
deny:
- pkg: sub2api/internal/repository
desc: "handler must not import repository"
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: true
# report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`.
# Such cases aren't reported by default.
# Default: false
check-blank: false
# To disable the errcheck built-in exclude list.
# See `-excludeonly` option in https://github.com/kisielk/errcheck#excluding-functions for details.
# Default: false
disable-default-exclusions: true
# List of functions to exclude from checking, where each entry is a single function to exclude.
# See https://github.com/kisielk/errcheck#excluding-functions for details.
exclude-functions:
- io/ioutil.ReadFile
- io.Copy(*bytes.Buffer)
- io.Copy(os.Stdout)
- fmt.Println
- fmt.Print
- fmt.Printf
- fmt.Fprint
- fmt.Fprintf
- fmt.Fprintln
# Display function signature instead of selector.
# Default: false
verbose: true
ineffassign:
# Check escaping variables of type error, may cause false positives.
# Default: false
check-escaping-errors: true
staticcheck:
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
dot-import-whitelist:
- fmt
# https://staticcheck.dev/docs/configuration/options/#initialisms
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
# Default: ["200", "400", "404", "500"]
http-status-code-whitelist: [ "200", "400", "404", "500" ]
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
checks:
# Invalid regular expression.
# https://staticcheck.dev/docs/checks/#SA1000
- SA1000
# Invalid template.
# https://staticcheck.dev/docs/checks/#SA1001
- SA1001
# Invalid format in 'time.Parse'.
# https://staticcheck.dev/docs/checks/#SA1002
- SA1002
# Unsupported argument to functions in 'encoding/binary'.
# https://staticcheck.dev/docs/checks/#SA1003
- SA1003
# Suspiciously small untyped constant in 'time.Sleep'.
# https://staticcheck.dev/docs/checks/#SA1004
- SA1004
# Invalid first argument to 'exec.Command'.
# https://staticcheck.dev/docs/checks/#SA1005
- SA1005
# 'Printf' with dynamic first argument and no further arguments.
# https://staticcheck.dev/docs/checks/#SA1006
- SA1006
# Invalid URL in 'net/url.Parse'.
# https://staticcheck.dev/docs/checks/#SA1007
- SA1007
# Non-canonical key in 'http.Header' map.
# https://staticcheck.dev/docs/checks/#SA1008
- SA1008
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
# https://staticcheck.dev/docs/checks/#SA1010
- SA1010
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
# https://staticcheck.dev/docs/checks/#SA1011
- SA1011
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
# https://staticcheck.dev/docs/checks/#SA1012
- SA1012
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
# https://staticcheck.dev/docs/checks/#SA1013
- SA1013
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
# https://staticcheck.dev/docs/checks/#SA1014
- SA1014
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
# https://staticcheck.dev/docs/checks/#SA1015
- SA1015
# Trapping a signal that cannot be trapped.
# https://staticcheck.dev/docs/checks/#SA1016
- SA1016
# Channels used with 'os/signal.Notify' should be buffered.
# https://staticcheck.dev/docs/checks/#SA1017
- SA1017
# 'strings.Replace' called with 'n == 0', which does nothing.
# https://staticcheck.dev/docs/checks/#SA1018
- SA1018
# Using a deprecated function, variable, constant or field.
# https://staticcheck.dev/docs/checks/#SA1019
- SA1019
# Using an invalid host:port pair with a 'net.Listen'-related function.
# https://staticcheck.dev/docs/checks/#SA1020
- SA1020
# Using 'bytes.Equal' to compare two 'net.IP'.
# https://staticcheck.dev/docs/checks/#SA1021
- SA1021
# Modifying the buffer in an 'io.Writer' implementation.
# https://staticcheck.dev/docs/checks/#SA1023
- SA1023
# A string cutset contains duplicate characters.
# https://staticcheck.dev/docs/checks/#SA1024
- SA1024
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
# https://staticcheck.dev/docs/checks/#SA1025
- SA1025
# Cannot marshal channels or functions.
# https://staticcheck.dev/docs/checks/#SA1026
- SA1026
# Atomic access to 64-bit variable must be 64-bit aligned.
# https://staticcheck.dev/docs/checks/#SA1027
- SA1027
# 'sort.Slice' can only be used on slices.
# https://staticcheck.dev/docs/checks/#SA1028
- SA1028
# Inappropriate key in call to 'context.WithValue'.
# https://staticcheck.dev/docs/checks/#SA1029
- SA1029
# Invalid argument in call to a 'strconv' function.
# https://staticcheck.dev/docs/checks/#SA1030
- SA1030
# Overlapping byte slices passed to an encoder.
# https://staticcheck.dev/docs/checks/#SA1031
- SA1031
# Wrong order of arguments to 'errors.Is'.
# https://staticcheck.dev/docs/checks/#SA1032
- SA1032
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
# https://staticcheck.dev/docs/checks/#SA2000
- SA2000
# Empty critical section, did you mean to defer the unlock?.
# https://staticcheck.dev/docs/checks/#SA2001
- SA2001
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
# https://staticcheck.dev/docs/checks/#SA2002
- SA2002
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
# https://staticcheck.dev/docs/checks/#SA2003
- SA2003
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
# https://staticcheck.dev/docs/checks/#SA3000
- SA3000
# Assigning to 'b.N' in benchmarks distorts the results.
# https://staticcheck.dev/docs/checks/#SA3001
- SA3001
# Binary operator has identical expressions on both sides.
# https://staticcheck.dev/docs/checks/#SA4000
- SA4000
# '&*x' gets simplified to 'x', it does not copy 'x'.
# https://staticcheck.dev/docs/checks/#SA4001
- SA4001
# Comparing unsigned values against negative values is pointless.
# https://staticcheck.dev/docs/checks/#SA4003
- SA4003
# The loop exits unconditionally after one iteration.
# https://staticcheck.dev/docs/checks/#SA4004
- SA4004
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
# https://staticcheck.dev/docs/checks/#SA4005
- SA4005
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
# https://staticcheck.dev/docs/checks/#SA4006
- SA4006
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
# https://staticcheck.dev/docs/checks/#SA4008
- SA4008
# A function argument is overwritten before its first use.
# https://staticcheck.dev/docs/checks/#SA4009
- SA4009
# The result of 'append' will never be observed anywhere.
# https://staticcheck.dev/docs/checks/#SA4010
- SA4010
# Break statement with no effect. Did you mean to break out of an outer loop?.
# https://staticcheck.dev/docs/checks/#SA4011
- SA4011
# Comparing a value against NaN even though no value is equal to NaN.
# https://staticcheck.dev/docs/checks/#SA4012
- SA4012
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
# https://staticcheck.dev/docs/checks/#SA4013
- SA4013
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
# https://staticcheck.dev/docs/checks/#SA4014
- SA4014
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
# https://staticcheck.dev/docs/checks/#SA4015
- SA4015
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
# https://staticcheck.dev/docs/checks/#SA4016
- SA4016
# Discarding the return values of a function without side effects, making the call pointless.
# https://staticcheck.dev/docs/checks/#SA4017
- SA4017
# Self-assignment of variables.
# https://staticcheck.dev/docs/checks/#SA4018
- SA4018
# Multiple, identical build constraints in the same file.
# https://staticcheck.dev/docs/checks/#SA4019
- SA4019
# Unreachable case clause in a type switch.
# https://staticcheck.dev/docs/checks/#SA4020
- SA4020
# "x = append(y)" is equivalent to "x = y".
# https://staticcheck.dev/docs/checks/#SA4021
- SA4021
# Comparing the address of a variable against nil.
# https://staticcheck.dev/docs/checks/#SA4022
- SA4022
# Impossible comparison of interface value with untyped nil.
# https://staticcheck.dev/docs/checks/#SA4023
- SA4023
# Checking for impossible return value from a builtin function.
# https://staticcheck.dev/docs/checks/#SA4024
- SA4024
# Integer division of literals that results in zero.
# https://staticcheck.dev/docs/checks/#SA4025
- SA4025
# Go constants cannot express negative zero.
# https://staticcheck.dev/docs/checks/#SA4026
- SA4026
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
# https://staticcheck.dev/docs/checks/#SA4027
- SA4027
# 'x % 1' is always zero.
# https://staticcheck.dev/docs/checks/#SA4028
- SA4028
# Ineffective attempt at sorting slice.
# https://staticcheck.dev/docs/checks/#SA4029
- SA4029
# Ineffective attempt at generating random number.
# https://staticcheck.dev/docs/checks/#SA4030
- SA4030
# Checking never-nil value against nil.
# https://staticcheck.dev/docs/checks/#SA4031
- SA4031
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
# https://staticcheck.dev/docs/checks/#SA4032
- SA4032
# Assignment to nil map.
# https://staticcheck.dev/docs/checks/#SA5000
- SA5000
# Deferring 'Close' before checking for a possible error.
# https://staticcheck.dev/docs/checks/#SA5001
- SA5001
# The empty for loop ("for {}") spins and can block the scheduler.
# https://staticcheck.dev/docs/checks/#SA5002
- SA5002
# Defers in infinite loops will never execute.
# https://staticcheck.dev/docs/checks/#SA5003
- SA5003
# "for { select { ..." with an empty default branch spins.
# https://staticcheck.dev/docs/checks/#SA5004
- SA5004
# The finalizer references the finalized object, preventing garbage collection.
# https://staticcheck.dev/docs/checks/#SA5005
- SA5005
# Infinite recursive call.
# https://staticcheck.dev/docs/checks/#SA5007
- SA5007
# Invalid struct tag.
# https://staticcheck.dev/docs/checks/#SA5008
- SA5008
# Invalid Printf call.
# https://staticcheck.dev/docs/checks/#SA5009
- SA5009
# Impossible type assertion.
# https://staticcheck.dev/docs/checks/#SA5010
- SA5010
# Possible nil pointer dereference.
# https://staticcheck.dev/docs/checks/#SA5011
- SA5011
# Passing odd-sized slice to function expecting even size.
# https://staticcheck.dev/docs/checks/#SA5012
- SA5012
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
# https://staticcheck.dev/docs/checks/#SA6000
- SA6000
# Missing an optimization opportunity when indexing maps by byte slices.
# https://staticcheck.dev/docs/checks/#SA6001
- SA6001
# Storing non-pointer values in 'sync.Pool' allocates memory.
# https://staticcheck.dev/docs/checks/#SA6002
- SA6002
# Converting a string to a slice of runes before ranging over it.
# https://staticcheck.dev/docs/checks/#SA6003
- SA6003
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
# https://staticcheck.dev/docs/checks/#SA6005
- SA6005
# Using io.WriteString to write '[]byte'.
# https://staticcheck.dev/docs/checks/#SA6006
- SA6006
# Defers in range loops may not run when you expect them to.
# https://staticcheck.dev/docs/checks/#SA9001
- SA9001
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
# https://staticcheck.dev/docs/checks/#SA9002
- SA9002
# Empty body in an if or else branch.
# https://staticcheck.dev/docs/checks/#SA9003
- SA9003
# Only the first constant has an explicit type.
# https://staticcheck.dev/docs/checks/#SA9004
- SA9004
# Trying to marshal a struct with no public fields nor custom marshaling.
# https://staticcheck.dev/docs/checks/#SA9005
- SA9005
# Dubious bit shifting of a fixed size integer value.
# https://staticcheck.dev/docs/checks/#SA9006
- SA9006
# Deleting a directory that shouldn't be deleted.
# https://staticcheck.dev/docs/checks/#SA9007
- SA9007
# 'else' branch of a type assertion is probably not reading the right value.
# https://staticcheck.dev/docs/checks/#SA9008
- SA9008
# Ineffectual Go compiler directive.
# https://staticcheck.dev/docs/checks/#SA9009
- SA9009
# Incorrect or missing package comment.
# https://staticcheck.dev/docs/checks/#ST1000
- ST1000
# Dot imports are discouraged.
# https://staticcheck.dev/docs/checks/#ST1001
- ST1001
# Poorly chosen identifier.
# https://staticcheck.dev/docs/checks/#ST1003
- ST1003
# Incorrectly formatted error string.
# https://staticcheck.dev/docs/checks/#ST1005
- ST1005
# Poorly chosen receiver name.
# https://staticcheck.dev/docs/checks/#ST1006
- ST1006
# A function's error value should be its last return value.
# https://staticcheck.dev/docs/checks/#ST1008
- ST1008
# Poorly chosen name for variable of type 'time.Duration'.
# https://staticcheck.dev/docs/checks/#ST1011
- ST1011
# Poorly chosen name for error variable.
# https://staticcheck.dev/docs/checks/#ST1012
- ST1012
# Should use constants for HTTP error codes, not magic numbers.
# https://staticcheck.dev/docs/checks/#ST1013
- ST1013
# A switch's default case should be the first or last case.
# https://staticcheck.dev/docs/checks/#ST1015
- ST1015
# Use consistent method receiver names.
# https://staticcheck.dev/docs/checks/#ST1016
- ST1016
# Don't use Yoda conditions.
# https://staticcheck.dev/docs/checks/#ST1017
- ST1017
# Avoid zero-width and control characters in string literals.
# https://staticcheck.dev/docs/checks/#ST1018
- ST1018
# Importing the same package multiple times.
# https://staticcheck.dev/docs/checks/#ST1019
- ST1019
# The documentation of an exported function should start with the function's name.
# https://staticcheck.dev/docs/checks/#ST1020
- ST1020
# The documentation of an exported type should start with type's name.
# https://staticcheck.dev/docs/checks/#ST1021
- ST1021
# The documentation of an exported variable or constant should start with variable's name.
# https://staticcheck.dev/docs/checks/#ST1022
- ST1022
# Redundant type in variable declaration.
# https://staticcheck.dev/docs/checks/#ST1023
- ST1023
# Use plain channel send or receive instead of single-case select.
# https://staticcheck.dev/docs/checks/#S1000
- S1000
# Replace for loop with call to copy.
# https://staticcheck.dev/docs/checks/#S1001
- S1001
# Omit comparison with boolean constant.
# https://staticcheck.dev/docs/checks/#S1002
- S1002
# Replace call to 'strings.Index' with 'strings.Contains'.
# https://staticcheck.dev/docs/checks/#S1003
- S1003
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
# https://staticcheck.dev/docs/checks/#S1004
- S1004
# Drop unnecessary use of the blank identifier.
# https://staticcheck.dev/docs/checks/#S1005
- S1005
# Use "for { ... }" for infinite loops.
# https://staticcheck.dev/docs/checks/#S1006
- S1006
# Simplify regular expression by using raw string literal.
# https://staticcheck.dev/docs/checks/#S1007
- S1007
# Simplify returning boolean expression.
# https://staticcheck.dev/docs/checks/#S1008
- S1008
# Omit redundant nil check on slices, maps, and channels.
# https://staticcheck.dev/docs/checks/#S1009
- S1009
# Omit default slice index.
# https://staticcheck.dev/docs/checks/#S1010
- S1010
# Use a single 'append' to concatenate two slices.
# https://staticcheck.dev/docs/checks/#S1011
- S1011
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
# https://staticcheck.dev/docs/checks/#S1012
- S1012
# Use a type conversion instead of manually copying struct fields.
# https://staticcheck.dev/docs/checks/#S1016
- S1016
# Replace manual trimming with 'strings.TrimPrefix'.
# https://staticcheck.dev/docs/checks/#S1017
- S1017
# Use "copy" for sliding elements.
# https://staticcheck.dev/docs/checks/#S1018
- S1018
# Simplify "make" call by omitting redundant arguments.
# https://staticcheck.dev/docs/checks/#S1019
- S1019
# Omit redundant nil check in type assertion.
# https://staticcheck.dev/docs/checks/#S1020
- S1020
# Merge variable declaration and assignment.
# https://staticcheck.dev/docs/checks/#S1021
- S1021
# Omit redundant control flow.
# https://staticcheck.dev/docs/checks/#S1023
- S1023
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
# https://staticcheck.dev/docs/checks/#S1024
- S1024
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
# https://staticcheck.dev/docs/checks/#S1025
- S1025
# Simplify error construction with 'fmt.Errorf'.
# https://staticcheck.dev/docs/checks/#S1028
- S1028
# Range over the string directly.
# https://staticcheck.dev/docs/checks/#S1029
- S1029
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
# https://staticcheck.dev/docs/checks/#S1030
- S1030
# Omit redundant nil check around loop.
# https://staticcheck.dev/docs/checks/#S1031
- S1031
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
# https://staticcheck.dev/docs/checks/#S1032
- S1032
# Unnecessary guard around call to "delete".
# https://staticcheck.dev/docs/checks/#S1033
- S1033
# Use result of type assertion to simplify cases.
# https://staticcheck.dev/docs/checks/#S1034
- S1034
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
# https://staticcheck.dev/docs/checks/#S1035
- S1035
# Unnecessary guard around map access.
# https://staticcheck.dev/docs/checks/#S1036
- S1036
# Elaborate way of sleeping.
# https://staticcheck.dev/docs/checks/#S1037
- S1037
# Unnecessarily complex way of printing formatted string.
# https://staticcheck.dev/docs/checks/#S1038
- S1038
# Unnecessary use of 'fmt.Sprint'.
# https://staticcheck.dev/docs/checks/#S1039
- S1039
# Type assertion to current type.
# https://staticcheck.dev/docs/checks/#S1040
- S1040
# Apply De Morgan's law.
# https://staticcheck.dev/docs/checks/#QF1001
- QF1001
# Convert untagged switch to tagged switch.
# https://staticcheck.dev/docs/checks/#QF1002
- QF1002
# Convert if/else-if chain to tagged switch.
# https://staticcheck.dev/docs/checks/#QF1003
- QF1003
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
# https://staticcheck.dev/docs/checks/#QF1004
- QF1004
# Expand call to 'math.Pow'.
# https://staticcheck.dev/docs/checks/#QF1005
- QF1005
# Lift 'if'+'break' into loop condition.
# https://staticcheck.dev/docs/checks/#QF1006
- QF1006
# Merge conditional assignment into variable declaration.
# https://staticcheck.dev/docs/checks/#QF1007
- QF1007
# Omit embedded fields from selector expression.
# https://staticcheck.dev/docs/checks/#QF1008
- QF1008
# Use 'time.Time.Equal' instead of '==' operator.
# https://staticcheck.dev/docs/checks/#QF1009
- QF1009
# Convert slice of bytes to string when printing it.
# https://staticcheck.dev/docs/checks/#QF1010
- QF1010
# Omit redundant type from variable declaration.
# https://staticcheck.dev/docs/checks/#QF1011
- QF1011
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
# https://staticcheck.dev/docs/checks/#QF1012
- QF1012
unused:
# Mark all struct fields that have been written to as used.
# Default: true
field-writes-are-uses: false
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
# Default: false
post-statements-are-reads: true
# Mark all exported fields as used.
# default: true
exported-fields-are-used: false
# Mark all function parameters as used.
# default: true
parameters-are-used: true
# Mark all local variables as used.
# default: true
local-variables-are-used: false
# Mark all identifiers inside generated files as used.
# Default: true
generated-is-used: false
formatters:
enable:
- gofmt
settings:
gofmt:
# Simplify code: gofmt with `-s` option.
# Default: true
simplify: false
# Apply the rewrite rules to the source before reformatting.
# https://pkg.go.dev/cmd/gofmt
# Default: []
rewrite-rules:
- pattern: 'interface{}'
replacement: 'any'
- pattern: 'a[b:len(a)]'
replacement: 'a[b:]'

16
backend/Makefile Normal file
View File

@@ -0,0 +1,16 @@
.PHONY: wire build build-embed
wire:
@echo "生成 Wire 代码..."
@cd cmd/server && go generate
@echo "Wire 代码生成完成"
build:
@echo "构建后端(不嵌入前端)..."
@go build -o bin/server ./cmd/server
@echo "构建完成: bin/server"
build-embed:
@echo "构建后端(嵌入前端)..."
@go build -tags embed -o bin/server ./cmd/server
@echo "构建完成: bin/server (with embedded frontend)"

View File

@@ -1,8 +1,11 @@
package main package main
//go:generate go run github.com/google/wire/cmd/wire
import ( import (
"context" "context"
_ "embed" _ "embed"
"errors"
"flag" "flag"
"log" "log"
"net/http" "net/http"
@@ -12,21 +15,13 @@ import (
"syscall" "syscall"
"time" "time"
"sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"sub2api/internal/middleware" "github.com/Wei-Shaw/sub2api/internal/middleware"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/setup"
"sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/web"
"sub2api/internal/repository"
"sub2api/internal/service"
"sub2api/internal/setup"
"sub2api/internal/web"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
) )
//go:embed VERSION //go:embed VERSION
@@ -100,8 +95,10 @@ func runSetupServer() {
r.Use(web.ServeEmbeddedFrontend()) r.Use(web.ServeEmbeddedFrontend())
} }
addr := ":8080" // Get server address from config.yaml or environment variables (SERVER_HOST, SERVER_PORT)
log.Printf("Setup wizard available at http://localhost%s", addr) // This allows users to run setup on a different address if needed
addr := config.GetServerAddress()
log.Printf("Setup wizard available at http://%s", addr)
log.Println("Complete the setup wizard to configure Sub2API") log.Println("Complete the setup wizard to configure Sub2API")
if err := r.Run(addr); err != nil { if err := r.Run(addr); err != nil {
@@ -110,78 +107,25 @@ func runSetupServer() {
} }
func runMainServer() { func runMainServer() {
// 加载配置
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 初始化时区(类似 PHP 的 date_default_timezone_set
if err := timezone.Init(cfg.Timezone); err != nil {
log.Fatalf("Failed to initialize timezone: %v", err)
}
// 初始化数据库
db, err := initDB(cfg)
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
// 初始化Redis
rdb := initRedis(cfg)
// 初始化Repository
repos := repository.NewRepositories(db)
// 初始化Service
services := service.NewServices(repos, rdb, cfg)
// 初始化Handler
buildInfo := handler.BuildInfo{ buildInfo := handler.BuildInfo{
Version: Version, Version: Version,
BuildType: BuildType, BuildType: BuildType,
} }
handlers := handler.NewHandlers(services, repos, rdb, buildInfo)
// 设置Gin模式 app, err := initializeApplication(buildInfo)
if cfg.Server.Mode == "release" { if err != nil {
gin.SetMode(gin.ReleaseMode) log.Fatalf("Failed to initialize application: %v", err)
}
// 创建路由
r := gin.New()
r.Use(gin.Recovery())
r.Use(middleware.Logger())
r.Use(middleware.CORS())
// 注册路由
registerRoutes(r, handlers, services, repos)
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
} }
defer app.Cleanup()
// 启动服务器 // 启动服务器
srv := &http.Server{
Addr: cfg.Server.Address(),
Handler: r,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
// 优雅关闭
go func() { go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := app.Server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start server: %v", err) log.Fatalf("Failed to start server: %v", err)
} }
}() }()
log.Printf("Server started on %s", cfg.Server.Address()) log.Printf("Server started on %s", app.Server.Addr)
// 等待中断信号 // 等待中断信号
quit := make(chan os.Signal, 1) quit := make(chan os.Signal, 1)
@@ -193,289 +137,9 @@ func runMainServer() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := srv.Shutdown(ctx); err != nil { if err := app.Server.Shutdown(ctx); err != nil {
log.Fatalf("Server forced to shutdown: %v", err) log.Fatalf("Server forced to shutdown: %v", err)
} }
log.Println("Server exited") log.Println("Server exited")
} }
func initDB(cfg *config.Config) (*gorm.DB, error) {
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil {
return nil, err
}
return db, nil
}
func initRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
}
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
// API v1
v1 := r.Group("/api/v1")
{
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的接口
authenticated := v1.Group("")
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
{
// 当前用户信息
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
// 管理员接口
admin := v1.Group("/admin")
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
{
// 仪表盘
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
// 用户管理
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
}
// 分组管理
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
// 账号管理
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
// OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
// 代理管理
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
// 卡密管理
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
// 系统设置
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
}
// 系统管理
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
// 订阅管理
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
// 使用记录管理
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
}
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
}
}

125
backend/cmd/server/wire.go Normal file
View File

@@ -0,0 +1,125 @@
//go:build wireinject
// +build wireinject
package main
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/service"
"context"
"log"
"net/http"
"time"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
type Application struct {
Server *http.Server
Cleanup func()
}
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
wire.Build(
// 基础设施层 ProviderSets
config.ProviderSet,
infrastructure.ProviderSet,
// 业务层 ProviderSets
repository.ProviderSet,
service.ProviderSet,
handler.ProviderSet,
// 服务器层 ProviderSet
server.ProviderSet,
// BuildInfo provider
provideServiceBuildInfo,
// 清理函数提供者
provideCleanup,
// 应用程序结构体
wire.Struct(new(Application), "Server", "Cleanup"),
)
return nil, nil
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,
BuildType: buildInfo.BuildType,
}
}
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Cleanup steps in reverse dependency order
cleanupSteps := []struct {
name string
fn func() error
}{
{"TokenRefreshService", func() error {
services.TokenRefresh.Stop()
return nil
}},
{"PricingService", func() error {
services.Pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
return nil
}},
{"OAuthService", func() error {
services.OAuth.Stop()
return nil
}},
{"OpenAIOAuthService", func() error {
services.OpenAIOAuth.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},
{"Database", func() error {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}},
}
for _, step := range cleanupSteps {
if err := step.fn(); err != nil {
log.Printf("[Cleanup] %s failed: %v", step.name, err)
// Continue with remaining cleanup steps even if one fails
} else {
log.Printf("[Cleanup] %s succeeded", step.name)
}
}
// Check if context timed out
select {
case <-ctx.Done():
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
default:
log.Printf("[Cleanup] All cleanup steps completed")
}
}
}

View File

@@ -0,0 +1,249 @@
// Code generated by Wire. DO NOT EDIT.
//go:generate go run -mod=mod github.com/google/wire/cmd/wire
//go:build !wireinject
// +build !wireinject
package main
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
"log"
"net/http"
"time"
)
import (
_ "embed"
)
// Injectors from wire.go:
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
configConfig, err := config.ProvideConfig()
if err != nil {
return nil, err
}
db, err := infrastructure.ProvideDB(configConfig)
if err != nil {
return nil, err
}
userRepository := repository.NewUserRepository(db)
settingRepository := repository.NewSettingRepository(db)
settingService := service.NewSettingService(settingRepository, configConfig)
client := infrastructure.ProvideRedis(configConfig)
emailCache := repository.NewEmailCache(client)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier()
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
authHandler := handler.NewAuthHandler(authService)
userService := service.NewUserService(userRepository)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
apiKeyCache := repository.NewApiKeyCache(client)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
billingCache := repository.NewBillingCache(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(client)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardService := service.NewDashboardService(usageLogRepository)
dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db)
proxyExitInfoProber := repository.NewProxyExitInfoProber()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
httpUpstream := repository.NewHTTPUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService)
updateCache := repository.NewUpdateCache(client)
gitHubReleaseClient := repository.NewGitHubReleaseClient()
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
gatewayCache := repository.NewGatewayCache(client)
pricingRemoteClient := repository.NewPricingRemoteClient()
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(client)
identityService := service.NewIdentityService(identityCache)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
groupService := service.NewGroupService(groupRepository)
accountService := service.NewAccountService(accountRepository, groupRepository)
proxyService := service.NewProxyService(proxyRepository)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
services := &service.Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OpenAIGateway: openAIGatewayService,
OAuth: oAuthService,
OpenAIOAuth: openAIOAuthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
Update: updateService,
TokenRefresh: tokenRefreshService,
}
repositories := &repository.Repositories{
User: userRepository,
ApiKey: apiKeyRepository,
Group: groupRepository,
Account: accountRepository,
Proxy: proxyRepository,
RedeemCode: redeemCodeRepository,
UsageLog: usageLogRepository,
Setting: settingRepository,
UserSubscription: userSubscriptionRepository,
}
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
httpServer := server.ProvideHTTPServer(configConfig, engine)
v := provideCleanup(db, client, services)
application := &Application{
Server: httpServer,
Cleanup: v,
}
return application, nil
}
// wire.go:
type Application struct {
Server *http.Server
Cleanup func()
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,
BuildType: buildInfo.BuildType,
}
}
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cleanupSteps := []struct {
name string
fn func() error
}{
{"TokenRefreshService", func() error {
services.TokenRefresh.Stop()
return nil
}},
{"PricingService", func() error {
services.Pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
return nil
}},
{"OAuthService", func() error {
services.OAuth.Stop()
return nil
}},
{"OpenAIOAuthService", func() error {
services.OpenAIOAuth.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},
{"Database", func() error {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}},
}
for _, step := range cleanupSteps {
if err := step.fn(); err != nil {
log.Printf("[Cleanup] %s failed: %v", step.name, err)
} else {
log.Printf("[Cleanup] %s succeeded", step.name)
}
}
select {
case <-ctx.Done():
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
default:
log.Printf("[Cleanup] All cleanup steps completed")
}
}
}

View File

@@ -1,38 +0,0 @@
server:
host: "0.0.0.0"
port: 8080
mode: "debug" # debug/release
database:
host: "127.0.0.1"
port: 5432
user: "postgres"
password: "XZeRr7nkjHWhm8fw"
dbname: "sub2api"
sslmode: "disable"
redis:
host: "127.0.0.1"
port: 6379
password: ""
db: 0
jwt:
secret: "your-secret-key-change-in-production"
expire_hour: 24
default:
admin_email: "admin@sub2api.com"
admin_password: "admin123"
user_concurrency: 5
user_balance: 0
api_key_prefix: "sk-"
rate_multiplier: 1.0
# Timezone configuration (similar to PHP's date_default_timezone_set)
# This affects ALL time operations:
# - Database timestamps
# - Usage statistics "today" boundary
# - Subscription expiry times
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
timezone: "Asia/Shanghai"

View File

@@ -1,4 +1,4 @@
module sub2api module github.com/Wei-Shaw/sub2api
go 1.24.0 go 1.24.0
@@ -8,11 +8,15 @@ require (
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.0 github.com/golang-jwt/jwt/v5 v5.2.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/google/wire v0.7.0
github.com/imroc/req/v3 v3.56.0 github.com/imroc/req/v3 v3.56.0
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/redis/go-redis/v9 v9.3.0 github.com/redis/go-redis/v9 v9.3.0
github.com/spf13/viper v1.18.2 github.com/spf13/viper v1.18.2
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
golang.org/x/crypto v0.44.0 golang.org/x/crypto v0.44.0
golang.org/x/net v0.47.0
golang.org/x/term v0.37.0 golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.4 gorm.io/driver/postgres v1.5.4
@@ -33,6 +37,7 @@ require (
github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -50,6 +55,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.56.0 // indirect github.com/quic-go/quic-go v0.56.0 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect github.com/refraction-networking/utls v1.8.1 // indirect
@@ -60,15 +66,19 @@ require (
github.com/spf13/cast v1.6.0 // indirect github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
go.uber.org/atomic v1.9.0 // indirect go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/net v0.47.0 // indirect golang.org/x/mod v0.29.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
) )

View File

@@ -48,8 +48,12 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
@@ -135,6 +139,15 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
@@ -154,8 +167,12 @@ golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
@@ -166,6 +183,8 @@ golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=

View File

@@ -8,15 +8,30 @@ import (
) )
type Config struct { type Config struct {
Server ServerConfig `mapstructure:"server"` Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"` Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"` Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"` JWT JWTConfig `mapstructure:"jwt"`
Default DefaultConfig `mapstructure:"default"` Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"` RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"` Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"` Gateway GatewayConfig `mapstructure:"gateway"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
Enabled bool `mapstructure:"enabled"`
// 检查间隔(分钟)
CheckIntervalMinutes int `mapstructure:"check_interval_minutes"`
// 提前刷新时间小时在token过期前多久开始刷新
RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"`
// 最大重试次数
MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
} }
type PricingConfig struct { type PricingConfig struct {
@@ -37,7 +52,7 @@ type PricingConfig struct {
type ServerConfig struct { type ServerConfig struct {
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
} }
@@ -148,7 +163,7 @@ func setDefaults() {
viper.SetDefault("server.port", 8080) viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug") viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
// Database // Database
viper.SetDefault("database.host", "localhost") viper.SetDefault("database.host", "localhost")
@@ -192,6 +207,13 @@ func setDefaults() {
// Gateway // Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久 viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
@@ -203,3 +225,29 @@ func (c *Config) Validate() error {
} }
return nil return nil
} }
// GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup.
// Priority: config.yaml > environment variables > defaults
func GetServerAddress() string {
v := viper.New()
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("./config")
v.AddConfigPath("/etc/sub2api")
// Support SERVER_HOST and SERVER_PORT environment variables
v.AutomaticEnv()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.SetDefault("server.host", "0.0.0.0")
v.SetDefault("server.port", 8080)
// Try to read config file (ignore errors if not found)
_ = v.ReadInConfig()
host := v.GetString("server.host")
port := v.GetInt("server.port")
return fmt.Sprintf("%s:%d", host, port)
}

View File

@@ -0,0 +1,13 @@
package config
import "github.com/google/wire"
// ProviderSet 提供配置层的依赖
var ProviderSet = wire.NewSet(
ProvideConfig,
)
// ProvideConfig 提供应用配置
func ProvideConfig() (*Config, error) {
return Load()
}

View File

@@ -3,8 +3,12 @@ package admin
import ( import (
"strconv" "strconv"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -12,14 +16,12 @@ import (
// OAuthHandler handles OAuth-related operations for accounts // OAuthHandler handles OAuth-related operations for accounts
type OAuthHandler struct { type OAuthHandler struct {
oauthService *service.OAuthService oauthService *service.OAuthService
adminService service.AdminService
} }
// NewOAuthHandler creates a new OAuth handler // NewOAuthHandler creates a new OAuth handler
func NewOAuthHandler(oauthService *service.OAuthService, adminService service.AdminService) *OAuthHandler { func NewOAuthHandler(oauthService *service.OAuthService) *OAuthHandler {
return &OAuthHandler{ return &OAuthHandler{
oauthService: oauthService, oauthService: oauthService,
adminService: adminService,
} }
} }
@@ -27,47 +29,81 @@ func NewOAuthHandler(oauthService *service.OAuthService, adminService service.Ad
type AccountHandler struct { type AccountHandler struct {
adminService service.AdminService adminService service.AdminService
oauthService *service.OAuthService oauthService *service.OAuthService
openaiOAuthService *service.OpenAIOAuthService
rateLimitService *service.RateLimitService rateLimitService *service.RateLimitService
accountUsageService *service.AccountUsageService accountUsageService *service.AccountUsageService
accountTestService *service.AccountTestService accountTestService *service.AccountTestService
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
} }
// NewAccountHandler creates a new admin account handler // NewAccountHandler creates a new admin account handler
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService) *AccountHandler { func NewAccountHandler(
adminService service.AdminService,
oauthService *service.OAuthService,
openaiOAuthService *service.OpenAIOAuthService,
rateLimitService *service.RateLimitService,
accountUsageService *service.AccountUsageService,
accountTestService *service.AccountTestService,
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
) *AccountHandler {
return &AccountHandler{ return &AccountHandler{
adminService: adminService, adminService: adminService,
oauthService: oauthService, oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
accountUsageService: accountUsageService, accountUsageService: accountUsageService,
accountTestService: accountTestService, accountTestService: accountTestService,
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
} }
} }
// CreateAccountRequest represents create account request // CreateAccountRequest represents create account request
type CreateAccountRequest struct { type CreateAccountRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
Platform string `json:"platform" binding:"required"` Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"` Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]interface{} `json:"credentials" binding:"required"` Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]interface{} `json:"extra"` Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"` Concurrency int `json:"concurrency"`
Priority int `json:"priority"` Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
} }
// UpdateAccountRequest represents update account request // UpdateAccountRequest represents update account request
// 使用指针类型来区分"未提供"和"设置为0" // 使用指针类型来区分"未提供"和"设置为0"
type UpdateAccountRequest struct { type UpdateAccountRequest struct {
Name string `json:"name"` Name string `json:"name"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"` Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]interface{} `json:"credentials"` Credentials map[string]any `json:"credentials"`
Extra map[string]interface{} `json:"extra"` Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"` Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
}
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*model.Account
CurrentConcurrency int `json:"current_concurrency"`
} }
// List handles listing all accounts with pagination // List handles listing all accounts with pagination
@@ -85,7 +121,28 @@ func (h *AccountHandler) List(c *gin.Context) {
return return
} }
response.Paginated(c, accounts, total, page, pageSize) // Get current concurrency counts for all accounts
accountIDs := make([]int64, len(accounts))
for i, acc := range accounts {
accountIDs[i] = acc.ID
}
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
if err != nil {
// Log error but don't fail the request, just use 0 for all
concurrencyCounts = make(map[int64]int)
}
// Build response with concurrency info
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
result[i] = AccountWithConcurrency{
Account: &accounts[i],
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
}
}
response.Paginated(c, result, total, page, pageSize)
} }
// GetByID handles getting an account by ID // GetByID handles getting an account by ID
@@ -186,6 +243,18 @@ func (h *AccountHandler) Delete(c *gin.Context) {
response.Success(c, gin.H{"message": "Account deleted successfully"}) response.Success(c, gin.H{"message": "Account deleted successfully"})
} }
// TestAccountRequest represents the request body for testing an account
type TestAccountRequest struct {
ModelID string `json:"model_id"`
}
type SyncFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
}
// Test handles testing account connectivity with SSE streaming // Test handles testing account connectivity with SSE streaming
// POST /api/v1/admin/accounts/:id/test // POST /api/v1/admin/accounts/:id/test
func (h *AccountHandler) Test(c *gin.Context) { func (h *AccountHandler) Test(c *gin.Context) {
@@ -195,13 +264,46 @@ func (h *AccountHandler) Test(c *gin.Context) {
return return
} }
var req TestAccountRequest
// Allow empty body, model_id is optional
_ = c.ShouldBindJSON(&req)
// Use AccountTestService to test the account with SSE streaming // Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID); err != nil { if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
// Error already sent via SSE, just log // Error already sent via SSE, just log
return return
} }
} }
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
// POST /api/v1/admin/accounts/sync/crs
func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
var req SyncFromCRSRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Default to syncing proxies (can be disabled by explicitly setting false)
syncProxies := true
if req.SyncProxies != nil {
syncProxies = *req.SyncProxies
}
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
})
if err != nil {
response.BadRequest(c, "Sync failed: "+err.Error())
return
}
response.Success(c, result)
}
// Refresh handles refreshing account credentials // Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh // POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) { func (h *AccountHandler) Refresh(c *gin.Context) {
@@ -224,21 +326,46 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
return return
} }
// Use OAuth service to refresh token var newCredentials map[string]any
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
// Update account credentials if account.IsOpenAI() {
newCredentials := map[string]interface{}{ // Use OpenAI OAuth service to refresh token
"access_token": tokenInfo.AccessToken, tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
"token_type": tokenInfo.TokenType, if err != nil {
"expires_in": tokenInfo.ExpiresIn, response.InternalError(c, "Failed to refresh credentials: "+err.Error())
"expires_at": tokenInfo.ExpiresAt, return
"refresh_token": tokenInfo.RefreshToken, }
"scope": tokenInfo.Scope,
// Build new credentials from token info
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
} else {
// Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
newCredentials = make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
// Update token-related fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = tokenInfo.ExpiresIn
newCredentials["expires_at"] = tokenInfo.ExpiresAt
newCredentials["refresh_token"] = tokenInfo.RefreshToken
newCredentials["scope"] = tokenInfo.Scope
} }
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
@@ -261,15 +388,26 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
return return
} }
// Return mock data for now // Parse days parameter (default 30)
_ = accountID days := 30
response.Success(c, gin.H{ if daysStr := c.Query("days"); daysStr != "" {
"total_requests": 0, if d, err := strconv.Atoi(daysStr); err == nil && d > 0 && d <= 90 {
"successful_requests": 0, days = d
"failed_requests": 0, }
"total_tokens": 0, }
"average_response_time": 0,
}) // Calculate time range
now := timezone.Now()
endTime := timezone.StartOfDay(now.AddDate(0, 0, 1))
startTime := timezone.StartOfDay(now.AddDate(0, 0, -days+1))
stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get account stats: "+err.Error())
return
}
response.Success(c, stats)
} }
// ClearError handles clearing account error // ClearError handles clearing account error
@@ -309,6 +447,136 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
}) })
} }
// BatchUpdateCredentialsRequest represents batch credentials update request
type BatchUpdateCredentialsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Field string `json:"field" binding:"required,oneof=account_uuid org_uuid intercept_warmup_requests"`
Value any `json:"value"`
}
// BatchUpdateCredentials handles batch updating credentials fields
// POST /api/v1/admin/accounts/batch-update-credentials
func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
var req BatchUpdateCredentialsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Validate value type based on field
if req.Field == "intercept_warmup_requests" {
// Must be boolean
if _, ok := req.Value.(bool); !ok {
response.BadRequest(c, "intercept_warmup_requests must be boolean")
return
}
} else {
// account_uuid and org_uuid can be string or null
if req.Value != nil {
if _, ok := req.Value.(string); !ok {
response.BadRequest(c, req.Field+" must be string or null")
return
}
}
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
// Update account
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"account_id": accountID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
})
}
// BulkUpdate handles bulk updating accounts with selected fields/credentials.
// POST /api/v1/admin/accounts/bulk-update
func (h *AccountHandler) BulkUpdate(c *gin.Context) {
var req BulkUpdateAccountsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
hasUpdates := req.Name != "" ||
req.ProxyID != nil ||
req.Concurrency != nil ||
req.Priority != nil ||
req.Status != "" ||
req.GroupIDs != nil ||
len(req.Credentials) > 0 ||
len(req.Extra) > 0
if !hasUpdates {
response.BadRequest(c, "No updates provided")
return
}
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs,
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: req.Status,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
})
if err != nil {
response.InternalError(c, "Failed to bulk update accounts: "+err.Error())
return
}
response.Success(c, result)
}
// ========== OAuth Handlers ========== // ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL // GenerateAuthURLRequest represents the request for generating auth URL
@@ -535,3 +803,98 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
response.Success(c, account) response.Success(c, account)
} }
// GetAvailableModels handles getting available models for an account
// GET /api/v1/admin/accounts/:id/models
func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
// Handle OpenAI accounts
if account.IsOpenAI() {
// For OAuth accounts: return default OpenAI models
if account.IsOAuth() {
response.Success(c, openai.DefaultModels)
return
}
// For API Key accounts: check model_mapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
response.Success(c, openai.DefaultModels)
return
}
// Return mapped models
var models []openai.Model
for requestedModel := range mapping {
var found bool
for _, dm := range openai.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
if !found {
models = append(models, openai.Model{
ID: requestedModel,
Object: "model",
Type: "model",
DisplayName: requestedModel,
})
}
}
response.Success(c, models)
return
}
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
response.Success(c, claude.DefaultModels)
return
}
// For API Key accounts: return models based on model_mapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
// No mapping configured, return default models
response.Success(c, claude.DefaultModels)
return
}
// Return mapped models (keys of the mapping are the available model IDs)
var models []claude.Model
for requestedModel := range mapping {
// Try to find display info from default models
var found bool
for _, dm := range claude.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
// If not found in defaults, create a basic entry
if !found {
models = append(models, claude.Model{
ID: requestedModel,
Type: "model",
DisplayName: requestedModel,
CreatedAt: "",
})
}
}
response.Success(c, models)
}

View File

@@ -1,11 +1,10 @@
package admin package admin
import ( import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"strconv" "strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -13,17 +12,15 @@ import (
// DashboardHandler handles admin dashboard statistics // DashboardHandler handles admin dashboard statistics
type DashboardHandler struct { type DashboardHandler struct {
adminService service.AdminService dashboardService *service.DashboardService
usageRepo *repository.UsageLogRepository startTime time.Time // Server start time for uptime calculation
startTime time.Time // Server start time for uptime calculation
} }
// NewDashboardHandler creates a new admin dashboard handler // NewDashboardHandler creates a new admin dashboard handler
func NewDashboardHandler(adminService service.AdminService, usageRepo *repository.UsageLogRepository) *DashboardHandler { func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
return &DashboardHandler{ return &DashboardHandler{
adminService: adminService, dashboardService: dashboardService,
usageRepo: usageRepo, startTime: time.Now(),
startTime: time.Now(),
} }
} }
@@ -61,7 +58,7 @@ func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
// GetStats handles getting dashboard statistics // GetStats handles getting dashboard statistics
// GET /api/v1/admin/dashboard/stats // GET /api/v1/admin/dashboard/stats
func (h *DashboardHandler) GetStats(c *gin.Context) { func (h *DashboardHandler) GetStats(c *gin.Context) {
stats, err := h.usageRepo.GetDashboardStats(c.Request.Context()) stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics") response.Error(c, 500, "Failed to get dashboard statistics")
return return
@@ -110,6 +107,10 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
// 系统运行统计 // 系统运行统计
"average_duration_ms": stats.AverageDurationMs, "average_duration_ms": stats.AverageDurationMs,
"uptime": uptime, "uptime": uptime,
// 性能指标
"rpm": stats.Rpm,
"tpm": stats.Tpm,
}) })
} }
@@ -127,12 +128,25 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data // GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend // GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour) // Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c) startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day") granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUsageTrend(c.Request.Context(), startTime, endTime, granularity) // Parse optional filter params
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get usage trend") response.Error(c, 500, "Failed to get usage trend")
return return
@@ -148,11 +162,24 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics // GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models // GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD) // Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
func (h *DashboardHandler) GetModelStats(c *gin.Context) { func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c) startTime, endTime := parseTimeRange(c)
stats, err := h.usageRepo.GetModelStats(c.Request.Context(), startTime, endTime) // Parse optional filter params
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get model statistics") response.Error(c, 500, "Failed to get model statistics")
return return
@@ -177,7 +204,7 @@ func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
limit = 5 limit = 5
} }
trend, err := h.usageRepo.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit) trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get API key usage trend") response.Error(c, 500, "Failed to get API key usage trend")
return return
@@ -203,7 +230,7 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
limit = 12 limit = 12
} }
trend, err := h.usageRepo.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit) trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get user usage trend") response.Error(c, 500, "Failed to get user usage trend")
return return
@@ -232,11 +259,11 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
} }
if len(req.UserIDs) == 0 { if len(req.UserIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}}) response.Success(c, gin.H{"stats": map[string]any{}})
return return
} }
stats, err := h.usageRepo.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs) stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get user usage stats") response.Error(c, 500, "Failed to get user usage stats")
return return
@@ -260,11 +287,11 @@ func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
} }
if len(req.ApiKeyIDs) == 0 { if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}}) response.Success(c, gin.H{"stats": map[string]any{}})
return return
} }
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs) stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get API key usage stats") response.Error(c, 500, "Failed to get API key usage stats")
return return

View File

@@ -3,9 +3,9 @@ package admin
import ( import (
"strconv" "strconv"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -0,0 +1,228 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// OpenAIOAuthHandler handles OpenAI OAuth-related operations
type OpenAIOAuthHandler struct {
openaiOAuthService *service.OpenAIOAuthService
adminService service.AdminService
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{
openaiOAuthService: openaiOAuthService,
adminService: adminService,
}
}
// OpenAIGenerateAuthURLRequest represents the request for generating OpenAI auth URL
type OpenAIGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
RedirectURI string `json:"redirect_uri"`
}
// GenerateAuthURL generates OpenAI OAuth authorization URL
// POST /api/v1/admin/openai/generate-auth-url
func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req OpenAIGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body
req = OpenAIGenerateAuthURLRequest{}
}
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
return
}
response.Success(c, result)
}
// OpenAIExchangeCodeRequest represents the request for exchanging OpenAI auth code
type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
}
// ExchangeCode exchanges OpenAI authorization code for tokens
// POST /api/v1/admin/openai/exchange-code
func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
var req OpenAIExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
var proxyURL string
if req.ProxyID != nil {
proxy, err := h.adminService.GetProxy(c.Request.Context(), *req.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
if err != nil {
response.BadRequest(c, "Failed to refresh token: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI account
// POST /api/v1/admin/openai/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
// Ensure account is OpenAI platform
if !account.IsOpenAI() {
response.BadRequest(c, "Account is not an OpenAI account")
return
}
// Only refresh OAuth-based accounts
if !account.IsOAuth() {
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
return
}
// Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
// Build new credentials from token info
newCredentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error())
return
}
response.Success(c, updatedAccount)
}
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Exchange code for tokens
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
// Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Use email as default name if not provided
name := req.Name
if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email
}
if name == "" {
name = "OpenAI OAuth Account"
}
// Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name,
Platform: "openai",
Type: "oauth",
Credentials: credentials,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
})
if err != nil {
response.InternalError(c, "Failed to create account: "+err.Error())
return
}
response.Success(c, account)
}

View File

@@ -2,9 +2,10 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -112,12 +113,12 @@ func (h *ProxyHandler) Create(c *gin.Context) {
} }
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: req.Name, Name: strings.TrimSpace(req.Name),
Protocol: req.Protocol, Protocol: strings.TrimSpace(req.Protocol),
Host: req.Host, Host: strings.TrimSpace(req.Host),
Port: req.Port, Port: req.Port,
Username: req.Username, Username: strings.TrimSpace(req.Username),
Password: req.Password, Password: strings.TrimSpace(req.Password),
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to create proxy: "+err.Error()) response.BadRequest(c, "Failed to create proxy: "+err.Error())
@@ -143,13 +144,13 @@ func (h *ProxyHandler) Update(c *gin.Context) {
} }
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{ proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
Name: req.Name, Name: strings.TrimSpace(req.Name),
Protocol: req.Protocol, Protocol: strings.TrimSpace(req.Protocol),
Host: req.Host, Host: strings.TrimSpace(req.Host),
Port: req.Port, Port: req.Port,
Username: req.Username, Username: strings.TrimSpace(req.Username),
Password: req.Password, Password: strings.TrimSpace(req.Password),
Status: req.Status, Status: strings.TrimSpace(req.Status),
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update proxy: "+err.Error()) response.InternalError(c, "Failed to update proxy: "+err.Error())
@@ -235,7 +236,6 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
response.Paginated(c, accounts, total, page, pageSize) response.Paginated(c, accounts, total, page, pageSize)
} }
// BatchCreateProxyItem represents a single proxy in batch create request // BatchCreateProxyItem represents a single proxy in batch create request
type BatchCreateProxyItem struct { type BatchCreateProxyItem struct {
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
@@ -263,8 +263,14 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
skipped := 0 skipped := 0
for _, item := range req.Proxies { for _, item := range req.Proxies {
// Trim all string fields
host := strings.TrimSpace(item.Host)
protocol := strings.TrimSpace(item.Protocol)
username := strings.TrimSpace(item.Username)
password := strings.TrimSpace(item.Password)
// Check for duplicates (same host, port, username, password) // Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), item.Host, item.Port, item.Username, item.Password) exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
if err != nil { if err != nil {
response.InternalError(c, "Failed to check proxy existence: "+err.Error()) response.InternalError(c, "Failed to check proxy existence: "+err.Error())
return return
@@ -278,11 +284,11 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
// Create proxy with default name // Create proxy with default name
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ _, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: "default", Name: "default",
Protocol: item.Protocol, Protocol: protocol,
Host: item.Host, Host: host,
Port: item.Port, Port: item.Port,
Username: item.Username, Username: username,
Password: item.Password, Password: password,
}) })
if err != nil { if err != nil {
// If creation fails due to duplicate, count as skipped // If creation fails due to duplicate, count as skipped

View File

@@ -6,8 +6,8 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -156,10 +156,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
func (h *RedeemHandler) GetStats(c *gin.Context) { func (h *RedeemHandler) GetStats(c *gin.Context) {
// Return mock data for now // Return mock data for now
response.Success(c, gin.H{ response.Success(c, gin.H{
"total_codes": 0, "total_codes": 0,
"active_codes": 0, "active_codes": 0,
"used_codes": 0, "used_codes": 0,
"expired_codes": 0, "expired_codes": 0,
"total_value_distributed": 0.0, "total_value_distributed": 0.0,
"by_type": gin.H{ "by_type": gin.H{
"balance": 0, "balance": 0,
@@ -187,7 +187,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
writer := csv.NewWriter(&buf) writer := csv.NewWriter(&buf)
// Write header // Write header
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}) if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
// Write data rows // Write data rows
for _, code := range codes { for _, code := range codes {
@@ -199,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
if code.UsedAt != nil { if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05") usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
} }
writer.Write([]string{ if err := writer.Write([]string{
fmt.Sprintf("%d", code.ID), fmt.Sprintf("%d", code.ID),
code.Code, code.Code,
code.Type, code.Type,
@@ -208,10 +211,17 @@ func (h *RedeemHandler) Export(c *gin.Context) {
usedBy, usedBy,
usedAt, usedAt,
code.CreatedAt.Format("2006-01-02 15:04:05"), code.CreatedAt.Format("2006-01-02 15:04:05"),
}) }); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
} }
writer.Flush() writer.Flush()
if err := writer.Error(); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
c.Header("Content-Type", "text/csv") c.Header("Content-Type", "text/csv")
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv") c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")

View File

@@ -1,9 +1,9 @@
package admin package admin
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -60,6 +60,7 @@ type UpdateSettingsRequest struct {
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"` ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"` ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置 // 默认配置
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
@@ -104,6 +105,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
SiteSubtitle: req.SiteSubtitle, SiteSubtitle: req.SiteSubtitle,
ApiBaseUrl: req.ApiBaseUrl, ApiBaseUrl: req.ApiBaseUrl,
ContactInfo: req.ContactInfo, ContactInfo: req.ContactInfo,
DocUrl: req.DocUrl,
DefaultConcurrency: req.DefaultConcurrency, DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance, DefaultBalance: req.DefaultBalance,
} }
@@ -256,3 +258,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
response.Success(c, gin.H{"message": "Test email sent successfully"}) response.Success(c, gin.H{"message": "Test email sent successfully"})
} }
// GetAdminApiKey 获取管理员 API Key 状态
// GET /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get admin API key status: "+err.Error())
return
}
response.Success(c, gin.H{
"exists": exists,
"masked_key": maskedKey,
})
}
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
// POST /api/v1/admin/settings/admin-api-key/regenerate
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to generate admin API key: "+err.Error())
return
}
response.Success(c, gin.H{
"key": key, // 完整 key 只在生成时返回一次
})
}
// DeleteAdminApiKey 删除管理员 API Key
// DELETE /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
response.InternalError(c, "Failed to delete admin API key: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Admin API key deleted"})
}

View File

@@ -3,16 +3,16 @@ package admin
import ( import (
"strconv" "strconv"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// toResponsePagination converts repository.PaginationResult to response.PaginationResult // toResponsePagination converts pagination.PaginationResult to response.PaginationResult
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult { func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult {
if p == nil { if p == nil {
return nil return nil
} }

View File

@@ -4,12 +4,11 @@ import (
"net/http" "net/http"
"time" "time"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/pkg/sysutil" "github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
) )
// SystemHandler handles system-related operations // SystemHandler handles system-related operations
@@ -18,9 +17,9 @@ type SystemHandler struct {
} }
// NewSystemHandler creates a new SystemHandler // NewSystemHandler creates a new SystemHandler
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler { func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
return &SystemHandler{ return &SystemHandler{
updateSvc: service.NewUpdateService(rdb, version, buildType), updateSvc: updateSvc,
} }
} }

View File

@@ -4,34 +4,32 @@ import (
"strconv" "strconv"
"time" "time"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// UsageHandler handles admin usage-related requests // UsageHandler handles admin usage-related requests
type UsageHandler struct { type UsageHandler struct {
usageRepo *repository.UsageLogRepository
apiKeyRepo *repository.ApiKeyRepository
usageService *service.UsageService usageService *service.UsageService
apiKeyService *service.ApiKeyService
adminService service.AdminService adminService service.AdminService
} }
// NewUsageHandler creates a new admin usage handler // NewUsageHandler creates a new admin usage handler
func NewUsageHandler( func NewUsageHandler(
usageRepo *repository.UsageLogRepository,
apiKeyRepo *repository.ApiKeyRepository,
usageService *service.UsageService, usageService *service.UsageService,
apiKeyService *service.ApiKeyService,
adminService service.AdminService, adminService service.AdminService,
) *UsageHandler { ) *UsageHandler {
return &UsageHandler{ return &UsageHandler{
usageRepo: usageRepo, usageService: usageService,
apiKeyRepo: apiKeyRepo, apiKeyService: apiKeyService,
usageService: usageService, adminService: adminService,
adminService: adminService,
} }
} }
@@ -82,15 +80,15 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t endTime = &t
} }
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := repository.UsageLogFilters{ filters := usagestats.UsageLogFilters{
UserID: userID, UserID: userID,
ApiKeyID: apiKeyID, ApiKeyID: apiKeyID,
StartTime: startTime, StartTime: startTime,
EndTime: endTime, EndTime: endTime,
} }
records, result, err := h.usageRepo.ListWithFilters(c.Request.Context(), params, filters) records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error()) response.InternalError(c, "Failed to list usage records: "+err.Error())
return return
@@ -178,7 +176,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
} }
// Get global stats // Get global stats
stats, err := h.usageRepo.GetGlobalStats(c.Request.Context(), startTime, endTime) stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error()) response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return return
@@ -192,7 +190,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
func (h *UsageHandler) SearchUsers(c *gin.Context) { func (h *UsageHandler) SearchUsers(c *gin.Context) {
keyword := c.Query("q") keyword := c.Query("q")
if keyword == "" { if keyword == "" {
response.Success(c, []interface{}{}) response.Success(c, []any{})
return return
} }
@@ -236,7 +234,7 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
userID = id userID = id
} }
keys, err := h.apiKeyRepo.SearchApiKeys(c.Request.Context(), userID, keyword, 30) keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil { if err != nil {
response.InternalError(c, "Failed to search API keys: "+err.Error()) response.InternalError(c, "Failed to search API keys: "+err.Error())
return return

View File

@@ -3,8 +3,8 @@ package admin
import ( import (
"strconv" "strconv"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -25,6 +25,9 @@ func NewUserHandler(adminService service.AdminService) *UserHandler {
type CreateUserRequest struct { type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"` Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Balance float64 `json:"balance"` Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"` Concurrency int `json:"concurrency"`
AllowedGroups []int64 `json:"allowed_groups"` AllowedGroups []int64 `json:"allowed_groups"`
@@ -35,6 +38,9 @@ type CreateUserRequest struct {
type UpdateUserRequest struct { type UpdateUserRequest struct {
Email string `json:"email" binding:"omitempty,email"` Email string `json:"email" binding:"omitempty,email"`
Password string `json:"password" binding:"omitempty,min=6"` Password string `json:"password" binding:"omitempty,min=6"`
Username *string `json:"username"`
Wechat *string `json:"wechat"`
Notes *string `json:"notes"`
Balance *float64 `json:"balance"` Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"` Status string `json:"status" binding:"omitempty,oneof=active disabled"`
@@ -43,8 +49,9 @@ type UpdateUserRequest struct {
// UpdateBalanceRequest represents balance update request // UpdateBalanceRequest represents balance update request
type UpdateBalanceRequest struct { type UpdateBalanceRequest struct {
Balance float64 `json:"balance" binding:"required"` Balance float64 `json:"balance" binding:"required,gt=0"`
Operation string `json:"operation" binding:"required,oneof=set add subtract"` Operation string `json:"operation" binding:"required,oneof=set add subtract"`
Notes string `json:"notes"`
} }
// List handles listing all users with pagination // List handles listing all users with pagination
@@ -94,6 +101,9 @@ func (h *UserHandler) Create(c *gin.Context) {
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{ user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
Email: req.Email, Email: req.Email,
Password: req.Password, Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance, Balance: req.Balance,
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
AllowedGroups: req.AllowedGroups, AllowedGroups: req.AllowedGroups,
@@ -125,6 +135,9 @@ func (h *UserHandler) Update(c *gin.Context) {
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{ user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
Email: req.Email, Email: req.Email,
Password: req.Password, Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance, Balance: req.Balance,
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
Status: req.Status, Status: req.Status,
@@ -171,7 +184,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return return
} }
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation) user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
if err != nil { if err != nil {
response.InternalError(c, "Failed to update balance: "+err.Error()) response.InternalError(c, "Failed to update balance: "+err.Error())
return return

View File

@@ -3,10 +3,10 @@ package handler
import ( import (
"strconv" "strconv"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -53,7 +53,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
} }
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params) keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
if err != nil { if err != nil {

View File

@@ -1,9 +1,9 @@
package handler package handler
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -7,28 +7,24 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"strings"
"time" "time"
"sub2api/internal/middleware" "github.com/Wei-Shaw/sub2api/internal/middleware"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
const (
// Maximum wait time for concurrency slot
maxConcurrencyWait = 60 * time.Second
// Ping interval during wait
pingInterval = 5 * time.Second
)
// GatewayHandler handles API gateway requests // GatewayHandler handles API gateway requests
type GatewayHandler struct { type GatewayHandler struct {
gatewayService *service.GatewayService gatewayService *service.GatewayService
userService *service.UserService userService *service.UserService
concurrencyService *service.ConcurrencyService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
} }
// NewGatewayHandler creates a new GatewayHandler // NewGatewayHandler creates a new GatewayHandler
@@ -36,8 +32,8 @@ func NewGatewayHandler(gatewayService *service.GatewayService, userService *serv
return &GatewayHandler{ return &GatewayHandler{
gatewayService: gatewayService, gatewayService: gatewayService,
userService: userService, userService: userService,
concurrencyService: concurrencyService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
} }
} }
@@ -87,7 +83,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 0. 检查wait队列是否已满 // 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(user.Concurrency) maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyService.IncrementWaitCount(c.Request.Context(), user.ID, maxWait) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed // On error, allow request to proceed
@@ -96,10 +92,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 确保在函数退出时减少wait计数 // 确保在函数退出时减少wait计数
defer h.concurrencyService.DecrementWaitCount(c.Request.Context(), user.ID) defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
// 1. 首先获取用户并发槽位 // 1. 首先获取用户并发槽位
userReleaseFunc, err := h.acquireUserSlotWithWait(c, user, req.Stream, &streamStarted) userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
if err != nil { if err != nil {
log.Printf("User concurrency acquire failed: %v", err) log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -126,8 +122,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
} else {
sendMockWarmupResponse(c, req.Model)
}
return
}
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted) accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
@@ -161,154 +167,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}() }()
} }
// acquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
}
// acquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
}
// concurrencyError represents a concurrency limit error with context
type concurrencyError struct {
SlotType string
IsTimeout bool
}
func (e *concurrencyError) Error() string {
if e.IsTimeout {
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
}
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests
// Note: For streaming requests, we send ping to keep the connection alive.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller)
func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
defer cancel()
// For streaming requests, set up SSE headers for ping
var flusher http.Flusher
if isStream {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported")
}
}
pingTicker := time.NewTicker(pingInterval)
defer pingTicker.Stop()
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
for {
select {
case <-ctx.Done():
return nil, &concurrencyError{
SlotType: slotType,
IsTimeout: true,
}
case <-pingTicker.C:
// Send ping for streaming requests to keep connection alive
if isStream && flusher != nil {
// Set headers on first ping (lazy initialization)
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
flusher.Flush()
}
case <-pollTicker.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
}
}
// Models handles listing available models // Models handles listing available models
// GET /v1/models // GET /v1/models
// Returns different model lists based on the API key's group platform
func (h *GatewayHandler) Models(c *gin.Context) { func (h *GatewayHandler) Models(c *gin.Context) {
models := []gin.H{ apiKey, _ := middleware.GetApiKeyFromContext(c)
{
"id": "claude-opus-4-5-20251101", // Return OpenAI models for OpenAI platform groups
"type": "model", if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
"display_name": "Claude Opus 4.5", c.JSON(http.StatusOK, gin.H{
"created_at": "2025-11-01T00:00:00Z", "object": "list",
}, "data": openai.DefaultModels,
{ })
"id": "claude-sonnet-4-5-20250929", return
"type": "model",
"display_name": "Claude Sonnet 4.5",
"created_at": "2025-09-29T00:00:00Z",
},
{
"id": "claude-haiku-4-5-20251001",
"type": "model",
"display_name": "Claude Haiku 4.5",
"created_at": "2025-10-01T00:00:00Z",
},
} }
// Default: Claude models
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"data": models,
"object": "list", "object": "list",
"data": claude.DefaultModels,
}) })
} }
@@ -423,7 +300,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
if ok { if ok {
// Send error event in SSE format // Send error event in SSE format
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
fmt.Fprint(c.Writer, errorEvent) if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush() flusher.Flush()
} }
return return
@@ -443,3 +322,155 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
}, },
}) })
} }
// CountTokens handles token counting endpoint
// POST /v1/messages/count_tokens
// 特点:校验订阅/余额,但不计算并发、不记录使用量
func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// 解析请求获取模型名
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 获取订阅信息可能为nil
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
return
}
// 计算粘性会话 hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
// 错误响应已在 ForwardCountTokens 中处理
return
}
}
// isWarmupRequest 检测是否为预热请求标题生成、Warmup等
func isWarmupRequest(body []byte) bool {
// 快速检查如果body不包含关键字直接返回false
bodyStr := string(body)
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
return false
}
// 解析完整请求
var req struct {
Messages []struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
} `json:"messages"`
System []struct {
Text string `json:"text"`
} `json:"system"`
}
if err := json.Unmarshal(body, &req); err != nil {
return false
}
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return true
}
}
}
}
// 检查 system 中的标题提取模式
for _, system := range req.System {
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return true
}
}
return false
}
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
func sendMockWarmupStream(c *gin.Context, model string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
events := []string{
`event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`,
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
}
for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush()
time.Sleep(20 * time.Millisecond)
}
}
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) {
c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup",
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
"stop_reason": "end_turn",
"usage": gin.H{
"input_tokens": 10,
"output_tokens": 2,
},
})
}

View File

@@ -0,0 +1,180 @@
package handler
import (
"context"
"fmt"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
maxConcurrencyWait = 30 * time.Second
// pingInterval is the interval for sending ping events during slot wait
pingInterval = 15 * time.Second
)
// SSEPingFormat defines the format of SSE ping events for different platforms
type SSEPingFormat string
const (
// SSEPingFormatClaude is the Claude/Anthropic SSE ping format
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone SSEPingFormat = ""
)
// ConcurrencyError represents a concurrency limit error with context
type ConcurrencyError struct {
SlotType string
IsTimeout bool
}
func (e *ConcurrencyError) Error() string {
if e.IsTimeout {
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
}
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
}
// ConcurrencyHelper provides common concurrency slot management for gateway handlers
type ConcurrencyHelper struct {
concurrencyService *service.ConcurrencyService
pingFormat SSEPingFormat
}
// NewConcurrencyHelper creates a new ConcurrencyHelper
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
return &ConcurrencyHelper{
concurrencyService: concurrencyService,
pingFormat: pingFormat,
}
}
// IncrementWaitCount increments the wait count for a user
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
}
// DecrementWaitCount decrements the wait count for a user
func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) {
h.concurrencyService.DecrementWaitCount(ctx, userID)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
}
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
defer cancel()
// Determine if ping is needed (streaming + ping format defined)
needPing := isStream && h.pingFormat != ""
var flusher http.Flusher
if needPing {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported")
}
}
// Only create ping ticker if ping is needed
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
for {
select {
case <-ctx.Done():
return nil, &ConcurrencyError{
SlotType: slotType,
IsTimeout: true,
}
case <-pingCh:
// Send ping to keep connection alive
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
return nil, err
}
flusher.Flush()
case <-pollTicker.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
}
}

View File

@@ -1,11 +1,7 @@
package handler package handler
import ( import (
"sub2api/internal/handler/admin" "github.com/Wei-Shaw/sub2api/internal/handler/admin"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/redis/go-redis/v9"
) )
// AdminHandlers contains all admin-related HTTP handlers // AdminHandlers contains all admin-related HTTP handlers
@@ -15,6 +11,7 @@ type AdminHandlers struct {
Group *admin.GroupHandler Group *admin.GroupHandler
Account *admin.AccountHandler Account *admin.AccountHandler
OAuth *admin.OAuthHandler OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
Proxy *admin.ProxyHandler Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler Redeem *admin.RedeemHandler
Setting *admin.SettingHandler Setting *admin.SettingHandler
@@ -25,15 +22,16 @@ type AdminHandlers struct {
// Handlers contains all HTTP handlers // Handlers contains all HTTP handlers
type Handlers struct { type Handlers struct {
Auth *AuthHandler Auth *AuthHandler
User *UserHandler User *UserHandler
APIKey *APIKeyHandler APIKey *APIKeyHandler
Usage *UsageHandler Usage *UsageHandler
Redeem *RedeemHandler Redeem *RedeemHandler
Subscription *SubscriptionHandler Subscription *SubscriptionHandler
Admin *AdminHandlers Admin *AdminHandlers
Gateway *GatewayHandler Gateway *GatewayHandler
Setting *SettingHandler OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler
} }
// BuildInfo contains build-time information // BuildInfo contains build-time information
@@ -41,30 +39,3 @@ type BuildInfo struct {
Version string Version string
BuildType string // "source" for manual builds, "release" for CI builds BuildType string // "source" for manual builds, "release" for CI builds
} }
// NewHandlers creates a new Handlers instance with all handlers initialized
func NewHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo BuildInfo) *Handlers {
return &Handlers{
Auth: NewAuthHandler(services.Auth),
User: NewUserHandler(services.User),
APIKey: NewAPIKeyHandler(services.ApiKey),
Usage: NewUsageHandler(services.Usage, repos.UsageLog, services.ApiKey),
Redeem: NewRedeemHandler(services.Redeem),
Subscription: NewSubscriptionHandler(services.Subscription),
Admin: &AdminHandlers{
Dashboard: admin.NewDashboardHandler(services.Admin, repos.UsageLog),
User: admin.NewUserHandler(services.Admin),
Group: admin.NewGroupHandler(services.Admin),
Account: admin.NewAccountHandler(services.Admin, services.OAuth, services.RateLimit, services.AccountUsage, services.AccountTest),
OAuth: admin.NewOAuthHandler(services.OAuth, services.Admin),
Proxy: admin.NewProxyHandler(services.Admin),
Redeem: admin.NewRedeemHandler(services.Admin),
Setting: admin.NewSettingHandler(services.Setting, services.Email),
System: admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType),
Subscription: admin.NewSubscriptionHandler(services.Subscription),
Usage: admin.NewUsageHandler(repos.UsageLog, repos.ApiKey, services.Usage, services.Admin),
},
Gateway: NewGatewayHandler(services.Gateway, services.User, services.Concurrency, services.BillingCache),
Setting: NewSettingHandler(services.Setting, buildInfo.Version),
}
}

View File

@@ -0,0 +1,209 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// OpenAIGatewayHandler handles OpenAI API gateway requests
type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
) *OpenAIGatewayHandler {
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
}
}
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// Read request body
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// Parse request body to map for potential modification
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
// For non-Codex CLI requests, set default instructions
userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) {
reqBody["instructions"] = openai.DefaultInstructions
// Re-serialize body
body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
// Track if we've started streaming (for error handling)
streamStarted := false
// Get subscription info (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
// Ensure wait count is decremented when function exits
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
// 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return
}
// Generate session hash (from header for OpenAI)
sessionHash := h.gatewayService.GenerateSessionHash(c)
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if err != nil {
// Error response already handled in Forward, just log
log.Printf("Forward request failed: %v", err)
return
}
// Async record usage
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}()
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in OpenAI SSE format
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
}
// Normal case: return JSON response with proper status code
h.errorResponse(c, status, errType, message)
}
// errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -1,9 +1,9 @@
package handler package handler
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -1,8 +1,8 @@
package handler package handler
import ( import (
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -1,9 +1,9 @@
package handler package handler
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -4,11 +4,11 @@ import (
"strconv" "strconv"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -16,15 +16,13 @@ import (
// UsageHandler handles usage-related requests // UsageHandler handles usage-related requests
type UsageHandler struct { type UsageHandler struct {
usageService *service.UsageService usageService *service.UsageService
usageRepo *repository.UsageLogRepository
apiKeyService *service.ApiKeyService apiKeyService *service.ApiKeyService
} }
// NewUsageHandler creates a new UsageHandler // NewUsageHandler creates a new UsageHandler
func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.UsageLogRepository, apiKeyService *service.ApiKeyService) *UsageHandler { func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
return &UsageHandler{ return &UsageHandler{
usageService: usageService, usageService: usageService,
usageRepo: usageRepo,
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
} }
} }
@@ -68,9 +66,9 @@ func (h *UsageHandler) List(c *gin.Context) {
apiKeyID = id apiKeyID = id
} }
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog var records []model.UsageLog
var result *repository.PaginationResult var result *pagination.PaginationResult
var err error var err error
if apiKeyID > 0 { if apiKeyID > 0 {
@@ -259,7 +257,7 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
return return
} }
stats, err := h.usageRepo.GetUserDashboardStats(c.Request.Context(), user.ID) stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get dashboard statistics") response.InternalError(c, "Failed to get dashboard statistics")
return return
@@ -286,7 +284,7 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
startTime, endTime := parseUserTimeRange(c) startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day") granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity) trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage trend") response.InternalError(c, "Failed to get usage trend")
return return
@@ -317,7 +315,7 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
startTime, endTime := parseUserTimeRange(c) startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageRepo.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime) stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get model statistics") response.InternalError(c, "Failed to get model statistics")
return return
@@ -357,12 +355,12 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
} }
if len(req.ApiKeyIDs) == 0 { if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}}) response.Success(c, gin.H{"stats": map[string]any{}})
return return
} }
// Verify ownership of all requested API keys // Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000}) userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil { if err != nil {
response.InternalError(c, "Failed to verify API key ownership") response.InternalError(c, "Failed to verify API key ownership")
return return
@@ -382,11 +380,11 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
} }
if len(validApiKeyIDs) == 0 { if len(validApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}}) response.Success(c, gin.H{"stats": map[string]any{}})
return return
} }
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs) stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get API key usage stats") response.InternalError(c, "Failed to get API key usage stats")
return return

View File

@@ -1,9 +1,9 @@
package handler package handler
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -26,6 +26,12 @@ type ChangePasswordRequest struct {
NewPassword string `json:"new_password" binding:"required,min=6"` NewPassword string `json:"new_password" binding:"required,min=6"`
} }
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
Wechat *string `json:"wechat"`
}
// GetProfile handles getting user profile // GetProfile handles getting user profile
// GET /api/v1/users/me // GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) { func (h *UserHandler) GetProfile(c *gin.Context) {
@@ -47,6 +53,9 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return return
} }
// 清空notes字段普通用户不应看到备注
userData.Notes = ""
response.Success(c, userData) response.Success(c, userData)
} }
@@ -83,3 +92,40 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
response.Success(c, gin.H{"message": "Password changed successfully"}) response.Success(c, gin.H{"message": "Password changed successfully"})
} }
// UpdateProfile handles updating user profile
// PUT /api/v1/users/me
func (h *UserHandler) UpdateProfile(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req UpdateProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.UpdateProfileRequest{
Username: req.Username,
Wechat: req.Wechat,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
if err != nil {
response.BadRequest(c, "Failed to update profile: "+err.Error())
return
}
// 清空notes字段普通用户不应看到备注
updatedUser.Notes = ""
response.Success(c, updatedUser)
}

View File

@@ -0,0 +1,108 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire"
)
// ProvideAdminHandlers creates the AdminHandlers struct
func ProvideAdminHandlers(
dashboardHandler *admin.DashboardHandler,
userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
settingHandler *admin.SettingHandler,
systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
User: userHandler,
Group: groupHandler,
Account: accountHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Setting: settingHandler,
System: systemHandler,
Subscription: subscriptionHandler,
Usage: usageHandler,
}
}
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
return admin.NewSystemHandler(updateService)
}
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
return NewSettingHandler(settingService, buildInfo.Version)
}
// ProvideHandlers creates the Handlers struct
func ProvideHandlers(
authHandler *AuthHandler,
userHandler *UserHandler,
apiKeyHandler *APIKeyHandler,
usageHandler *UsageHandler,
redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
User: userHandler,
APIKey: apiKeyHandler,
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler,
}
}
// ProviderSet is the Wire provider set for all handlers
var ProviderSet = wire.NewSet(
// Top-level handlers
NewAuthHandler,
NewUserHandler,
NewAPIKeyHandler,
NewUsageHandler,
NewRedeemHandler,
NewSubscriptionHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
ProvideSettingHandler,
// Admin handlers
admin.NewDashboardHandler,
admin.NewUserHandler,
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewSettingHandler,
ProvideSystemHandler,
admin.NewSubscriptionHandler,
admin.NewUsageHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
ProvideHandlers,
)

View File

@@ -0,0 +1,38 @@
package infrastructure
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// InitDB 初始化数据库连接
func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 初始化时区(在数据库连接之前,确保时区设置正确)
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, err
}
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil {
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,16 @@
package infrastructure
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
}

View File

@@ -0,0 +1,25 @@
package infrastructure
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
// ProviderSet 提供基础设施层的依赖
var ProviderSet = wire.NewSet(
ProvideDB,
ProvideRedis,
)
// ProvideDB 提供数据库连接
func ProvideDB(cfg *config.Config) (*gorm.DB, error) {
return InitDB(cfg)
}
// ProvideRedis 提供 Redis 客户端
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}

View File

@@ -0,0 +1,130 @@
package middleware
import (
"context"
"crypto/subtle"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"github.com/gin-gonic/gin"
)
// AdminAuth 管理员认证中间件
// 支持两种认证方式(通过不同的 header 区分):
// 1. Admin API Key: x-api-key: <admin-api-key>
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
func AdminAuth(
authService *service.AuthService,
userRepo interface {
GetByID(ctx context.Context, id int64) (*model.User, error)
GetFirstAdmin(ctx context.Context) (*model.User, error)
},
settingService *service.SettingService,
) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查 x-api-key headerAdmin API Key 认证)
apiKey := c.GetHeader("x-api-key")
if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userRepo) {
return
}
c.Next()
return
}
// 检查 Authorization headerJWT 认证)
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
if !validateJWTForAdmin(c, parts[1], authService, userRepo) {
return
}
c.Next()
return
}
}
// 无有效认证信息
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
}
}
// validateAdminApiKey 验证管理员 API Key
func validateAdminApiKey(
c *gin.Context,
key string,
settingService *service.SettingService,
userRepo interface {
GetFirstAdmin(ctx context.Context) (*model.User, error)
},
) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false
}
// 未配置或不匹配,统一返回相同错误(避免信息泄露)
if storedKey == "" || subtle.ConstantTimeCompare([]byte(key), []byte(storedKey)) != 1 {
AbortWithError(c, 401, "INVALID_ADMIN_KEY", "Invalid admin API key")
return false
}
// 获取真实的管理员用户
admin, err := userRepo.GetFirstAdmin(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
return false
}
c.Set(string(ContextKeyUser), admin)
c.Set("auth_method", "admin_api_key")
return true
}
// validateJWTForAdmin 验证 JWT 并检查管理员权限
func validateJWTForAdmin(
c *gin.Context,
token string,
authService *service.AuthService,
userRepo interface {
GetByID(ctx context.Context, id int64) (*model.User, error)
},
) bool {
// 验证 JWT token
claims, err := authService.ValidateToken(token)
if err != nil {
if err == service.ErrTokenExpired {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return false
}
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
return false
}
// 从数据库获取用户
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID)
if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return false
}
// 检查用户状态
if !user.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return false
}
// 检查管理员权限
if user.Role != model.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return false
}
c.Set(string(ContextKeyUser), user)
c.Set("auth_method", "jwt")
return true
}

View File

@@ -1,7 +1,7 @@
package middleware package middleware
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -3,9 +3,9 @@ package middleware
import ( import (
"context" "context"
"errors" "errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"log" "log"
"strings" "strings"
"sub2api/internal/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"

View File

@@ -2,9 +2,9 @@ package middleware
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"strings" "strings"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -10,7 +10,7 @@ import (
) )
// JSONB 用于存储JSONB数据 // JSONB 用于存储JSONB数据
type JSONB map[string]interface{} type JSONB map[string]any
func (j JSONB) Value() (driver.Value, error) { func (j JSONB) Value() (driver.Value, error) {
if j == nil { if j == nil {
@@ -19,7 +19,7 @@ func (j JSONB) Value() (driver.Value, error) {
return json.Marshal(j) return json.Marshal(j)
} }
func (j *JSONB) Scan(value interface{}) error { func (j *JSONB) Scan(value any) error {
if value == nil { if value == nil {
*j = nil *j = nil
return nil return nil
@@ -40,8 +40,8 @@ type Account struct {
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息 Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
ProxyID *int64 `gorm:"index" json:"proxy_id"` ProxyID *int64 `gorm:"index" json:"proxy_id"`
Concurrency int `gorm:"default:3;not null" json:"concurrency"` Concurrency int `gorm:"default:3;not null" json:"concurrency"`
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100越小越高 Priority int `gorm:"default:50;not null" json:"priority"` // 1-100越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
ErrorMessage string `gorm:"type:text" json:"error_message"` ErrorMessage string `gorm:"type:text" json:"error_message"`
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"` LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"` CreatedAt time.Time `gorm:"not null" json:"created_at"`
@@ -68,7 +68,8 @@ type Account struct {
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"` AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库) // 虚拟字段 (不存储到数据库)
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"` GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
Groups []*Group `gorm:"-" json:"groups,omitempty"`
} }
func (Account) TableName() string { func (Account) TableName() string {
@@ -145,7 +146,7 @@ func (a *Account) GetModelMapping() map[string]string {
return nil return nil
} }
// 处理map[string]interface{}类型 // 处理map[string]interface{}类型
if m, ok := raw.(map[string]interface{}); ok { if m, ok := raw.(map[string]any); ok {
result := make(map[string]string) result := make(map[string]string)
for k, v := range m { for k, v := range m {
if s, ok := v.(string); ok { if s, ok := v.(string); ok {
@@ -163,7 +164,7 @@ func (a *Account) GetModelMapping() map[string]string {
// 如果没有设置模型映射,则支持所有模型 // 如果没有设置模型映射,则支持所有模型
func (a *Account) IsModelSupported(requestedModel string) bool { func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 { if len(mapping) == 0 {
return true // 没有映射配置,支持所有模型 return true // 没有映射配置,支持所有模型
} }
_, exists := mapping[requestedModel] _, exists := mapping[requestedModel]
@@ -174,7 +175,7 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
// 如果没有映射,返回原始模型名 // 如果没有映射,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string { func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 { if len(mapping) == 0 {
return requestedModel return requestedModel
} }
if mappedModel, exists := mapping[requestedModel]; exists { if mappedModel, exists := mapping[requestedModel]; exists {
@@ -231,7 +232,7 @@ func (a *Account) GetCustomErrorCodes() []int {
return nil return nil
} }
// 处理 []interface{} 类型JSON反序列化后的格式 // 处理 []interface{} 类型JSON反序列化后的格式
if arr, ok := raw.([]interface{}); ok { if arr, ok := raw.([]any); ok {
result := make([]int, 0, len(arr)) result := make([]int, 0, len(arr))
for _, v := range arr { for _, v := range arr {
// JSON 数字默认解析为 float64 // JSON 数字默认解析为 float64
@@ -263,3 +264,152 @@ func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
} }
return false return false
} }
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
// 启用后标题生成、Warmup等预热请求将返回mock响应不消耗上游token
func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil {
return false
}
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// =============== OpenAI 相关方法 ===============
// IsOpenAI 检查是否为 OpenAI 平台账号
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}
// IsAnthropic 检查是否为 Anthropic 平台账号
func (a *Account) IsAnthropic() bool {
return a.Platform == PlatformAnthropic
}
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号Response 账号)
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
}
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
// 对于 API Key 类型账号,从 credentials 中获取 base_url
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
}
}
return "https://api.openai.com" // OpenAI 默认 API URL
}
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
func (a *Account) GetOpenAIAccessToken() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("access_token")
}
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
func (a *Account) GetOpenAIRefreshToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("refresh_token")
}
// GetOpenAIIDToken 获取 OpenAI ID TokenJWT包含用户信息
func (a *Account) GetOpenAIIDToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("id_token")
}
// GetOpenAIApiKey 获取 OpenAI API Key用于 Response 账号)
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
return ""
}
return a.GetCredential("api_key")
}
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
// 返回空字符串表示透传原始 User-Agent
func (a *Account) GetOpenAIUserAgent() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("user_agent")
}
// GetChatGPTAccountID 获取 ChatGPT 账号 ID从 ID Token 解析)
func (a *Account) GetChatGPTAccountID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_account_id")
}
// GetChatGPTUserID 获取 ChatGPT 用户 ID从 ID Token 解析)
func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_user_id")
}
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
func (a *Account) GetOpenAIOrganizationID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("organization_id")
}
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() {
return nil
}
expiresAtStr := a.GetCredential("expires_at")
if expiresAtStr == "" {
return nil
}
// 尝试解析时间
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
// 尝试解析为 Unix 时间戳
if v, ok := a.Credentials["expires_at"].(float64); ok {
t = time.Unix(int64(v), 0)
return &t
}
return nil
}
return &t
}
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
func (a *Account) IsOpenAITokenExpired() bool {
expiresAt := a.GetOpenAITokenExpiresAt()
if expiresAt == nil {
return false // 没有过期时间信息,假设未过期
}
// 提前 60 秒认为过期,便于刷新
return time.Now().Add(60 * time.Second).After(*expiresAt)
}

View File

@@ -13,13 +13,13 @@ const (
) )
type Group struct { type Group struct {
ID int64 `gorm:"primaryKey" json:"id"` ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"` Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"` Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"` RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"` IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
// 订阅功能字段 // 订阅功能字段
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription

View File

@@ -9,15 +9,16 @@ import (
type RedeemCode struct { type RedeemCode struct {
ID int64 `gorm:"primaryKey" json:"id"` ID int64 `gorm:"primaryKey" json:"id"`
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"` Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数 Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"` UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"` UsedAt *time.Time `json:"used_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"` CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段 // 订阅类型专用字段
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用) GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用) ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
// 关联 // 关联
@@ -40,8 +41,10 @@ func (r *RedeemCode) CanUse() bool {
} }
// GenerateRedeemCode 生成唯一的兑换码 // GenerateRedeemCode 生成唯一的兑换码
func GenerateRedeemCode() string { func GenerateRedeemCode() (string, error) {
b := make([]byte, 16) b := make([]byte, 16)
rand.Read(b) if _, err := rand.Read(b); err != nil {
return hex.EncodeToString(b) return "", err
}
return hex.EncodeToString(b), nil
} }

View File

@@ -19,17 +19,17 @@ func (Setting) TableName() string {
// 设置Key常量 // 设置Key常量
const ( const (
// 注册设置 // 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置 // 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址 SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口 SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名 SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储 SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址 SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称 SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置 // Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证 SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
@@ -42,12 +42,19 @@ const (
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入 SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式 SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
) )
// 管理员 API Key 前缀(与用户 sk- 前缀区分)
const AdminApiKeyPrefix = "admin-"
// SystemSettings 系统设置结构体用于API响应 // SystemSettings 系统设置结构体用于API响应
type SystemSettings struct { type SystemSettings struct {
// 注册设置 // 注册设置
@@ -74,6 +81,7 @@ type SystemSettings struct {
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"` ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"` ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置 // 默认配置
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
@@ -91,5 +99,6 @@ type PublicSettings struct {
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"` ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"` ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"` Version string `json:"version"`
} }

View File

@@ -37,7 +37,7 @@ type UsageLog struct {
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"` OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"` CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"` CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用 TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用 ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率 RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率

View File

@@ -9,8 +9,11 @@ import (
) )
type User struct { type User struct {
ID int64 `gorm:"primaryKey" json:"id"` ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"` Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
Username string `gorm:"size:100;default:''" json:"username"`
Wechat string `gorm:"size:100;default:''" json:"wechat"`
Notes string `gorm:"type:text;default:''" json:"notes"`
PasswordHash string `gorm:"size:255;not null" json:"-"` PasswordHash string `gorm:"size:255;not null" json:"-"`
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"` Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
@@ -22,7 +25,8 @@ type User struct {
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联 // 关联
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"` ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
} }
func (User) TableName() string { func (User) TableName() string {

View File

@@ -0,0 +1,74 @@
package claude
// Claude Code 客户端相关常量
// Beta header 常量
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0",
"X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64",
"X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0",
"X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
// Model 表示一个 Claude 模型
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels Claude Code 客户端支持的默认模型列表
var DefaultModels = []Model{
{
ID: "claude-opus-4-5-20251101",
Type: "model",
DisplayName: "Claude Opus 4.5",
CreatedAt: "2025-11-01T00:00:00Z",
},
{
ID: "claude-sonnet-4-5-20250929",
Type: "model",
DisplayName: "Claude Sonnet 4.5",
CreatedAt: "2025-09-29T00:00:00Z",
},
{
ID: "claude-haiku-4-5-20251001",
Type: "model",
DisplayName: "Claude Haiku 4.5",
CreatedAt: "2025-10-01T00:00:00Z",
},
}
// DefaultModelIDs 返回默认模型的 ID 列表
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"

View File

@@ -43,18 +43,25 @@ type OAuthSession struct {
type SessionStore struct { type SessionStore struct {
mu sync.RWMutex mu sync.RWMutex
sessions map[string]*OAuthSession sessions map[string]*OAuthSession
stopCh chan struct{}
} }
// NewSessionStore creates a new session store // NewSessionStore creates a new session store
func NewSessionStore() *SessionStore { func NewSessionStore() *SessionStore {
store := &SessionStore{ store := &SessionStore{
sessions: make(map[string]*OAuthSession), sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
} }
// Start cleanup goroutine // Start cleanup goroutine
go store.cleanup() go store.cleanup()
return store return store
} }
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
}
// Set stores a session // Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) { func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock() s.mu.Lock()
@@ -87,14 +94,20 @@ func (s *SessionStore) Delete(sessionID string) {
// cleanup removes expired sessions periodically // cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() { func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
for range ticker.C { defer ticker.Stop()
s.mu.Lock() for {
for id, session := range s.sessions { select {
if time.Since(session.CreatedAt) > SessionTTL { case <-s.stopCh:
delete(s.sessions, id) return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
} }
s.mu.Unlock()
} }
s.mu.Unlock()
} }
} }

View File

@@ -0,0 +1,42 @@
package openai
import _ "embed"
// Model represents an OpenAI model
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
}
// DefaultModels OpenAI models list
var DefaultModels = []Model{
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
}
// DefaultModelIDs returns the default model ID list
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel default model for testing OpenAI accounts
const DefaultTestModel = "gpt-5.1-codex"
// DefaultInstructions default instructions for non-Codex CLI requests
// Content loaded from instructions.txt at compile time
//
//go:embed instructions.txt
var DefaultInstructions string

View File

@@ -0,0 +1,118 @@
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
## General
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
## Editing constraints
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like \"Assigns the value to the variable\", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
- You may be in a dirty git worktree.
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
* If the changes are in unrelated files, just ignore them and don't revert them.
- Do not amend a commit unless explicitly requested to do so.
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
## Plan tool
When using the planning tool:
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
- Do not make single-step plans.
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
## Codex CLI harness, sandboxing, and approvals
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
- **read-only**: The sandbox only permits reading files.
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
- **restricted**: Requires approval
- **enabled**: No approval needed
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
- (for all of these, you should weigh alternative paths that do not require approval)
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to \"never\", in which case never ask for approvals.
When requesting approval to execute a command that will require escalated privileges:
- Provide the `sandbox_permissions` parameter with the value `\"require_escalated\"`
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
## Special user requests
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
- If the user asks for a \"review\", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
## Frontend tasks
When doing frontend design tasks, avoid collapsing into \"AI slop\" or safe, average-looking layouts.
Aim for interfaces that feel intentional, bold, and a bit surprising.
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
- Ensure the page loads properly on both desktop and mobile
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
## Presenting your work and final message
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
- Default: be very concise; friendly coding teammate tone.
- Ask only when needed; suggest ideas; mirror the user's style.
- For substantial work, summarize clearly; follow finalanswer formatting.
- Skip heavy formatting for simple confirmations.
- Don't dump large files you've written; reference paths only.
- No \"save/copy this file\" - User is on the same machine.
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
- For code changes:
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with \"summary\", just jump right in.
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
### Final answer structure and style guidelines
- Plain text; CLI handles styling. Use structure only when it helps scanability.
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
- Bullets: use - ; merge related points; keep to one line when possible; 46 per list ordered by importance; keep phrasing consistent.
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
- Tone: collaborative, concise, factual; present tense, active voice; selfcontained; no \"above/below\"; parallel wording.
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
- File References: When referencing files in your response follow the below rules:
* Use inline code to make file paths clickable.
* Each reference should have a stand alone path. Even if it's the same file.
* Accepted: absolute, workspacerelative, a/ or b/ diff prefixes, or bare filename/suffix.
* Optionally include line/column (1based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
* Do not use URIs like file://, vscode://, or https://.
* Do not provide range of lines
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\\repo\\project\\main.rs:12:5

View File

@@ -0,0 +1,366 @@
package openai
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
// OpenAI OAuth Constants (from CRS project - Codex CLI client)
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
TokenURL = "https://auth.openai.com/oauth/token"
// Default redirect URI (can be customized)
DefaultRedirectURI = "http://localhost:1455/auth/callback"
// Scopes
DefaultScopes = "openid profile email offline_access"
// RefreshScopes - scope for token refresh (without offline_access, aligned with CRS project)
RefreshScopes = "openid profile email"
// Session TTL
SessionTTL = 30 * time.Minute
)
// OAuthSession stores OAuth flow state for OpenAI
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages OAuth sessions in memory
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
// Get retrieves a session
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
// Delete removes a session
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
}
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
// GenerateRandomBytes generates cryptographically secure random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateState generates a random state string for OAuth
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateSessionID generates a unique session ID
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (64 bytes -> hex for OpenAI)
// OpenAI uses hex encoding instead of base64url
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(64)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
// Uses base64url encoding as per RFC 7636
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", DefaultScopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
// OpenAI specific parameters
params.Set("id_token_add_organizations", "true")
params.Set("codex_cli_simplified_flow", "true")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
}
// TokenResponse represents the token response from OpenAI OAuth
type TokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
}
// IDTokenClaims represents the claims from OpenAI ID Token
type IDTokenClaims struct {
// Standard claims
Sub string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Iss string `json:"iss"`
Aud []string `json:"aud"` // OpenAI returns aud as an array
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
// OpenAI specific claims (nested under https://api.openai.com/auth)
OpenAIAuth *OpenAIAuthClaims `json:"https://api.openai.com/auth,omitempty"`
}
// OpenAIAuthClaims represents the OpenAI specific auth claims
type OpenAIAuthClaims struct {
ChatGPTAccountID string `json:"chatgpt_account_id"`
ChatGPTUserID string `json:"chatgpt_user_id"`
UserID string `json:"user_id"`
Organizations []OrganizationClaim `json:"organizations"`
}
// OrganizationClaim represents an organization in the ID Token
type OrganizationClaim struct {
ID string `json:"id"`
Role string `json:"role"`
Title string `json:"title"`
IsDefault bool `json:"is_default"`
}
// BuildTokenRequest creates a token exchange request for OpenAI
func BuildTokenRequest(code, codeVerifier, redirectURI string) *TokenRequest {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: redirectURI,
CodeVerifier: codeVerifier,
}
}
// BuildRefreshTokenRequest creates a refresh token request for OpenAI
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
Scope: RefreshScopes,
}
}
// ToFormData converts TokenRequest to URL-encoded form data
func (r *TokenRequest) ToFormData() string {
params := url.Values{}
params.Set("grant_type", r.GrantType)
params.Set("client_id", r.ClientID)
params.Set("code", r.Code)
params.Set("redirect_uri", r.RedirectURI)
params.Set("code_verifier", r.CodeVerifier)
return params.Encode()
}
// ToFormData converts RefreshTokenRequest to URL-encoded form data
func (r *RefreshTokenRequest) ToFormData() string {
params := url.Values{}
params.Set("grant_type", r.GrantType)
params.Set("client_id", r.ClientID)
params.Set("refresh_token", r.RefreshToken)
params.Set("scope", r.Scope)
return params.Encode()
}
// ParseIDToken parses the ID Token JWT and extracts claims
// Note: This does NOT verify the signature - it only decodes the payload
// For production, you should verify the token signature using OpenAI's public keys
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
// Decode payload (second part)
payload := parts[1]
// Add padding if necessary
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
// Try standard encoding
decoded, err = base64.StdEncoding.DecodeString(payload)
if err != nil {
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
}
}
var claims IDTokenClaims
if err := json.Unmarshal(decoded, &claims); err != nil {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
return &claims, nil
}
// ExtractUserInfo extracts user information from ID Token claims
type UserInfo struct {
Email string
ChatGPTAccountID string
ChatGPTUserID string
UserID string
OrganizationID string
Organizations []OrganizationClaim
}
// GetUserInfo extracts user info from ID Token claims
func (c *IDTokenClaims) GetUserInfo() *UserInfo {
info := &UserInfo{
Email: c.Email,
}
if c.OpenAIAuth != nil {
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
info.UserID = c.OpenAIAuth.UserID
info.Organizations = c.OpenAIAuth.Organizations
// Get default organization ID
for _, org := range c.OpenAIAuth.Organizations {
if org.IsDefault {
info.OrganizationID = org.ID
break
}
}
// If no default, use first org
if info.OrganizationID == "" && len(c.OpenAIAuth.Organizations) > 0 {
info.OrganizationID = c.OpenAIAuth.Organizations[0].ID
}
}
return info
}

View File

@@ -0,0 +1,18 @@
package openai
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
var CodexCLIUserAgentPrefixes = []string{
"codex_vscode/",
"codex_cli_rs/",
}
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
func IsCodexCLIRequest(userAgent string) bool {
for _, prefix := range CodexCLIUserAgentPrefixes {
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
return true
}
}
return false
}

View File

@@ -0,0 +1,42 @@
package pagination
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}

View File

@@ -9,22 +9,22 @@ import (
// Response 标准API响应格式 // Response 标准API响应格式
type Response struct { type Response struct {
Code int `json:"code"` Code int `json:"code"`
Message string `json:"message"` Message string `json:"message"`
Data interface{} `json:"data,omitempty"` Data any `json:"data,omitempty"`
} }
// PaginatedData 分页数据格式(匹配前端期望) // PaginatedData 分页数据格式(匹配前端期望)
type PaginatedData struct { type PaginatedData struct {
Items interface{} `json:"items"` Items any `json:"items"`
Total int64 `json:"total"` Total int64 `json:"total"`
Page int `json:"page"` Page int `json:"page"`
PageSize int `json:"page_size"` PageSize int `json:"page_size"`
Pages int `json:"pages"` Pages int `json:"pages"`
} }
// Success 返回成功响应 // Success 返回成功响应
func Success(c *gin.Context, data interface{}) { func Success(c *gin.Context, data any) {
c.JSON(http.StatusOK, Response{ c.JSON(http.StatusOK, Response{
Code: 0, Code: 0,
Message: "success", Message: "success",
@@ -33,7 +33,7 @@ func Success(c *gin.Context, data interface{}) {
} }
// Created 返回创建成功响应 // Created 返回创建成功响应
func Created(c *gin.Context, data interface{}) { func Created(c *gin.Context, data any) {
c.JSON(http.StatusCreated, Response{ c.JSON(http.StatusCreated, Response{
Code: 0, Code: 0,
Message: "success", Message: "success",
@@ -75,7 +75,7 @@ func InternalError(c *gin.Context, message string) {
} }
// Paginated 返回分页数据 // Paginated 返回分页数据
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) { func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
pages := int(math.Ceil(float64(total) / float64(pageSize))) pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 { if pages < 1 {
pages = 1 pages = 1
@@ -90,7 +90,7 @@ func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize in
}) })
} }
// PaginationResult 分页结果(与repository.PaginationResult兼容 // PaginationResult 分页结果(与pagination.PaginationResult兼容
type PaginationResult struct { type PaginationResult struct {
Total int64 Total int64
Page int Page int
@@ -99,7 +99,7 @@ type PaginationResult struct {
} }
// PaginatedWithResult 使用PaginationResult返回分页数据 // PaginatedWithResult 使用PaginationResult返回分页数据
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) { func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
if pagination == nil { if pagination == nil {
Success(c, PaginatedData{ Success(c, PaginatedData{
Items: items, Items: items,

View File

@@ -1,43 +1,39 @@
package sysutil package sysutil
import ( import (
"fmt"
"log" "log"
"os/exec" "os"
"runtime" "runtime"
"time"
) )
const serviceName = "sub2api" // RestartService triggers a service restart by gracefully exiting.
// RestartService triggers a service restart via systemd.
// //
// IMPORTANT: This function initiates the restart and returns immediately. // This relies on systemd's Restart=always configuration to automatically
// The actual restart happens asynchronously - the current process will be killed // restart the service after it exits. This is the industry-standard approach:
// by systemd and a new process will be started. // - Simple and reliable
// // - No sudo permissions needed
// We use Start() instead of Run() because: // - No complex process management
// - systemctl restart will kill the current process first // - Leverages systemd's native restart capability
// - Run() waits for completion, but the process dies before completion
// - Start() spawns the command independently, allowing systemd to handle the full cycle
// //
// Prerequisites: // Prerequisites:
// - Linux OS with systemd // - Linux OS with systemd
// - NOPASSWD sudo access configured (install.sh creates /etc/sudoers.d/sub2api) // - Service configured with Restart=always in systemd unit file
func RestartService() error { func RestartService() error {
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {
return fmt.Errorf("systemd restart only available on Linux") log.Println("Service restart via exit only works on Linux with systemd")
return nil
} }
log.Println("Initiating service restart...") log.Println("Initiating service restart by graceful exit...")
log.Println("systemd will automatically restart the service (Restart=always)")
// The sub2api user has NOPASSWD sudo access for systemctl commands // Give a moment for logs to flush and response to be sent
// (configured by install.sh in /etc/sudoers.d/sub2api). go func() {
cmd := exec.Command("sudo", "systemctl", "restart", serviceName) time.Sleep(100 * time.Millisecond)
if err := cmd.Start(); err != nil { os.Exit(0)
return fmt.Errorf("failed to initiate service restart: %w", err) }()
}
log.Println("Service restart initiated successfully")
return nil return nil
} }

View File

@@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) {
func TestTimeNowAffected(t *testing.T) { func TestTimeNowAffected(t *testing.T) {
// Reset to UTC first // Reset to UTC first
Init("UTC") if err := Init("UTC"); err != nil {
t.Fatalf("Init failed with UTC: %v", err)
}
utcNow := time.Now() utcNow := time.Now()
// Switch to Shanghai (UTC+8) // Switch to Shanghai (UTC+8)
Init("Asia/Shanghai") if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
shanghaiNow := time.Now() shanghaiNow := time.Now()
// The times should be the same instant, but different timezone representation // The times should be the same instant, but different timezone representation
@@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) {
} }
func TestToday(t *testing.T) { func TestToday(t *testing.T) {
Init("Asia/Shanghai") if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
today := Today() today := Today()
now := Now() now := Now()
@@ -75,7 +81,9 @@ func TestToday(t *testing.T) {
} }
func TestStartOfDay(t *testing.T) { func TestStartOfDay(t *testing.T) {
Init("Asia/Shanghai") if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
// Create a time at 15:30:45 // Create a time at 15:30:45
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location()) testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
@@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) {
// This test demonstrates why Truncate(24*time.Hour) can be problematic // This test demonstrates why Truncate(24*time.Hour) can be problematic
// and why StartOfDay is more reliable for timezone-aware code // and why StartOfDay is more reliable for timezone-aware code
Init("Asia/Shanghai") if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
now := Now() now := Now()

View File

@@ -0,0 +1,8 @@
package usagestats
// AccountStats 账号使用统计
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}

View File

@@ -0,0 +1,209 @@
package usagestats
import "time"
// DashboardStats 仪表盘统计
type DashboardStats struct {
// 用户统计
TotalUsers int64 `json:"total_users"`
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 系统运行统计
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// TrendDataPoint represents a single point in trend data
type TrendDataPoint struct {
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ModelStat represents usage statistics for a single model
type ModelStat struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct {
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct {
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 性能统计
AverageDurationMs float64 `json:"average_duration_ms"`
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct {
UserID int64
ApiKeyID int64
StartTime *time.Time
EndTime *time.Time
}
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct {
ApiKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
ActualCost float64 `json:"actual_cost"`
}
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary struct {
Days int `json:"days"`
ActualDaysUsed int `json:"actual_days_used"`
TotalCost float64 `json:"total_cost"`
TotalStandardCost float64 `json:"total_standard_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
AvgDailyCost float64 `json:"avg_daily_cost"`
AvgDailyRequests float64 `json:"avg_daily_requests"`
AvgDailyTokens float64 `json:"avg_daily_tokens"`
AvgDurationMs float64 `json:"avg_duration_ms"`
Today *struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
} `json:"today"`
HighestCostDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
} `json:"highest_cost_day"`
HighestRequestDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
} `json:"highest_request_day"`
}
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse struct {
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
}

View File

@@ -2,10 +2,14 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type AccountRepository struct { type AccountRepository struct {
@@ -22,14 +26,34 @@ func (r *AccountRepository) Create(ctx context.Context, account *model.Account)
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups").First(&account, id).Error err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 填充 GroupIDs 虚拟字段 // 填充 GroupIDs 和 Groups 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups)) account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups { for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID) account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
return &account, nil
}
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
if crsAccountID == "" {
return nil, nil
}
var account model.Account
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
} }
return &account, nil return &account, nil
} }
@@ -47,12 +71,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
} }
func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) { func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "") return r.ListWithFilters(ctx, params, "", "", "", "")
} }
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query // ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) { func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account var accounts []model.Account
var total int64 var total int64
@@ -77,15 +101,19 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati
return nil, nil, err return nil, nil, err
} }
if err := db.Preload("Proxy").Preload("AccountGroups").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil { if err := db.Preload("Proxy").Preload("AccountGroups.Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
return nil, nil, err return nil, nil, err
} }
// 填充每个 Account 的 GroupIDs 虚拟字段 // 填充每个 Account 的虚拟字段(GroupIDs 和 Groups
for i := range accounts { for i := range accounts {
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups)) accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
for _, ag := range accounts[i].AccountGroups { for _, ag := range accounts[i].AccountGroups {
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID) accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
if ag.Group != nil {
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
}
} }
} }
@@ -94,7 +122,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati
pages++ pages++
} }
return accounts, &PaginationResult{ return accounts, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
@@ -130,7 +158,7 @@ func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{ Updates(map[string]any{
"status": model.StatusError, "status": model.StatusError,
"error_message": errorMsg, "error_message": errorMsg,
}).Error }).Error
@@ -221,12 +249,44 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
return accounts, err return accounts, err
} }
// ListSchedulableByPlatform 按平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Where("platform = ?", platform).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.platform = ?", platform).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
}
// SetRateLimited 标记账号为限流状态(429) // SetRateLimited 标记账号为限流状态(429)
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{ Updates(map[string]any{
"rate_limited_at": now, "rate_limited_at": now,
"rate_limit_reset_at": resetAt, "rate_limit_reset_at": resetAt,
}).Error }).Error
} }
@@ -240,7 +300,7 @@ func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until t
// ClearRateLimit 清除账号的限流状态 // ClearRateLimit 清除账号的限流状态
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error { func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{ Updates(map[string]any{
"rate_limited_at": nil, "rate_limited_at": nil,
"rate_limit_reset_at": nil, "rate_limit_reset_at": nil,
"overload_until": nil, "overload_until": nil,
@@ -249,7 +309,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
// UpdateSessionWindow 更新账号的5小时时间窗口信息 // UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]interface{}{ updates := map[string]any{
"session_window_status": status, "session_window_status": status,
} }
if start != nil { if start != nil {
@@ -266,3 +326,75 @@ func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("schedulable", schedulable).Error Update("schedulable", schedulable).Error
} }
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
}
// Get current account to preserve existing Extra data
var account model.Account
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
return err
}
// Initialize Extra if nil
if account.Extra == nil {
account.Extra = make(model.JSONB)
}
// Merge updates into existing Extra
for k, v := range updates {
account.Extra[k] = v
}
// Save updated Extra
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("extra", account.Extra).Error
}
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ports.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
updateMap := map[string]any{}
if updates.Name != nil {
updateMap["name"] = *updates.Name
}
if updates.ProxyID != nil {
updateMap["proxy_id"] = updates.ProxyID
}
if updates.Concurrency != nil {
updateMap["concurrency"] = *updates.Concurrency
}
if updates.Priority != nil {
updateMap["priority"] = *updates.Priority
}
if updates.Status != nil {
updateMap["status"] = *updates.Status
}
if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
}
if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
}
if len(updateMap) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).
Model(&model.Account{}).
Where("id IN ?", ids).
Clauses(clause.Returning{}).
Updates(updateMap)
return result.RowsAffected, result.Error
}

View File

@@ -0,0 +1,51 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
)
type apiKeyCache struct {
rdb *redis.Client
}
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
return &apiKeyCache{rdb: rdb}
}
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
}
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, err := pipe.Exec(ctx)
return err
}
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return c.rdb.Incr(ctx, apiKey).Err()
}
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return c.rdb.Expire(ctx, apiKey, ttl).Err()
}

View File

@@ -2,7 +2,8 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -45,7 +46,7 @@ func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
} }
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) { func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []model.ApiKey
var total int64 var total int64
@@ -64,7 +65,7 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
pages++ pages++
} }
return keys, &PaginationResult{ return keys, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
@@ -84,7 +85,7 @@ func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
return count > 0, err return count > 0, err
} }
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) { func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []model.ApiKey
var total int64 var total int64
@@ -103,7 +104,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
pages++ pages++
} }
return keys, &PaginationResult{ return keys, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),

View File

@@ -0,0 +1,174 @@
package repository
import (
"context"
"errors"
"fmt"
"log"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
)
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
)
var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
type billingCache struct {
rdb *redis.Client
}
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
return &billingCache{rdb: rdb}
}
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return 0, err
}
return strconv.ParseFloat(val, 64)
}
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
}
return nil
}
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
return c.parseSubscriptionCache(result)
}
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
result := &ports.SubscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
if data == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]any{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
}
return nil
}
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,228 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
)
type claudeOAuthService struct{}
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
return &claudeOAuthService{}
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := "https://claude.ai/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetSuccessResult(&orgs).
Get(targetURL)
if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
}
if len(orgs) == 0 {
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
return orgs[0].UUID, nil
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
reqBody := map[string]any{
"response_type": "code",
"client_id": oauth.ClientID,
"organization_uuid": orgUUID,
"redirect_uri": oauth.RedirectURI,
"scope": scope,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
RedirectURI string `json:"redirect_uri"`
}
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetHeader("Accept", "application/json").
SetHeader("Accept-Language", "en-US,en;q=0.9").
SetHeader("Cache-Control", "no-cache").
SetHeader("Origin", "https://claude.ai").
SetHeader("Referer", "https://claude.ai/new").
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&result).
Post(authURL)
if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
}
if result.RedirectURI == "" {
return "", fmt.Errorf("no redirect_uri in response")
}
parsedURL, err := url.Parse(result.RedirectURI)
if err != nil {
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
}
queryParams := parsedURL.Query()
authCode := queryParams.Get("code")
responseState := queryParams.Get("state")
if authCode == "" {
return "", fmt.Errorf("no authorization code in redirect_uri")
}
fullCode := authCode
if responseState != "" {
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
return fullCode, nil
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
// Parse code which may contain state in format "authCode#state"
authCode := code
codeState := ""
if idx := strings.Index(code, "#"); idx != -1 {
authCode = code[:idx]
codeState = code[idx+1:]
}
reqBody := map[string]any{
"code": authCode,
"grant_type": "authorization_code",
"client_id": oauth.ClientID,
"redirect_uri": oauth.RedirectURI,
"code_verifier": codeVerifier,
}
if codeState != "" {
reqBody["state"] = codeState
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
return &tokenResp, nil
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauth.ClientID)
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome().
SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}

View File

@@ -0,0 +1,63 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type claudeUsageService struct{}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
return &claudeUsageService{}
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("failed to get default transport")
}
transport = transport.Clone()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
var usageResp service.ClaudeUsageResponse
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
return nil, fmt.Errorf("decode response failed: %w", err)
}
return &usageResp, nil
}

View File

@@ -0,0 +1,204 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
// Key prefixes for independent slot keys
// Format: concurrency:account:{accountID}:{requestID}
accountSlotKeyPrefix = "concurrency:account:"
// Format: concurrency:user:{userID}:{requestID}
userSlotKeyPrefix = "concurrency:user:"
// Wait queue keeps counter format: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// Slot TTL - each slot expires independently
slotTTL = 5 * time.Minute
)
var (
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
// ARGV[1] = maxConcurrency
// ARGV[2] = TTL in seconds
acquireScript = redis.NewScript(`
local pattern = KEYS[1]
local slotKey = KEYS[2]
local maxConcurrency = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
-- Count existing slots using SCAN
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
-- Check if we can acquire a slot
if count < maxConcurrency then
redis.call('SET', slotKey, '1', 'EX', ttl)
return 1
end
return 0
`)
// getCountScript counts slots using SCAN
// KEYS[1] = pattern for SCAN
getCountScript = redis.NewScript(`
local pattern = KEYS[1]
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
return count
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
incrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
)
type concurrencyCache struct {
rdb *redis.Client
}
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
return &concurrencyCache{rdb: rdb}
}
// Helper functions for key generation
func accountSlotKey(accountID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
}
func accountSlotPattern(accountID int64) string {
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
}
func userSlotKey(userID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
}
func userSlotPattern(userID int64) string {
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
}
func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := accountSlotPattern(accountID)
slotKey := accountSlotKey(accountID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
slotKey := accountSlotKey(accountID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
}
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
pattern := accountSlotPattern(accountID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
if err != nil {
return 0, err
}
return result, nil
}
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := userSlotPattern(userID)
slotKey := userSlotKey(userID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
slotKey := userSlotKey(userID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
}
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
pattern := userSlotPattern(userID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
if err != nil {
return 0, err
}
return result, nil
}
// Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := waitQueueKey(userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}

View File

@@ -0,0 +1,48 @@
package repository
import (
"context"
"encoding/json"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const verifyCodeKeyPrefix = "verify_code:"
type emailCache struct {
rdb *redis.Client
}
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
return &emailCache{rdb: rdb}
}
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
key := verifyCodeKeyPrefix + email
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data ports.VerificationCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
key := verifyCodeKeyPrefix + email
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
key := verifyCodeKeyPrefix + email
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,35 @@
package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const stickySessionPrefix = "sticky_session:"
type gatewayCache struct {
rdb *redis.Client
}
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
return &gatewayCache{rdb: rdb}
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
key := stickySessionPrefix + sessionHash
return c.rdb.Get(ctx, key).Int64()
}
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
return c.rdb.Set(ctx, key, accountID, ttl).Err()
}
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
return c.rdb.Expire(ctx, key, ttl).Err()
}

View File

@@ -0,0 +1,116 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type githubReleaseClient struct {
httpClient *http.Client
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "Sub2API-Updater")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release service.GitHubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
return &release, nil
}
func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode)
}
// SECURITY: Check Content-Length if available
if resp.ContentLength > maxSize {
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxSize)
}
out, err := os.Create(dest)
if err != nil {
return err
}
defer func() { _ = out.Close() }()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxSize+1)
written, err := io.Copy(out, limited)
if err != nil {
return err
}
// Check if we hit the limit (downloaded more than maxSize)
if written > maxSize {
_ = os.Remove(dest) // Clean up partial file (best-effort)
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
}
return nil
}
func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}

View File

@@ -2,7 +2,8 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -36,12 +37,12 @@ func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
} }
func (r *GroupRepository) List(ctx context.Context, params PaginationParams) ([]model.Group, *PaginationResult, error) { func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil) return r.ListWithFilters(ctx, params, "", "", nil)
} }
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive // ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *GroupRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *PaginationResult, error) { func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
var groups []model.Group var groups []model.Group
var total int64 var total int64
@@ -77,7 +78,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params Pagination
pages++ pages++
} }
return groups, &PaginationResult{ return groups, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),

View File

@@ -0,0 +1,67 @@
package repository
import (
"net/http"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
// httpUpstreamService is a generic HTTP upstream service that can be used for
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support.
type httpUpstreamService struct {
defaultClient *http.Client
cfg *config.Config
}
// NewHTTPUpstream creates a new generic HTTP upstream service
func NewHTTPUpstream(cfg *config.Config) ports.HTTPUpstream {
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
return &httpUpstreamService{
defaultClient: &http.Client{Transport: transport},
cfg: cfg,
}
}
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
if proxyURL == "" {
return s.defaultClient.Do(req)
}
client := s.createProxyClient(proxyURL)
return client.Do(req)
}
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return s.defaultClient
}
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
return &http.Client{Transport: transport}
}

View File

@@ -0,0 +1,47 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
fingerprintKeyPrefix = "fingerprint:"
fingerprintTTL = 24 * time.Hour
)
type identityCache struct {
rdb *redis.Client
}
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache {
return &identityCache{rdb: rdb}
}
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var fp ports.Fingerprint
if err := json.Unmarshal([]byte(val), &fp); err != nil {
return nil, err
}
return &fp, nil
}
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
val, err := json.Marshal(fp)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
}

View File

@@ -0,0 +1,92 @@
package repository
import (
"context"
"fmt"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/imroc/req/v3"
)
type openaiOAuthService struct{}
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
return &openaiOAuthService{}
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
}
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("client_id", openai.ClientID)
formData.Set("code", code)
formData.Set("redirect_uri", redirectURI)
formData.Set("code_verifier", codeVerifier)
var tokenResp openai.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", openai.ClientID)
formData.Set("scope", openai.RefreshScopes)
var tokenResp openai.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func createOpenAIReqClient(proxyURL string) *req.Client {
client := req.C().
SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}

View File

@@ -0,0 +1,73 @@
package repository
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type pricingRemoteClient struct {
httpClient *http.Client
}
func NewPricingRemoteClient() service.PricingRemoteClient {
return &pricingRemoteClient{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// 哈希文件格式hash filename 或者纯 hash
hash := strings.TrimSpace(string(body))
parts := strings.Fields(hash)
if len(parts) > 0 {
return parts[0], nil
}
return hash, nil
}

View File

@@ -0,0 +1,104 @@
package repository
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"golang.org/x/net/proxy"
)
type proxyProbeService struct{}
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
return &proxyProbeService{}
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
transport, err := createProxyTransport(proxyURL)
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
}
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
latencyMs := time.Since(startTime).Milliseconds()
if resp.StatusCode != http.StatusOK {
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
}
var ipInfo struct {
IP string `json:"ip"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
}
return &service.ProxyExitInfo{
IP: ipInfo.IP,
City: ipInfo.City,
Region: ipInfo.Region,
Country: ipInfo.Country,
}, latencyMs, nil
}
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}

View File

@@ -2,7 +2,8 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -36,12 +37,12 @@ func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
} }
func (r *ProxyRepository) List(ctx context.Context, params PaginationParams) ([]model.Proxy, *PaginationResult, error) { func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query // ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params PaginationParams, protocol, status, search string) ([]model.Proxy, *PaginationResult, error) { func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy var proxies []model.Proxy
var total int64 var total int64
@@ -72,7 +73,7 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params Pagination
pages++ pages++
} }
return proxies, &PaginationResult{ return proxies, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),

View File

@@ -0,0 +1,49 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
redeemLockKeyPrefix = "redeem:lock:"
redeemRateLimitDuration = 24 * time.Hour
)
type redeemCache struct {
rdb *redis.Client
}
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache {
return &redeemCache{rdb: rdb}
}
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
}
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
_, err := pipe.Exec(ctx)
return err
}
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
key := redeemLockKeyPrefix + code
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
key := redeemLockKeyPrefix + code
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -2,7 +2,8 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
@@ -46,12 +47,12 @@ func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
} }
func (r *RedeemCodeRepository) List(ctx context.Context, params PaginationParams) ([]model.RedeemCode, *PaginationResult, error) { func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query // ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params PaginationParams, codeType, status, search string) ([]model.RedeemCode, *PaginationResult, error) { func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
var codes []model.RedeemCode var codes []model.RedeemCode
var total int64 var total int64
@@ -82,7 +83,7 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params Pagin
pages++ pages++
} }
return codes, &PaginationResult{ return codes, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
@@ -98,7 +99,7 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
now := time.Now() now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}). result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
Where("id = ? AND status = ?", id, model.StatusUnused). Where("id = ? AND status = ?", id, model.StatusUnused).
Updates(map[string]interface{}{ Updates(map[string]any{
"status": model.StatusUsed, "status": model.StatusUsed,
"used_by": userID, "used_by": userID,
"used_at": now, "used_at": now,

View File

@@ -1,9 +1,5 @@
package repository package repository
import (
"gorm.io/gorm"
)
// Repositories 所有仓库的集合 // Repositories 所有仓库的集合
type Repositories struct { type Repositories struct {
User *UserRepository User *UserRepository
@@ -16,59 +12,3 @@ type Repositories struct {
Setting *SettingRepository Setting *SettingRepository
UserSubscription *UserSubscriptionRepository UserSubscription *UserSubscriptionRepository
} }
// NewRepositories 创建所有仓库
func NewRepositories(db *gorm.DB) *Repositories {
return &Repositories{
User: NewUserRepository(db),
ApiKey: NewApiKeyRepository(db),
Group: NewGroupRepository(db),
Account: NewAccountRepository(db),
Proxy: NewProxyRepository(db),
RedeemCode: NewRedeemCodeRepository(db),
UsageLog: NewUsageLogRepository(db),
Setting: NewSettingRepository(db),
UserSubscription: NewUserSubscriptionRepository(db),
}
}
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}

View File

@@ -2,7 +2,7 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"

View File

@@ -0,0 +1,55 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
type turnstileVerifier struct {
httpClient *http.Client
}
func NewTurnstileVerifier() service.TurnstileVerifier {
return &turnstileVerifier{
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*service.TurnstileVerifyResponse, error) {
formData := url.Values{}
formData.Set("secret", secretKey)
formData.Set("response", token)
if remoteIP != "" {
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := v.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
var result service.TurnstileVerifyResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decode response: %w", err)
}
return &result, nil
}

View File

@@ -0,0 +1,28 @@
package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const updateCacheKey = "update:latest"
type updateCache struct {
rdb *redis.Client
}
func NewUpdateCache(rdb *redis.Client) ports.UpdateCache {
return &updateCache{rdb: rdb}
}
func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) {
return c.rdb.Get(ctx, updateCacheKey).Result()
}
func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error {
return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err()
}

View File

@@ -2,8 +2,10 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
@@ -17,6 +19,30 @@ func NewUsageLogRepository(db *gorm.DB) *UsageLogRepository {
return &UsageLogRepository{db: db} return &UsageLogRepository{db: db}
} }
// getPerformanceStats 获取 RPM 和 TPM近5分钟平均值可选按用户过滤
func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
var perfStats struct {
RequestCount int64 `gorm:"column:request_count"`
TokenCount int64 `gorm:"column:token_count"`
}
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
`).
Where("created_at >= ?", fiveMinutesAgo)
if userID > 0 {
db = db.Where("user_id = ?", userID)
}
db.Scan(&perfStats)
// 返回5分钟平均值
return perfStats.RequestCount / 5, perfStats.TokenCount / 5
}
func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error { func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error return r.db.WithContext(ctx).Create(log).Error
} }
@@ -30,7 +56,7 @@ func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.Usag
return &log, nil return &log, nil
} }
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -49,7 +75,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
pages++ pages++
} }
return logs, &PaginationResult{ return logs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
@@ -57,7 +83,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
}, nil }, nil
} }
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -76,7 +102,7 @@ func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
pages++ pages++
} }
return logs, &PaginationResult{ return logs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
@@ -111,46 +137,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
} }
// DashboardStats 仪表盘统计 // DashboardStats 仪表盘统计
type DashboardStats struct { type DashboardStats = usagestats.DashboardStats
// 用户统计
TotalUsers int64 `json:"total_users"`
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 系统运行统计
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
}
func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
var stats DashboardStats var stats DashboardStats
@@ -267,10 +254,13 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
stats.TodayCost = todayStats.TodayCost stats.TodayCost = todayStats.TodayCost
stats.TodayActualCost = todayStats.TodayActualCost stats.TodayActualCost = todayStats.TodayActualCost
// 性能指标RPM 和 TPM最近1分钟全局
stats.Rpm, stats.Tpm = r.getPerformanceStats(ctx, 0)
return &stats, nil return &stats, nil
} }
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -289,7 +279,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
pages++ pages++
} }
return logs, &PaginationResult{ return logs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
@@ -297,7 +287,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}, nil }, nil
} }
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime). Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
@@ -306,7 +296,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime). Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
@@ -315,7 +305,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
@@ -324,7 +314,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime). Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
@@ -337,15 +327,8 @@ func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
} }
// AccountStats 账号使用统计
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}
// GetAccountTodayStats 获取账号今日统计 // GetAccountTodayStats 获取账号今日统计
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*AccountStats, error) { func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
today := timezone.Today() today := timezone.Today()
var stats struct { var stats struct {
@@ -367,7 +350,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
return nil, err return nil, err
} }
return &AccountStats{ return &usagestats.AccountStats{
Requests: stats.Requests, Requests: stats.Requests,
Tokens: stats.Tokens, Tokens: stats.Tokens,
Cost: stats.Cost, Cost: stats.Cost,
@@ -375,7 +358,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
} }
// GetAccountWindowStats 获取账号时间窗口内的统计 // GetAccountWindowStats 获取账号时间窗口内的统计
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*AccountStats, error) { func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
var stats struct { var stats struct {
Requests int64 `gorm:"column:requests"` Requests int64 `gorm:"column:requests"`
Tokens int64 `gorm:"column:tokens"` Tokens int64 `gorm:"column:tokens"`
@@ -395,7 +378,7 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
return nil, err return nil, err
} }
return &AccountStats{ return &usagestats.AccountStats{
Requests: stats.Requests, Requests: stats.Requests,
Tokens: stats.Tokens, Tokens: stats.Tokens,
Cost: stats.Cost, Cost: stats.Cost,
@@ -403,109 +386,16 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
} }
// TrendDataPoint represents a single point in trend data // TrendDataPoint represents a single point in trend data
type TrendDataPoint struct { type TrendDataPoint = usagestats.TrendDataPoint
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ModelStat represents usage statistics for a single model // ModelStat represents usage statistics for a single model
type ModelStat struct { type ModelStat = usagestats.ModelStat
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point // UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct { type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// GetUsageTrend returns usage trend data grouped by date
// granularity: "day" or "hour"
func (r *UsageLogRepository) GetUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
var results []TrendDataPoint
// Choose date format based on granularity
var dateFormat string
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
} else {
dateFormat = "YYYY-MM-DD"
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
TO_CHAR(created_at, ?) as date,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`, dateFormat).
Where("created_at >= ? AND created_at < ?", startTime, endTime).
Group("date").
Order("date ASC").
Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetModelStats returns usage statistics grouped by model
func (r *UsageLogRepository) GetModelStats(ctx context.Context, startTime, endTime time.Time) ([]ModelStat, error) {
var results []ModelStat
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`).
Where("created_at >= ? AND created_at < ?", startTime, endTime).
Group("model").
Order("total_tokens DESC").
Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// ApiKeyUsageTrendPoint represents API key usage trend data point // ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct { type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date // GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) { func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
@@ -598,34 +488,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
} }
// UserDashboardStats 用户仪表盘统计 // UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct { type UserDashboardStats = usagestats.UserDashboardStats
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 性能统计
AverageDurationMs float64 `json:"average_duration_ms"`
}
// GetUserDashboardStats 获取用户专属的仪表盘统计 // GetUserDashboardStats 获取用户专属的仪表盘统计
func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) { func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
@@ -708,6 +571,9 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
stats.TodayCost = todayStats.TodayCost stats.TodayCost = todayStats.TodayCost
stats.TodayActualCost = todayStats.TodayActualCost stats.TodayActualCost = todayStats.TodayActualCost
// 性能指标RPM 和 TPM最近1分钟仅统计该用户的请求
stats.Rpm, stats.Tpm = r.getPerformanceStats(ctx, userID)
return &stats, nil return &stats, nil
} }
@@ -772,15 +638,10 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64
} }
// UsageLogFilters represents filters for usage log queries // UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct { type UsageLogFilters = usagestats.UsageLogFilters
UserID int64
ApiKeyID int64
StartTime *time.Time
EndTime *time.Time
}
// ListWithFilters lists usage logs with optional filters (for admin) // ListWithFilters lists usage logs with optional filters (for admin)
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *PaginationResult, error) { func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -816,7 +677,7 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params Paginat
pages++ pages++
} }
return logs, &PaginationResult{ return logs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
@@ -825,23 +686,10 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params Paginat
} }
// UsageStats represents usage statistics // UsageStats represents usage statistics
type UsageStats struct { type UsageStats = usagestats.UsageStats
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user // BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct { type BatchUserUsageStats = usagestats.BatchUserUsageStats
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// GetBatchUserUsageStats gets today and total actual_cost for multiple users // GetBatchUserUsageStats gets today and total actual_cost for multiple users
func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
@@ -901,11 +749,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
} }
// BatchApiKeyUsageStats represents usage stats for a single API key // BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct { type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
ApiKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys // GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) { func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
@@ -964,6 +808,79 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
return result, nil return result, nil
} }
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
var results []TrendDataPoint
var dateFormat string
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
} else {
dateFormat = "YYYY-MM-DD"
}
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
TO_CHAR(created_at, ?) as date,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`, dateFormat).
Where("created_at >= ? AND created_at < ?", startTime, endTime)
if userID > 0 {
db = db.Where("user_id = ?", userID)
}
if apiKeyID > 0 {
db = db.Where("api_key_id = ?", apiKeyID)
}
err := db.Group("date").Order("date ASC").Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`).
Where("created_at >= ? AND created_at < ?", startTime, endTime)
if userID > 0 {
db = db.Where("user_id = ?", userID)
}
if apiKeyID > 0 {
db = db.Where("api_key_id = ?", apiKeyID)
}
if accountID > 0 {
db = db.Where("account_id = ?", accountID)
}
err := db.Group("model").Order("total_tokens DESC").Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetGlobalStats gets usage statistics for all users within a time range // GetGlobalStats gets usage statistics for all users within a time range
func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) { func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
var stats struct { var stats struct {
@@ -1004,3 +921,169 @@ func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
AverageDurationMs: stats.AverageDurationMs, AverageDurationMs: stats.AverageDurationMs,
}, nil }, nil
} }
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory = usagestats.AccountUsageHistory
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary = usagestats.AccountUsageSummary
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
if daysCount <= 0 {
daysCount = 30
}
// Get daily history
var historyResults []struct {
Date string `gorm:"column:date"`
Requests int64 `gorm:"column:requests"`
Tokens int64 `gorm:"column:tokens"`
Cost float64 `gorm:"column:cost"`
ActualCost float64 `gorm:"column:actual_cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Group("date").
Order("date ASC").
Scan(&historyResults).Error
if err != nil {
return nil, err
}
// Build history with labels
history := make([]AccountUsageHistory, 0, len(historyResults))
for _, h := range historyResults {
// Parse date to get label (MM/DD)
t, _ := time.Parse("2006-01-02", h.Date)
label := t.Format("01/02")
history = append(history, AccountUsageHistory{
Date: h.Date,
Label: label,
Requests: h.Requests,
Tokens: h.Tokens,
Cost: h.Cost,
ActualCost: h.ActualCost,
})
}
// Calculate summary
var totalActualCost, totalStandardCost float64
var totalRequests, totalTokens int64
var highestCostDay, highestRequestDay *AccountUsageHistory
for i := range history {
h := &history[i]
totalActualCost += h.ActualCost
totalStandardCost += h.Cost
totalRequests += h.Requests
totalTokens += h.Tokens
if highestCostDay == nil || h.ActualCost > highestCostDay.ActualCost {
highestCostDay = h
}
if highestRequestDay == nil || h.Requests > highestRequestDay.Requests {
highestRequestDay = h
}
}
actualDaysUsed := len(history)
if actualDaysUsed == 0 {
actualDaysUsed = 1
}
// Get average duration
var avgDuration struct {
AvgDurationMs float64 `gorm:"column:avg_duration_ms"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms").
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Scan(&avgDuration)
summary := AccountUsageSummary{
Days: daysCount,
ActualDaysUsed: actualDaysUsed,
TotalCost: totalActualCost,
TotalStandardCost: totalStandardCost,
TotalRequests: totalRequests,
TotalTokens: totalTokens,
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
AvgDurationMs: avgDuration.AvgDurationMs,
}
// Set today's stats
todayStr := timezone.Now().Format("2006-01-02")
for i := range history {
if history[i].Date == todayStr {
summary.Today = &struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}{
Date: history[i].Date,
Cost: history[i].ActualCost,
Requests: history[i].Requests,
Tokens: history[i].Tokens,
}
break
}
}
// Set highest cost day
if highestCostDay != nil {
summary.HighestCostDay = &struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
}{
Date: highestCostDay.Date,
Label: highestCostDay.Label,
Cost: highestCostDay.ActualCost,
Requests: highestCostDay.Requests,
}
}
// Set highest request day
if highestRequestDay != nil {
summary.HighestRequestDay = &struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
}{
Date: highestRequestDay.Date,
Label: highestRequestDay.Label,
Requests: highestRequestDay.Requests,
Cost: highestRequestDay.ActualCost,
}
}
// Get model statistics using the unified method
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
if err != nil {
models = []ModelStat{}
}
return &AccountUsageStatsResponse{
History: history,
Summary: summary,
Models: models,
}, nil
}

View File

@@ -2,7 +2,8 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -45,12 +46,12 @@ func (r *UserRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
} }
func (r *UserRepository) List(ctx context.Context, params PaginationParams) ([]model.User, *PaginationResult, error) { func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists users with optional filtering by status, role, and search query // ListWithFilters lists users with optional filtering by status, role, and search query
func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationParams, status, role, search string) ([]model.User, *PaginationResult, error) { func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User var users []model.User
var total int64 var total int64
@@ -65,23 +66,53 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationP
} }
if search != "" { if search != "" {
searchPattern := "%" + search + "%" searchPattern := "%" + search + "%"
db = db.Where("email ILIKE ?", searchPattern) db = db.Where(
"email ILIKE ? OR username ILIKE ? OR wechat ILIKE ?",
searchPattern, searchPattern, searchPattern,
)
} }
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
} }
// Query users with pagination (reuse the same db with filters applied)
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil { if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil {
return nil, nil, err return nil, nil, err
} }
// Batch load subscriptions for all users (avoid N+1)
if len(users) > 0 {
userIDs := make([]int64, len(users))
userMap := make(map[int64]*model.User, len(users))
for i := range users {
userIDs[i] = users[i].ID
userMap[users[i].ID] = &users[i]
}
// Query active subscriptions with groups in one query
var subscriptions []model.UserSubscription
if err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive).
Find(&subscriptions).Error; err != nil {
return nil, nil, err
}
// Associate subscriptions with users
for i := range subscriptions {
if user, ok := userMap[subscriptions[i].UserID]; ok {
user.Subscriptions = append(user.Subscriptions, subscriptions[i])
}
}
}
pages := int(total) / params.Limit() pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 { if int(total)%params.Limit() > 0 {
pages++ pages++
} }
return users, &PaginationResult{ return users, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
@@ -128,3 +159,15 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
Order("id ASC").
First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}

View File

@@ -4,7 +4,8 @@ import (
"context" "context"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -100,7 +101,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
} }
// ListByGroupID 获取分组的所有订阅(分页) // ListByGroupID 获取分组的所有订阅(分页)
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.UserSubscription, *PaginationResult, error) { func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
var total int64 var total int64
@@ -126,7 +127,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
pages++ pages++
} }
return subs, &PaginationResult{ return subs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
@@ -135,7 +136,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
} }
// List 获取所有订阅(分页,支持筛选) // List 获取所有订阅(分页,支持筛选)
func (r *UserSubscriptionRepository) List(ctx context.Context, params PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *PaginationResult, error) { func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
var total int64 var total int64
@@ -172,7 +173,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params Pagination
pages++ pages++
} }
return subs, &PaginationResult{ return subs, &pagination.PaginationResult{
Total: total, Total: total,
Page: params.Page, Page: params.Page,
PageSize: params.Limit(), PageSize: params.Limit(),
@@ -184,7 +185,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params Pagination
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]interface{}{ Updates(map[string]any{
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD), "daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD), "weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD), "monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
@@ -196,7 +197,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]interface{}{ Updates(map[string]any{
"daily_usage_usd": 0, "daily_usage_usd": 0,
"daily_window_start": newWindowStart, "daily_window_start": newWindowStart,
"updated_at": time.Now(), "updated_at": time.Now(),
@@ -207,7 +208,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]interface{}{ Updates(map[string]any{
"weekly_usage_usd": 0, "weekly_usage_usd": 0,
"weekly_window_start": newWindowStart, "weekly_window_start": newWindowStart,
"updated_at": time.Now(), "updated_at": time.Now(),
@@ -218,7 +219,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]interface{}{ Updates(map[string]any{
"monthly_usage_usd": 0, "monthly_usage_usd": 0,
"monthly_window_start": newWindowStart, "monthly_window_start": newWindowStart,
"updated_at": time.Now(), "updated_at": time.Now(),
@@ -229,7 +230,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error { func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]interface{}{ Updates(map[string]any{
"daily_window_start": activateTime, "daily_window_start": activateTime,
"weekly_window_start": activateTime, "weekly_window_start": activateTime,
"monthly_window_start": activateTime, "monthly_window_start": activateTime,
@@ -241,7 +242,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error { func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]interface{}{ Updates(map[string]any{
"status": status, "status": status,
"updated_at": time.Now(), "updated_at": time.Now(),
}).Error }).Error
@@ -251,7 +252,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error { func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]interface{}{ Updates(map[string]any{
"expires_at": newExpiresAt, "expires_at": newExpiresAt,
"updated_at": time.Now(), "updated_at": time.Now(),
}).Error }).Error
@@ -261,7 +262,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error { func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]interface{}{ Updates(map[string]any{
"notes": notes, "notes": notes,
"updated_at": time.Now(), "updated_at": time.Now(),
}).Error }).Error
@@ -280,7 +281,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}). result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Updates(map[string]interface{}{ Updates(map[string]any{
"status": model.SubscriptionStatusExpired, "status": model.SubscriptionStatusExpired,
"updated_at": time.Now(), "updated_at": time.Now(),
}) })

View File

@@ -0,0 +1,52 @@
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/google/wire"
)
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
NewApiKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewProxyRepository,
NewRedeemCodeRepository,
NewUsageLogRepository,
NewSettingRepository,
NewUserSubscriptionRepository,
wire.Struct(new(Repositories), "*"),
// Cache implementations
NewGatewayCache,
NewBillingCache,
NewApiKeyCache,
NewConcurrencyCache,
NewEmailCache,
NewIdentityCache,
NewRedeemCache,
NewUpdateCache,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
NewPricingRemoteClient,
NewGitHubReleaseClient,
NewProxyExitInfoProber,
NewClaudeUsageFetcher,
NewClaudeOAuthClient,
NewHTTPUpstream,
NewOpenAIOAuthClient,
// Bind concrete repositories to service port interfaces
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
wire.Bind(new(ports.GroupRepository), new(*GroupRepository)),
wire.Bind(new(ports.AccountRepository), new(*AccountRepository)),
wire.Bind(new(ports.ProxyRepository), new(*ProxyRepository)),
wire.Bind(new(ports.RedeemCodeRepository), new(*RedeemCodeRepository)),
wire.Bind(new(ports.UsageLogRepository), new(*UsageLogRepository)),
wire.Bind(new(ports.SettingRepository), new(*SettingRepository)),
wire.Bind(new(ports.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
)

View File

@@ -0,0 +1,45 @@
package server
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/google/wire"
)
// ProviderSet 提供服务器层的依赖
var ProviderSet = wire.NewSet(
ProvideRouter,
ProvideHTTPServer,
)
// ProvideRouter 提供路由器
func ProvideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
}
r := gin.New()
r.Use(gin.Recovery())
return SetupRouter(r, cfg, handlers, services, repos)
}
// ProvideHTTPServer 提供 HTTP 服务器
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
return &http.Server{
Addr: cfg.Server.Address(),
Handler: router,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
}

View File

@@ -0,0 +1,312 @@
package server
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/web"
"net/http"
"github.com/gin-gonic/gin"
)
// SetupRouter 配置路由器中间件和路由
func SetupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
// 应用中间件
r.Use(middleware.Logger())
r.Use(middleware.CORS())
// 注册路由
registerRoutes(r, handlers, services, repos)
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
}
return r
}
// registerRoutes 注册所有 HTTP 路由
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Claude Code 遥测日志忽略直接返回200
r.POST("/api/event_logging/batch", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
// API v1
v1 := r.Group("/api/v1")
{
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的接口
authenticated := v1.Group("")
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
{
// 当前用户信息
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
// 管理员接口
admin := v1.Group("/admin")
admin.Use(middleware.AdminAuth(s.Auth, repos.User, s.Setting))
{
// 仪表盘
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
// 用户管理
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
}
// 分组管理
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
// 账号管理
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
// OpenAI OAuth routes
openai := admin.Group("/openai")
{
openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
// 代理管理
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
// 卡密管理
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
// 系统设置
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
}
// 系统管理
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
// 订阅管理
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
// 使用记录管理
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
}
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription), h.OpenAIGateway.Responses)
}

View File

@@ -4,8 +4,9 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -16,37 +17,37 @@ var (
// CreateAccountRequest 创建账号请求 // CreateAccountRequest 创建账号请求
type CreateAccountRequest struct { type CreateAccountRequest struct {
Name string `json:"name"` Name string `json:"name"`
Platform string `json:"platform"` Platform string `json:"platform"`
Type string `json:"type"` Type string `json:"type"`
Credentials map[string]interface{} `json:"credentials"` Credentials map[string]any `json:"credentials"`
Extra map[string]interface{} `json:"extra"` Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"` Concurrency int `json:"concurrency"`
Priority int `json:"priority"` Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
} }
// UpdateAccountRequest 更新账号请求 // UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct { type UpdateAccountRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
Credentials *map[string]interface{} `json:"credentials"` Credentials *map[string]any `json:"credentials"`
Extra *map[string]interface{} `json:"extra"` Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"` Priority *int `json:"priority"`
Status *string `json:"status"` Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
} }
// AccountService 账号管理服务 // AccountService 账号管理服务
type AccountService struct { type AccountService struct {
accountRepo *repository.AccountRepository accountRepo ports.AccountRepository
groupRepo *repository.GroupRepository groupRepo ports.GroupRepository
} }
// NewAccountService 创建账号服务实例 // NewAccountService 创建账号服务实例
func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService { func NewAccountService(accountRepo ports.AccountRepository, groupRepo ports.GroupRepository) *AccountService {
return &AccountService{ return &AccountService{
accountRepo: accountRepo, accountRepo: accountRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
@@ -108,7 +109,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
} }
// List 获取账号列表 // List 获取账号列表
func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) { func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params) accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err) return nil, nil, fmt.Errorf("list accounts: %w", err)

View File

@@ -10,20 +10,23 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
) )
const ( const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages" testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testModel = "claude-sonnet-4-5-20250929" testOpenAIAPIURL = "https://api.openai.com/v1/responses"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
) )
// TestEvent represents a SSE event for account testing // TestEvent represents a SSE event for account testing
@@ -37,39 +40,46 @@ type TestEvent struct {
// AccountTestService handles account testing operations // AccountTestService handles account testing operations
type AccountTestService struct { type AccountTestService struct {
repos *repository.Repositories accountRepo ports.AccountRepository
oauthService *OAuthService oauthService *OAuthService
httpClient *http.Client openaiOAuthService *OpenAIOAuthService
httpUpstream ports.HTTPUpstream
} }
// NewAccountTestService creates a new AccountTestService // NewAccountTestService creates a new AccountTestService
func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService { func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream ports.HTTPUpstream) *AccountTestService {
return &AccountTestService{ return &AccountTestService{
repos: repos, accountRepo: accountRepo,
oauthService: oauthService, oauthService: oauthService,
httpClient: &http.Client{ openaiOAuthService: openaiOAuthService,
Timeout: 60 * time.Second, httpUpstream: httpUpstream,
},
} }
} }
// generateSessionString generates a Claude Code style session string // generateSessionString generates a Claude Code style session string
func generateSessionString() string { func generateSessionString() (string, error) {
bytes := make([]byte, 32) bytes := make([]byte, 32)
rand.Read(bytes) if _, err := rand.Read(bytes); err != nil {
return "", err
}
hex64 := hex.EncodeToString(bytes) hex64 := hex.EncodeToString(bytes)
sessionUUID := uuid.New().String() sessionUUID := uuid.New().String()
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID) return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
} }
// createTestPayload creates a minimal test request payload for OAuth/Setup Token accounts // createTestPayload creates a Claude Code style test request payload
func createTestPayload() map[string]interface{} { func createTestPayload(modelID string) (map[string]any, error) {
return map[string]interface{}{ sessionID, err := generateSessionString()
"model": testModel, if err != nil {
"messages": []map[string]interface{}{ return nil, err
}
return map[string]any{
"model": modelID,
"messages": []map[string]any{
{ {
"role": "user", "role": "user",
"content": []map[string]interface{}{ "content": []map[string]any{
{ {
"type": "text", "type": "text",
"text": "hi", "text": "hi",
@@ -80,7 +90,7 @@ func createTestPayload() map[string]interface{} {
}, },
}, },
}, },
"system": []map[string]interface{}{ "system": []map[string]any{
{ {
"type": "text", "type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.", "text": "You are Claude Code, Anthropic's official CLI for Claude.",
@@ -90,47 +100,62 @@ func createTestPayload() map[string]interface{} {
}, },
}, },
"metadata": map[string]string{ "metadata": map[string]string{
"user_id": generateSessionString(), "user_id": sessionID,
}, },
"max_tokens": 1024, "max_tokens": 1024,
"temperature": 1, "temperature": 1,
"stream": true, "stream": true,
} }, nil
}
// createApiKeyTestPayload creates a simpler test request payload for API Key accounts
func createApiKeyTestPayload(model string) map[string]interface{} {
return map[string]interface{}{
"model": model,
"messages": []map[string]interface{}{
{
"role": "user",
"content": "hi",
},
},
"max_tokens": 1024,
"stream": true,
}
} }
// TestAccountConnection tests an account's connection by sending a test request // TestAccountConnection tests an account's connection by sending a test request
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64) error { // All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
ctx := c.Request.Context() ctx := c.Request.Context()
// Get account // Get account
account, err := s.repos.Account.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, "Account not found") return s.sendErrorAndEnd(c, "Account not found")
} }
// Determine authentication method based on account type // Route to platform-specific test method
if account.IsOpenAI() {
return s.testOpenAIAccountConnection(c, account, modelID)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
// testClaudeAccountConnection tests an Anthropic Claude account's connection
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
ctx := c.Request.Context()
// Determine the model to use
testModelID := modelID
if testModelID == "" {
testModelID = claude.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
}
// Determine authentication method and API URL
var authToken string var authToken string
var authType string // "bearer" for OAuth, "apikey" for API Key var useBearer bool
var apiURL string var apiURL string
if account.IsOAuth() { if account.IsOAuth() {
// OAuth or Setup Token account // OAuth or Setup Token - use Bearer token
authType = "bearer" useBearer = true
apiURL = testClaudeAPIURL apiURL = testClaudeAPIURL
authToken = account.GetCredential("access_token") authToken = account.GetCredential("access_token")
if authToken == "" { if authToken == "" {
@@ -141,7 +166,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
needRefresh := false needRefresh := false
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" { if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil && time.Now().Unix()+300 > expiresAt { // 5 minute buffer if err == nil && time.Now().Unix()+300 > expiresAt {
needRefresh = true needRefresh = true
} }
} }
@@ -154,19 +179,17 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
authToken = tokenInfo.AccessToken authToken = tokenInfo.AccessToken
} }
} else if account.Type == "apikey" { } else if account.Type == "apikey" {
// API Key account // API Key - use x-api-key header
authType = "apikey" useBearer = false
authToken = account.GetCredential("api_key") authToken = account.GetCredential("api_key")
if authToken == "" { if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available") return s.sendErrorAndEnd(c, "No API key available")
} }
// Get base URL (use default if not set)
apiURL = account.GetBaseURL() apiURL = account.GetBaseURL()
if apiURL == "" { if apiURL == "" {
apiURL = "https://api.anthropic.com" apiURL = "https://api.anthropic.com"
} }
// Append /v1/messages endpoint
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages" apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
} else { } else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
@@ -179,61 +202,49 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
c.Writer.Header().Set("X-Accel-Buffering", "no") c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush() c.Writer.Flush()
// Create test request payload // Create Claude Code style payload (same for all account types)
var payload map[string]interface{} payload, err := createTestPayload(testModelID)
var actualModel string if err != nil {
if authType == "apikey" { return s.sendErrorAndEnd(c, "Failed to create test payload")
// Use simpler payload for API Key (without Claude Code specific fields)
// Apply model mapping if configured
actualModel = account.GetMappedModel(testModel)
payload = createApiKeyTestPayload(actualModel)
} else {
actualModel = testModel
payload = createTestPayload()
} }
payloadBytes, _ := json.Marshal(payload) payloadBytes, _ := json.Marshal(payload)
// Send test_start event with model info // Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: actualModel}) s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes)) req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request") return s.sendErrorAndEnd(c, "Failed to create request")
} }
// Set headers based on auth type // Set common headers
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
if authType == "bearer" { // Apply Claude Code client headers
for key, value := range claude.DefaultHeaders {
req.Header.Set(key, value)
}
// Set authentication header
if useBearer {
req.Header.Set("Authorization", "Bearer "+authToken) req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31,interleaved-thinking-2025-05-14,output-128k-2025-02-19")
} else { } else {
// API Key uses x-api-key header
req.Header.Set("x-api-key", authToken) req.Header.Set("x-api-key", authToken)
} }
// Configure proxy if account has one // Get proxy URL
transport := http.DefaultTransport.(*http.Transport).Clone() proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL() proxyURL = account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
} }
client := &http.Client{ resp, err := s.httpUpstream.Do(req, proxyURL)
Transport: transport,
Timeout: 60 * time.Second,
}
resp, err := client.Do(req)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
@@ -241,18 +252,161 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
} }
// Process SSE stream // Process SSE stream
return s.processStream(c, resp.Body) return s.processClaudeStream(c, resp.Body)
} }
// processStream processes the SSE stream from Claude API // testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error { func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
ctx := c.Request.Context()
// Default to openai.DefaultTestModel for OpenAI testing
testModelID := modelID
if testModelID == "" {
testModelID = openai.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
}
// Determine authentication method and API URL
var authToken string
var apiURL string
var isOAuth bool
var chatgptAccountID string
if account.IsOAuth() {
isOAuth = true
// OAuth - use Bearer token with ChatGPT internal API
authToken = account.GetOpenAIAccessToken()
if authToken == "" {
return s.sendErrorAndEnd(c, "No access token available")
}
// Check if token is expired and refresh if needed
if account.IsOpenAITokenExpired() && s.openaiOAuthService != nil {
tokenInfo, err := s.openaiOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
}
authToken = tokenInfo.AccessToken
}
// OAuth uses ChatGPT internal API
apiURL = chatgptCodexAPIURL
chatgptAccountID = account.GetChatGPTAccountID()
} else if account.Type == "apikey" {
// API Key - use Platform API
authToken = account.GetOpenAIApiKey()
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
apiURL = strings.TrimSuffix(baseURL, "/") + "/v1/responses"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
// Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
// Create OpenAI Responses API payload
payload := createOpenAITestPayload(testModelID, isOAuth)
payloadBytes, _ := json.Marshal(payload)
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
// Set common headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
// Set OAuth-specific headers for ChatGPT internal API
if isOAuth {
req.Host = "chatgpt.com"
req.Header.Set("accept", "text/event-stream")
if chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
}
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
// Process SSE stream
return s.processOpenAIStream(c, resp.Body)
}
// createOpenAITestPayload creates a test payload for OpenAI Responses API
func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
payload := map[string]any{
"model": modelID,
"input": []map[string]any{
{
"role": "user",
"content": []map[string]any{
{
"type": "input_text",
"text": "hi",
},
},
},
},
"stream": true,
}
// OAuth accounts using ChatGPT internal API require store: false
if isOAuth {
payload["store"] = false
}
// All accounts require instructions for Responses API
payload["instructions"] = openai.DefaultInstructions
return payload
}
// processClaudeStream processes the SSE stream from Claude API
func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body) reader := bufio.NewReader(body)
for { for {
line, err := reader.ReadString('\n') line, err := reader.ReadString('\n')
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
// Stream ended, send complete event
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil return nil
} }
@@ -270,7 +424,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
return nil return nil
} }
var data map[string]interface{} var data map[string]any
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
continue continue
} }
@@ -279,7 +433,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
switch eventType { switch eventType {
case "content_block_delta": case "content_block_delta":
if delta, ok := data["delta"].(map[string]interface{}); ok { if delta, ok := data["delta"].(map[string]any); ok {
if text, ok := delta["text"].(string); ok { if text, ok := delta["text"].(string); ok {
s.sendEvent(c, TestEvent{Type: "content", Text: text}) s.sendEvent(c, TestEvent{Type: "content", Text: text})
} }
@@ -289,7 +443,60 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
return nil return nil
case "error": case "error":
errorMsg := "Unknown error" errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]interface{}); ok { if errData, ok := data["error"].(map[string]any); ok {
if msg, ok := errData["message"].(string); ok {
errorMsg = msg
}
}
return s.sendErrorAndEnd(c, errorMsg)
}
}
}
// processOpenAIStream processes the SSE stream from OpenAI Responses API
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
}
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data: ") {
continue
}
jsonStr := strings.TrimPrefix(line, "data: ")
if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
var data map[string]any
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
continue
}
eventType, _ := data["type"].(string)
switch eventType {
case "response.output_text.delta":
// OpenAI Responses API uses "delta" field for text content
if delta, ok := data["delta"].(string); ok && delta != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
}
case "response.completed":
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
case "error":
errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]any); ok {
if msg, ok := errData["message"].(string); ok { if msg, ok := errData["message"].(string); ok {
errorMsg = msg errorMsg = msg
} }
@@ -302,7 +509,10 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
// sendEvent sends a SSE event to the client // sendEvent sends a SSE event to the client
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
eventJSON, _ := json.Marshal(event) eventJSON, _ := json.Marshal(event)
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON) if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
log.Printf("failed to write SSE event: %v", err)
return
}
c.Writer.Flush() c.Writer.Flush()
} }
@@ -310,5 +520,5 @@ func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error { func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error {
log.Printf("Account test error: %s", errorMsg) log.Printf("Account test error: %s", errorMsg)
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg}) s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
return fmt.Errorf(errorMsg) return fmt.Errorf("%s", errorMsg)
} }

View File

@@ -2,17 +2,14 @@ package service
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"io"
"log" "log"
"net/http"
"net/url"
"sync" "sync"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
// usageCache 用于缓存usage数据 // usageCache 用于缓存usage数据
@@ -35,10 +32,10 @@ type WindowStats struct {
// UsageProgress 使用量进度 // UsageProgress 使用量进度
type UsageProgress struct { type UsageProgress struct {
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+100表示100%) Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+100表示100%)
ResetsAt *time.Time `json:"resets_at"` // 重置时间 ResetsAt *time.Time `json:"resets_at"` // 重置时间
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数 RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量) WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
} }
// UsageInfo 账号使用量信息 // UsageInfo 账号使用量信息
@@ -65,21 +62,24 @@ type ClaudeUsageResponse struct {
} `json:"seven_day_sonnet"` } `json:"seven_day_sonnet"`
} }
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
type ClaudeUsageFetcher interface {
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
}
// AccountUsageService 账号使用量查询服务 // AccountUsageService 账号使用量查询服务
type AccountUsageService struct { type AccountUsageService struct {
repos *repository.Repositories accountRepo ports.AccountRepository
oauthService *OAuthService usageLogRepo ports.UsageLogRepository
httpClient *http.Client usageFetcher ClaudeUsageFetcher
} }
// NewAccountUsageService 创建AccountUsageService实例 // NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService { func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
return &AccountUsageService{ return &AccountUsageService{
repos: repos, accountRepo: accountRepo,
oauthService: oauthService, usageLogRepo: usageLogRepo,
httpClient: &http.Client{ usageFetcher: usageFetcher,
Timeout: 30 * time.Second,
},
} }
} }
@@ -88,7 +88,7 @@ func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthS
// Setup Token账号: 根据session_window推算5h窗口7d数据不可用没有profile scope // Setup Token账号: 根据session_window推算5h窗口7d数据不可用没有profile scope
// API Key账号: 不支持usage查询 // API Key账号: 不支持usage查询
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) { func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
account, err := s.repos.Account.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account failed: %w", err) return nil, fmt.Errorf("get account failed: %w", err)
} }
@@ -97,8 +97,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
if account.CanGetUsage() { if account.CanGetUsage() {
// 检查缓存 // 检查缓存
if cached, ok := usageCacheMap.Load(accountID); ok { if cached, ok := usageCacheMap.Load(accountID); ok {
cache := cached.(*usageCache) cache, ok := cached.(*usageCache)
if time.Since(cache.timestamp) < cacheTTL { if !ok {
usageCacheMap.Delete(accountID)
} else if time.Since(cache.timestamp) < cacheTTL {
return cache.data, nil return cache.data, nil
} }
} }
@@ -148,7 +150,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
startTime = time.Now().Add(-5 * time.Hour) startTime = time.Now().Add(-5 * time.Hour)
} }
stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime) stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil { if err != nil {
log.Printf("Failed to get window stats for account %d: %v", account.ID, err) log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
return return
@@ -163,7 +165,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
// GetTodayStats 获取账号今日统计 // GetTodayStats 获取账号今日统计
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) { func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID) stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get today stats failed: %w", err) return nil, fmt.Errorf("get today stats failed: %w", err)
} }
@@ -175,60 +177,33 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
}, nil }, nil
} }
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get account usage stats failed: %w", err)
}
return stats, nil
}
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量 // fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) { func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
// 获取access token从credentials中获取
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
if accessToken == "" { if accessToken == "" {
return nil, fmt.Errorf("no access token available") return nil, fmt.Errorf("no access token available")
} }
// 获取代理配置 var proxyURL string
transport := http.DefaultTransport.(*http.Transport).Clone()
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL() proxyURL = account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
} }
client := &http.Client{ usageResp, err := s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
Transport: transport,
Timeout: 30 * time.Second,
}
// 构建请求
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("create request failed: %w", err) return nil, err
} }
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
// 发送请求
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
// 解析响应
var usageResp ClaudeUsageResponse
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
return nil, fmt.Errorf("decode response failed: %w", err)
}
// 转换为UsageInfo
now := time.Now() now := time.Now()
return s.buildUsageInfo(&usageResp, &now), nil return s.buildUsageInfo(usageResp, &now), nil
} }
// parseTime 尝试多种格式解析时间 // parseTime 尝试多种格式解析时间

View File

@@ -2,20 +2,15 @@ package service
import ( import (
"context" "context"
"crypto/tls"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "log"
"net"
"net/http"
"net/url"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"golang.org/x/net/proxy"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -27,9 +22,9 @@ type AdminService interface {
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management // Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
@@ -50,6 +45,7 @@ type AdminService interface {
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
ClearAccountError(ctx context.Context, id int64) (*model.Account, error) ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
// Proxy management // Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
@@ -76,6 +72,9 @@ type AdminService interface {
type CreateUserInput struct { type CreateUserInput struct {
Email string Email string
Password string Password string
Username string
Wechat string
Notes string
Balance float64 Balance float64
Concurrency int Concurrency int
AllowedGroups []int64 AllowedGroups []int64
@@ -84,6 +83,9 @@ type CreateUserInput struct {
type UpdateUserInput struct { type UpdateUserInput struct {
Email string Email string
Password string Password string
Username *string
Wechat *string
Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0" Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string Status string
@@ -119,8 +121,8 @@ type CreateAccountInput struct {
Name string Name string
Platform string Platform string
Type string Type string
Credentials map[string]interface{} Credentials map[string]any
Extra map[string]interface{} Extra map[string]any
ProxyID *int64 ProxyID *int64
Concurrency int Concurrency int
Priority int Priority int
@@ -130,8 +132,8 @@ type CreateAccountInput struct {
type UpdateAccountInput struct { type UpdateAccountInput struct {
Name string Name string
Type string // Account type: oauth, setup-token, apikey Type string // Account type: oauth, setup-token, apikey
Credentials map[string]interface{} Credentials map[string]any
Extra map[string]interface{} Extra map[string]any
ProxyID *int64 ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0" Priority *int // 使用指针区分"未提供"和"设置为0"
@@ -139,6 +141,33 @@ type UpdateAccountInput struct {
GroupIDs *[]int64 GroupIDs *[]int64
} }
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
Name string
ProxyID *int64
Concurrency *int
Priority *int
Status string
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
}
// BulkUpdateAccountResult captures the result for a single account update.
type BulkUpdateAccountResult struct {
AccountID int64 `json:"account_id"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
Failed int `json:"failed"`
Results []BulkUpdateAccountResult `json:"results"`
}
type CreateProxyInput struct { type CreateProxyInput struct {
Name string Name string
Protocol string Protocol string
@@ -177,44 +206,57 @@ type ProxyTestResult struct {
Country string `json:"country,omitempty"` Country string `json:"country,omitempty"`
} }
// ProxyExitInfo represents proxy exit information from ipinfo.io
type ProxyExitInfo struct {
IP string
City string
Region string
Country string
}
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
// adminServiceImpl implements AdminService // adminServiceImpl implements AdminService
type adminServiceImpl struct { type adminServiceImpl struct {
userRepo *repository.UserRepository userRepo ports.UserRepository
groupRepo *repository.GroupRepository groupRepo ports.GroupRepository
accountRepo *repository.AccountRepository accountRepo ports.AccountRepository
proxyRepo *repository.ProxyRepository proxyRepo ports.ProxyRepository
apiKeyRepo *repository.ApiKeyRepository apiKeyRepo ports.ApiKeyRepository
redeemCodeRepo *repository.RedeemCodeRepository redeemCodeRepo ports.RedeemCodeRepository
usageLogRepo *repository.UsageLogRepository
userSubRepo *repository.UserSubscriptionRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
} }
// NewAdminService creates a new AdminService // NewAdminService creates a new AdminService
func NewAdminService(repos *repository.Repositories) AdminService { func NewAdminService(
userRepo ports.UserRepository,
groupRepo ports.GroupRepository,
accountRepo ports.AccountRepository,
proxyRepo ports.ProxyRepository,
apiKeyRepo ports.ApiKeyRepository,
redeemCodeRepo ports.RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
) AdminService {
return &adminServiceImpl{ return &adminServiceImpl{
userRepo: repos.User, userRepo: userRepo,
groupRepo: repos.Group, groupRepo: groupRepo,
accountRepo: repos.Account, accountRepo: accountRepo,
proxyRepo: repos.Proxy, proxyRepo: proxyRepo,
apiKeyRepo: repos.ApiKey, apiKeyRepo: apiKeyRepo,
redeemCodeRepo: repos.RedeemCode, redeemCodeRepo: redeemCodeRepo,
usageLogRepo: repos.UsageLog, billingCacheService: billingCacheService,
userSubRepo: repos.UserSubscription, proxyProber: proxyProber,
}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
// 注意AdminService是接口需要类型断言
func SetAdminServiceBillingCache(adminService AdminService, billingCacheService *BillingCacheService) {
if impl, ok := adminService.(*adminServiceImpl); ok {
impl.billingCacheService = billingCacheService
} }
} }
// User management implementations // User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) { func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search) users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -229,6 +271,9 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User,
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) { func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
user := &model.User{ user := &model.User{
Email: input.Email, Email: input.Email,
Username: input.Username,
Wechat: input.Wechat,
Notes: input.Notes,
Role: "user", // Always create as regular user, never admin Role: "user", // Always create as regular user, never admin
Balance: input.Balance, Balance: input.Balance,
Concurrency: input.Concurrency, Concurrency: input.Concurrency,
@@ -254,8 +299,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
return nil, errors.New("cannot disable admin user") return nil, errors.New("cannot disable admin user")
} }
// Track balance and concurrency changes for logging
oldBalance := user.Balance
oldConcurrency := user.Concurrency oldConcurrency := user.Concurrency
if input.Email != "" { if input.Email != "" {
@@ -266,22 +309,25 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
return nil, err return nil, err
} }
} }
// Role is not allowed to be changed via API to prevent privilege escalation
if input.Username != nil {
user.Username = *input.Username
}
if input.Wechat != nil {
user.Wechat = *input.Wechat
}
if input.Notes != nil {
user.Notes = *input.Notes
}
if input.Status != "" { if input.Status != "" {
user.Status = input.Status user.Status = input.Status
} }
// 只在指针非 nil 时更新 Balance支持设置为 0
if input.Balance != nil {
user.Balance = *input.Balance
}
// 只在指针非 nil 时更新 Concurrency支持设置为任意值
if input.Concurrency != nil { if input.Concurrency != nil {
user.Concurrency = *input.Concurrency user.Concurrency = *input.Concurrency
} }
// 只在指针非 nil 时更新 AllowedGroups
if input.AllowedGroups != nil { if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups user.AllowedGroups = *input.AllowedGroups
} }
@@ -290,39 +336,15 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
return nil, err return nil, err
} }
// 余额变化时失效缓存
if input.Balance != nil && *input.Balance != oldBalance {
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, id)
}()
}
}
// Create adjustment records for balance/concurrency changes
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
adjustmentRecord := &model.RedeemCode{
Code: model.GenerateRedeemCode(),
Type: model.AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: model.StatusUsed,
UsedBy: &user.ID,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
// Log error but don't fail the update
// The user update has already succeeded
}
}
concurrencyDiff := user.Concurrency - oldConcurrency concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 { if concurrencyDiff != 0 {
code, err := model.GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{ adjustmentRecord := &model.RedeemCode{
Code: model.GenerateRedeemCode(), Code: code,
Type: model.AdjustmentTypeAdminConcurrency, Type: model.AdjustmentTypeAdminConcurrency,
Value: float64(concurrencyDiff), Value: float64(concurrencyDiff),
Status: model.StatusUsed, Status: model.StatusUsed,
@@ -331,8 +353,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
now := time.Now() now := time.Now()
adjustmentRecord.UsedAt = &now adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
// Log error but don't fail the update log.Printf("failed to create concurrency adjustment redeem code: %v", err)
// The user update has already succeeded
} }
} }
@@ -351,12 +372,14 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
return s.userRepo.Delete(ctx, id) return s.userRepo.Delete(ctx, id)
} }
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error) { func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
oldBalance := user.Balance
switch operation { switch operation {
case "set": case "set":
user.Balance = balance user.Balance = balance
@@ -366,24 +389,53 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
user.Balance -= balance user.Balance -= balance
} }
if user.Balance < 0 {
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
}
if err := s.userRepo.Update(ctx, user); err != nil { if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err return nil, err
} }
// 失效余额缓存
if s.billingCacheService != nil { if s.billingCacheService != nil {
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
}
}() }()
} }
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := model.GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{
Code: code,
Type: model.AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: model.StatusUsed,
UsedBy: &user.ID,
Notes: notes,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create balance adjustment redeem code: %v", err)
}
}
return user, nil return user, nil
} }
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -391,9 +443,9 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
return keys, result.Total, nil return keys, result.Total, nil
} }
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error) { func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
// Return mock data for now // Return mock data for now
return map[string]interface{}{ return map[string]any{
"period": period, "period": period,
"total_requests": 0, "total_requests": 0,
"total_cost": 0.0, "total_cost": 0.0,
@@ -404,7 +456,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
// Group management implementations // Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) { func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -566,7 +618,9 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
for _, userID := range affectedUserIDs { for _, userID := range affectedUserIDs {
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
}
} }
}() }()
} }
@@ -575,7 +629,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
} }
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) { func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -585,7 +639,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -633,10 +687,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Type != "" { if input.Type != "" {
account.Type = input.Type account.Type = input.Type
} }
if input.Credentials != nil && len(input.Credentials) > 0 { if len(input.Credentials) > 0 {
account.Credentials = model.JSONB(input.Credentials) account.Credentials = model.JSONB(input.Credentials)
} }
if input.Extra != nil && len(input.Extra) > 0 { if len(input.Extra) > 0 {
account.Extra = model.JSONB(input.Extra) account.Extra = model.JSONB(input.Extra)
} }
if input.ProxyID != nil { if input.ProxyID != nil {
@@ -668,6 +722,65 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
return account, nil return account, nil
} }
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
result := &BulkUpdateAccountsResult{
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
}
if len(input.AccountIDs) == 0 {
return result, nil
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := ports.AccountBulkUpdate{
Credentials: input.Credentials,
Extra: input.Extra,
}
if input.Name != "" {
repoUpdates.Name = &input.Name
}
if input.ProxyID != nil {
repoUpdates.ProxyID = input.ProxyID
}
if input.Concurrency != nil {
repoUpdates.Concurrency = input.Concurrency
}
if input.Priority != nil {
repoUpdates.Priority = input.Priority
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
// Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
return nil, err
}
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
entry.Success = true
result.Success++
result.Results = append(result.Results, entry)
}
return result, nil
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id) return s.accountRepo.Delete(ctx, id)
} }
@@ -703,7 +816,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
// Proxy management implementations // Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) { func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -788,7 +901,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
// Redeem code management implementations // Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) { func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@@ -818,8 +931,12 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
codes := make([]model.RedeemCode, 0, input.Count) codes := make([]model.RedeemCode, 0, input.Count)
for i := 0; i < input.Count; i++ { for i := 0; i < input.Count; i++ {
codeValue, err := model.GenerateRedeemCode()
if err != nil {
return nil, err
}
code := model.RedeemCode{ code := model.RedeemCode{
Code: model.GenerateRedeemCode(), Code: codeValue,
Type: input.Type, Type: input.Type,
Value: input.Value, Value: input.Value,
Status: model.StatusUnused, Status: model.StatusUnused,
@@ -872,79 +989,12 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
return nil, err return nil, err
} }
return testProxyConnection(ctx, proxy)
}
// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json
func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestResult, error) {
proxyURL := proxy.URL() proxyURL := proxy.URL()
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
// Create HTTP client with proxy
transport, err := createProxyTransport(proxyURL)
if err != nil { if err != nil {
return &ProxyTestResult{ return &ProxyTestResult{
Success: false, Success: false,
Message: fmt.Sprintf("Failed to create proxy transport: %v", err), Message: err.Error(),
}, nil
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
}
// Measure latency
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Failed to create request: %v", err),
}, nil
}
resp, err := client.Do(req)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Proxy connection failed: %v", err),
}, nil
}
defer resp.Body.Close()
latencyMs := time.Since(startTime).Milliseconds()
if resp.StatusCode != http.StatusOK {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Request failed with status: %d", resp.StatusCode),
LatencyMs: latencyMs,
}, nil
}
// Parse ipinfo.io response
var ipInfo struct {
IP string `json:"ip"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible but failed to read response",
LatencyMs: latencyMs,
}, nil
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible but failed to parse response",
LatencyMs: latencyMs,
}, nil }, nil
} }
@@ -952,38 +1002,9 @@ func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestRes
Success: true, Success: true,
Message: "Proxy is accessible", Message: "Proxy is accessible",
LatencyMs: latencyMs, LatencyMs: latencyMs,
IPAddress: ipInfo.IP, IPAddress: exitInfo.IP,
City: ipInfo.City, City: exitInfo.City,
Region: ipInfo.Region, Region: exitInfo.Region,
Country: ipInfo.Country, Country: exitInfo.Country,
}, nil }, nil
} }
// createProxyTransport creates an HTTP transport with the given proxy URL
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}

View File

@@ -6,10 +6,11 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"time" "time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
@@ -17,18 +18,16 @@ import (
) )
var ( var (
ErrApiKeyNotFound = errors.New("api key not found") ErrApiKeyNotFound = errors.New("api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group") ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists") ErrApiKeyExists = errors.New("api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters") ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens") ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later") ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
) )
const ( const (
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:" apiKeyMaxErrorsPerHour = 20
apiKeyMaxErrorsPerHour = 20
apiKeyRateLimitDuration = time.Hour
) )
// CreateApiKeyRequest 创建API Key请求 // CreateApiKeyRequest 创建API Key请求
@@ -47,21 +46,21 @@ type UpdateApiKeyRequest struct {
// ApiKeyService API Key服务 // ApiKeyService API Key服务
type ApiKeyService struct { type ApiKeyService struct {
apiKeyRepo *repository.ApiKeyRepository apiKeyRepo ports.ApiKeyRepository
userRepo *repository.UserRepository userRepo ports.UserRepository
groupRepo *repository.GroupRepository groupRepo ports.GroupRepository
userSubRepo *repository.UserSubscriptionRepository userSubRepo ports.UserSubscriptionRepository
rdb *redis.Client cache ports.ApiKeyCache
cfg *config.Config cfg *config.Config
} }
// NewApiKeyService 创建API Key服务实例 // NewApiKeyService 创建API Key服务实例
func NewApiKeyService( func NewApiKeyService(
apiKeyRepo *repository.ApiKeyRepository, apiKeyRepo ports.ApiKeyRepository,
userRepo *repository.UserRepository, userRepo ports.UserRepository,
groupRepo *repository.GroupRepository, groupRepo ports.GroupRepository,
userSubRepo *repository.UserSubscriptionRepository, userSubRepo ports.UserSubscriptionRepository,
rdb *redis.Client, cache ports.ApiKeyCache,
cfg *config.Config, cfg *config.Config,
) *ApiKeyService { ) *ApiKeyService {
return &ApiKeyService{ return &ApiKeyService{
@@ -69,7 +68,7 @@ func NewApiKeyService(
userRepo: userRepo, userRepo: userRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
rdb: rdb, cache: cache,
cfg: cfg, cfg: cfg,
} }
} }
@@ -101,10 +100,13 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
// 检查字符:只允许字母、数字、下划线、连字符 // 检查字符:只允许字母、数字、下划线、连字符
for _, c := range key { for _, c := range key {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || if (c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') || c == '_' || c == '-') { (c >= 'A' && c <= 'Z') ||
return ErrApiKeyInvalidChars (c >= '0' && c <= '9') ||
c == '_' || c == '-' {
continue
} }
return ErrApiKeyInvalidChars
} }
return nil return nil
@@ -112,13 +114,11 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限 // checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error { func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil { if s.cache == nil {
return nil return nil
} }
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) count, err := s.cache.GetCreateAttemptCount(ctx, userID)
count, err := s.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) { if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作 // Redis 出错时不阻止用户操作
return nil return nil
@@ -133,16 +133,11 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数 // incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) { func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil { if s.cache == nil {
return return
} }
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) _ = s.cache.IncrementCreateAttemptCount(ctx, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, _ = pipe.Exec(ctx)
} }
// canUserBindGroup 检查用户是否可以绑定指定分组 // canUserBindGroup 检查用户是否可以绑定指定分组
@@ -237,7 +232,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// List 获取用户的API Key列表 // List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) { func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err) return nil, nil, fmt.Errorf("list api keys: %w", err)
@@ -272,7 +267,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
} }
// 缓存到Redis可选TTL设置为5分钟 // 缓存到Redis可选TTL设置为5分钟
if s.rdb != nil { if s.cache != nil {
// 这里可以序列化并缓存API Key // 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误 _ = cacheKey // 使用变量避免未使用错误
} }
@@ -325,9 +320,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
if req.Status != nil { if req.Status != nil {
apiKey.Status = *req.Status apiKey.Status = *req.Status
// 如果状态改变清除Redis缓存 // 如果状态改变清除Redis缓存
if s.rdb != nil { if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key) _ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
_ = s.rdb.Del(ctx, cacheKey)
} }
} }
@@ -354,9 +348,8 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
} }
// 清除Redis缓存 // 清除Redis缓存
if s.rdb != nil { if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key) _ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
_ = s.rdb.Del(ctx, cacheKey)
} }
if err := s.apiKeyRepo.Delete(ctx, id); err != nil { if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
@@ -399,13 +392,13 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
// IncrementUsage 增加API Key使用次数可选用于统计 // IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error { func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器 // 使用Redis计数器
if s.rdb != nil { if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02")) cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil { if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
return fmt.Errorf("increment usage: %w", err) return fmt.Errorf("increment usage: %w", err)
} }
// 设置24小时过期 // 设置24小时过期
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour) _ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
} }
return nil return nil
} }
@@ -462,3 +455,11 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.
// 标准类型分组:使用原有逻辑 // 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive) return user.CanBindGroup(group.ID, group.IsExclusive)
} }
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}
return keys, nil
}

Some files were not shown because too many files have changed in this diff Show More