Compare commits

..

53 Commits

Author SHA1 Message Date
刀刀
9d30ceae8d CC 400 返回具体错误信息 && 非 CC 请求时增加 system prompt (#26)
* feat: http 400 返回具体错误

* 更新 workflows

* 优化打包/docker 构建流程

* 400 是返回 原始错误 - json 格式

* feat: 非 cc请求时补充 system

* go mod tidy
2025-12-25 14:47:19 +08:00
IanShaw
60f6ed6bf6 feat: CRS 同步增强 - 自动刷新 OAuth token 和修复测试配置 (#27)
* fix(service): 修复 OpenAI Responses API 测试负载配置

- 所有账号类型统一添加 instructions 字段(不再仅限 OAuth)
- Responses API 要求所有请求必须包含 instructions 参数

* feat(crs-sync): CRS 同步时自动刷新 OAuth token 并保留完整 extra 字段

**核心功能**:
- CRSSyncService 注入 OAuth 服务依赖(Anthropic + OpenAI)
- 账号创建/更新后自动刷新 OAuth token,确保可用性
- 完整保留 CRS extra 字段,避免数据丢失

**Extra 字段增强**:
- 保留 CRS 所有原始 extra 字段
- 新增同步元数据: crs_account_id, crs_kind, crs_synced_at
- Claude 账号: 从 credentials 提取 org_uuid/account_uuid 到 extra
- OpenAI 账号: 映射 crs_email -> email

**Token 刷新逻辑**:
- 新增 refreshOAuthToken() 方法处理 Anthropic/OpenAI 平台
- 保留原有 credentials 字段,仅更新 token 相关字段
- 刷新失败静默处理,不中断同步流程

**依赖注入**:
- wire_gen.go: CRSSyncService 新增 oAuthService/openaiOAuthService

* style(crs-sync): 使用 switch 替代 if-else 修复 golangci-lint 警告

- 将 refreshOAuthToken 中的 if-else 改为 switch 语句
- 符合 staticcheck 规范
- 添加 default 分支处理未知平台
2025-12-25 14:45:17 +08:00
shaw
4a2f7d4a99 chore: CRS迁移功能增加版本提示 2025-12-25 10:57:04 +08:00
shaw
c19a393be9 Merge PR #24: feat: 添加账户同步与批量编辑功能
- 添加从 CRS 同步账户功能 (Claude OAuth/API Key, OpenAI OAuth/Responses)
- 添加批量编辑账户功能,支持 JSONB 字段智能合并
- 新增 CRSSyncService、BulkUpdate 仓储方法
- 前端新增 SyncFromCrsModal 和 BulkEditAccountModal 组件
2025-12-25 10:44:40 +08:00
ianshaw
938ffb002e style(frontend): format code with prettier
格式化前端业务代码,符合代码规范
- 统一代码风格
- 修复 ESLint 警告
2025-12-24 18:07:58 -08:00
ianshaw
372a01290b fix(backend): handle defer Close() errors in crs_sync_service
修复 golangci-lint 错误检查问题
- 使用匿名函数包装 defer Close() 并忽略错误
- 符合 Go 最佳实践
2025-12-24 17:58:47 -08:00
ianshaw
8b163ca49b chore: trigger CI after enabling Actions 2025-12-24 17:56:55 -08:00
ianshaw
d23810dc53 chore: trigger CI workflow 2025-12-24 17:54:43 -08:00
ianshaw
62ed5422dd feat(account): 优化批量更新实现,使用统一 SQL 合并 JSONB 字段
- 新增 BulkUpdate 仓储方法,使用单条 SQL 更新所有账户
- credentials/extra 使用 COALESCE(...) || ? 合并,只更新传入的 key
- name/proxy_id/concurrency/priority/status 只在提供时更新
- 分组绑定仍逐账号处理(需要独立操作)
- 前端优化:Base URL 留空则不修改,按勾选字段更新
- 完善 i18n 文案:说明留空不修改、批量更新行为
2025-12-24 17:16:19 -08:00
ianshaw
2e76302af7 feat(account): 添加批量编辑账户凭据功能并优化 CRS 同步
- 新增批量更新账户凭据接口(account_uuid/org_uuid/intercept_warmup_requests)
- 新增前端批量编辑模态框组件
- 优化 CRS 同步逻辑,改进 extra 字段处理
- 优化 CRS 同步 UI,添加更详细的结果展示
- 完善国际化文案(中英文)
2025-12-24 16:56:48 -08:00
ianshaw
6553828008 feat(account): 添加从 CRS 同步账户功能
- 添加账户同步 API 接口 (account_handler.go)
- 实现 CRS 同步服务 (crs_sync_service.go)
- 添加前端同步对话框组件 (SyncFromCrsModal.vue)
- 更新账户管理界面支持同步操作
- 添加账户仓库批量创建方法
- 添加中英文国际化翻译
- 更新依赖注入配置
2025-12-24 08:48:58 -08:00
ianshaw
adcb7bf00e chore: 更新 .gitignore 忽略配置文件并还原 Makefile
- 添加 backend/config.yaml 到 .gitignore(包含敏感信息)
- 添加 deploy/config.yaml 到 .gitignore(包含敏感信息)
- 添加 backend/.installed 到 .gitignore
- 还原 Makefile 到原始版本
2025-12-24 08:48:49 -08:00
shaw
876e85e7ad Merge branch 'feat/rename-go-module' 2025-12-24 21:34:37 +08:00
shaw
2e7818d688 feat(settings): 添加文档链接配置功能
- 后台系统设置新增文档链接(doc_url)配置项
- 首页顶部导航栏显示文档链接图标(条件渲染)
- Footer区域添加文档链接和GitHub链接
- 支持中英文国际化
2025-12-24 21:30:19 +08:00
Forest
836c4dda2b refactor: 重命名 go module 2025-12-24 21:07:21 +08:00
shaw
e65e9587b4 fix(concurrency): 重构并发管理使用独立Key+原生TTL
问题:旧方案使用计数器模式,每次acquire都刷新TTL,导致僵尸数据永不过期

解决方案:
- 每个槽位使用独立Redis Key: concurrency:account:{id}:{requestID}
- 利用Redis原生TTL,每个槽位独立5分钟过期
- 服务崩溃后僵尸数据自动清理,无需手动干预
- 兼容多实例K8s部署

技术改动:
- 新增SCAN脚本统计活跃槽位数量
- 移除冗余的releaseScript,直接使用DEL命令
- Wait队列TTL只在首次创建时设置,避免刷新
2025-12-24 21:00:29 +08:00
shaw
aaadd6ed04 fix(dashboard): 修复性能指标 RPM/TPM 显示为0的问题
- 修复 Admin Dashboard Handler 遗漏返回 rpm/tpm 字段
- 将性能统计时间窗口从1分钟改为5分钟平均值,数据更稳定
2025-12-24 19:58:33 +08:00
shaw
870b21916c feat(install): 添加安装指定版本和回退功能
- 新增 rollback 命令支持回退到指定版本
- 新增 list-versions 命令列出可用版本
- 新增 -v/--version 参数指定安装版本
- upgrade 命令支持升级到指定版本
- 添加安装状态检查,未安装时给出明确提示
- 版本切换仅替换二进制文件,保留配置和数据
- 自动备份当前版本(带版本号或时间戳后缀)
- 改进网络错误处理,添加超时和友好提示
- 修复 grep -oP 兼容性问题,改用 grep -oE
2025-12-24 17:44:13 +08:00
shaw
fb119f9a67 fix(version): 优化服务重启后页面刷新时机
- 将重启后等待时间从 3 秒增加到 8 秒
- 添加倒计时显示,提升用户体验
- 倒计时结束后先检测服务健康状态再刷新页面
- 避免刷新过早导致 502 错误
2025-12-24 17:21:17 +08:00
shaw
ad54795a24 feat(gateway): 添加上游错误重试机制
- OAuth/Setup Token 账号遇到 403 错误时,等待 2 秒后重试,最多 3 次
- Console 账号遇到未配置的错误码时,同样进行重试
- 重试耗尽后:OAuth 403 标记账号异常,Console 未配置错误码不标记账号
- 移除 handleErrorResponse 中已被重试逻辑覆盖的死代码
2025-12-24 16:55:46 +08:00
shaw
0abe322cca feat(accounts): 账户列表显示实时并发数
- 在账户列表 API 返回中添加 current_concurrency 字段
- 合并平台和类型列为 PlatformTypeBadge 组件,节省表格空间
- 新增并发状态列,显示 当前/最大 并发数,支持颜色编码
2025-12-24 15:44:45 +08:00
shaw
b071511676 refactor(accounts): 优化用量窗口显示,统一 OAuth 和 Setup Token 处理
- Setup Token 账号现在也调用 API 获取 5h 窗口用量数据
- 重新设计 UsageProgressBar UI,将用量统计移到进度条上方
- 删除冗余的 SetupTokenTimeWindow 组件
- 请求数/Token数支持 K/M/B 单位显示
2025-12-24 10:57:40 +08:00
shaw
7d9a757a26 feat(dashboard): 添加 RPM/TPM 性能指标
在 Dashboard 中用 RPM/TPM 卡片替换原来的"今日缓存"卡片,
实时显示最近1分钟的请求数和 Token 吞吐量。
2025-12-24 10:24:02 +08:00
Forest
bbf4024dc7 refactor(usage): 移动 usage 查询到 services 2025-12-24 08:41:31 +08:00
shaw
5831eb8a6a fix: 修复Claude OAuth token交换时authorization code解析错误
原代码中 `parts` 变量被创建但从未使用,导致 `len(parts) == 0`
永远为 true,使得即使成功从 `code#state` 格式中分割出 authCode,
最后也会被覆盖为原始的完整字符串。

这导致传递给Claude Token端点的code包含了 `#state` 部分,
Claude返回 "Invalid 'code' in request" 错误。
2025-12-23 19:42:52 +08:00
shaw
61838cdb3d fix: 兼容GLM等API的usage数据解析
部分第三方API(如GLM)的SSE响应格式与标准Claude API不同:
- 标准Claude: input_tokens在message_start中
- GLM等API: 所有tokens都在message_delta中

现在从message_delta中也解析input_tokens和cache相关字段,
如果message_start中没有值则使用message_delta中的数据。
2025-12-23 19:42:52 +08:00
dexcoder6
50dba656fd feat: 添加用户余额充值/退款功能 (#17)
## 功能特性

### 前端
- 在用户列表操作列添加充值和退款按钮
- 实现充值/退款对话框,支持输入金额和备注
- 从编辑用户表单中移除余额字段,防止直接修改
- 添加余额不足验证,实时显示操作后余额
- 优化备注提示词,提供多种场景示例

### 后端
- 为 redeem_codes 表添加 notes 字段(迁移文件)
- 在 UpdateUserBalance 接口添加 notes 参数支持
- 添加余额验证:金额必须大于0,操作后余额不能为负
- UpdateUser 接口移除 balance 字段处理,防止误操作
- 完整的审计日志和缓存管理

## 安全保护

- 前端:余额不足时禁用提交按钮,实时提示
- 后端:双重验证(输入金额 > 0 + 结果余额 >= 0)
- 权限:仅管理员可访问(AdminAuth 中间件)
- 审计:所有操作记录到 redeem_codes 表

## 修改文件

后端:
- backend/migrations/004_add_redeem_code_notes.sql
- backend/internal/model/redeem_code.go
- backend/internal/service/admin_service.go
- backend/internal/handler/admin/user_handler.go

前端:
- frontend/src/views/admin/UsersView.vue
- frontend/src/api/admin/users.ts
- frontend/src/i18n/locales/zh.ts
- frontend/src/i18n/locales/en.ts

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-23 16:29:57 +08:00
shaw
0e2821456c chore: 忽略TypeScript增量编译缓存文件 2025-12-23 16:27:56 +08:00
shaw
f25ac3aff5 feat: OpenAI OAuth账号显示Codex使用量
从响应头提取x-codex-*使用量信息并保存到账号Extra字段,
前端账号列表展示5h/7d窗口的使用进度条。
2025-12-23 16:26:07 +08:00
shaw
f6341b7f2b chore: 将"代理管理"菜单更名为"IP管理" 2025-12-23 15:46:10 +08:00
shaw
4e257512b9 style: 统一平台和分组列的样式
- 账号页面平台列改为与分组页面一致的标签样式
- 订阅页面分组列改用 GroupBadge 组件展示
- 修正 OpenAI OAuth 类型描述文案
2025-12-23 15:40:22 +08:00
shaw
e53b34f321 Merge PR #15: feat: 增强用户管理功能,添加用户名、微信号和备注字段 2025-12-23 14:03:07 +08:00
shaw
12ddae0184 fix: 优化OpenAI模型定价查找的回退逻辑
当模型ID在model_pricing.json中找不到时,增加智能回退策略:
- gpt-5.2-codex → 回退到 gpt-5.2
- gpt-5.2-20251222 → 去掉日期后缀回退到 gpt-5.2
- 最终回退到 DefaultTestModel (gpt-5.1-codex)
2025-12-23 13:58:56 +08:00
shaw
7b9c3f165e feat: 账号管理新增使用统计功能
- 新增账号统计弹窗,展示30天使用数据
- 显示总费用、请求数、日均费用、日均请求等汇总指标
- 显示今日概览、最高费用日、最高请求日
- 包含费用与请求趋势图(双Y轴)
- 复用模型分布图组件展示模型使用分布
- 显示实际扣费和标准计费(标准计费以较淡颜色显示)
2025-12-23 13:42:33 +08:00
dexcoder6
0b8e84f942 feat: 增强用户管理功能,添加用户名、微信号和备注字段
- 新增User模型字段:username(用户名)、wechat(微信号)、notes(备注)
- 扩展用户搜索功能,支持通过用户名和微信号搜索
- 添加用户个人资料更新功能,用户可自行编辑用户名和微信号
- 管理员用户列表新增用户名、微信号、备注显示列
- 备注字段仅对管理员可见,增强数据安全性
- 完善中英文国际化翻译
- 修复国际化文件中重复属性的TypeScript错误

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-23 11:26:22 +08:00
shaw
d9e27df9af feat: 账号列表显示所属分组
- Account模型新增Groups虚拟字段
- 账号列表API预加载Group信息
- 账号管理页面新增分组列,使用GroupBadge展示
2025-12-23 11:20:02 +08:00
shaw
f0fabf89a1 feat: 用户列表显示订阅分组及剩余天数
- User模型新增Subscriptions关联
- 用户列表批量加载订阅信息避免N+1查询
- GroupBadge组件支持显示剩余天数(过期红色、<=3天红色、<=7天橙色)
- 用户管理页面新增订阅分组列
2025-12-23 11:03:10 +08:00
shaw
5bbfbcdae9 fix: 修复订阅窗口过期后进度条显示不正确的问题
问题:滑动窗口过期后(如昨天用满额度),前端仍显示历史数据(红色进度条100%、"即将重置")

解决:
- 后端返回数据前检查窗口是否过期,过期则清零展示数据
- 前端处理 window_start 为 null 的情况,显示"窗口未激活"
- 不影响实际的窗口激活逻辑,窗口仍从当天零点开始
2025-12-23 10:38:15 +08:00
shaw
eb55947ec4 fix: 修复golangci-lint检查问题
- 移除OpenAIGatewayHandler中未使用的userService字段
- 将账号类型判断的if-else链改为switch语句
2025-12-23 10:25:32 +08:00
shaw
5f7e5184eb feat: admin/subscriptions新增重置时间显示 2025-12-23 10:14:41 +08:00
shaw
008a111268 chore: 更新前端构建信息 2025-12-23 10:03:34 +08:00
shaw
fda753278c feat: 平台图标与计费修复
- fix(billing): 修复 OpenAI 兼容 API 缓存 token 重复计费问题
- fix(auth): 隐藏数据库错误详情,返回通用服务不可用错误
- feat(ui): 新增 PlatformIcon 组件,GroupBadge 支持平台颜色区分
- feat(ui): 账号管理新增重置状态按钮,重授权后自动清除错误
- feat(ui): 分组管理新增计费类型列,显示订阅限额信息
- ui: 首页 GPT 状态改为已支持
2025-12-23 10:01:58 +08:00
shaw
6c469b42ed feat: 新增支持codex转发 2025-12-22 22:58:31 +08:00
shaw
dacf3a2a6e fix: 去掉accept-encoding透传 2025-12-21 21:30:19 +08:00
shaw
e6add93ae3 fix(build): add -tags embed to ensure frontend is embedded
- Add -tags=embed flag to GoReleaser builds
- Add -tags embed flag to Dockerfile builds
- Fix Dockerfile COPY order to prevent frontend dist being overwritten
- Update README build instructions with embed tag explanation
2025-12-20 19:13:26 +08:00
NepetaLemon
b2273ec695 ci(backend): 修复 backend-ci 2025-12-20 16:52:38 +08:00
Forest
aa89777dda ci(backend): 调整 embed server 2025-12-20 16:44:25 +08:00
Forest
1e1f3c0c74 ci(backend): 添加 gofmt 配置 2025-12-20 16:19:40 +08:00
Forest
1fab9204eb ci(backend): 添加 unused 配置 2025-12-20 16:12:44 +08:00
Forest
dbd3e71637 ci(backend): 添加 staticcheck 配置 2025-12-20 16:01:24 +08:00
Forest
974f67211b ci(backend): 添加 ineffassign 配置 2025-12-20 15:58:08 +08:00
Forest
0338c83b90 ci(backend): 添加 errcheck 配置 2025-12-20 15:52:13 +08:00
NepetaLemon
c6b3de1199 ci(backend): 添加 github actions (#10)
## 变更内容

### CI/CD
- 添加 GitHub Actions 工作流(test + golangci-lint)
- 添加 golangci-lint 配置,启用 errcheck/govet/staticcheck/unused/depguard
- 通过 depguard 强制 service 层不能直接导入 repository

### 错误处理修复
- 修复 CSV 写入、SSE 流式输出、随机数生成等未处理的错误
- GenerateRedeemCode() 现在返回 error

### 资源泄露修复
- 统一使用 defer func() { _ = xxx.Close() }() 模式

### 代码清理
- 移除未使用的常量
- 简化 nil map 检查
- 统一代码格式
2025-12-20 02:29:52 -05:00
180 changed files with 11961 additions and 2064 deletions

38
.github/workflows/backend-ci.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: CI
on:
push:
pull_request:
permissions:
contents: read
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: backend/go.mod
check-latest: true
cache: true
- name: Run tests
working-directory: backend
run: go test ./...
golangci-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: backend/go.mod
check-latest: true
cache: true
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:
version: v2.7
args: --timeout=5m
working-directory: backend

View File

@@ -85,6 +85,19 @@ jobs:
go-version: '1.24'
cache-dependency-path: backend/go.sum
# Docker setup for GoReleaser
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Fetch tags with annotations
run: |
# 确保获取完整的 annotated tag 信息
@@ -117,87 +130,16 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
# ===========================================================================
# Docker Build and Push
# ===========================================================================
docker:
needs: [update-version, build-frontend]
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Download VERSION artifact
uses: actions/download-artifact@v4
with:
name: version-file
path: backend/cmd/server/
- name: Download frontend artifact
uses: actions/download-artifact@v4
with:
name: frontend-dist
path: backend/internal/web/dist/
# Extract version from tag
- name: Extract version
id: version
run: |
VERSION=${GITHUB_REF#refs/tags/v}
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Version: $VERSION"
# Set up Docker Buildx for multi-platform builds
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# Login to DockerHub
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Extract metadata for Docker
- name: Extract Docker metadata
id: meta
uses: docker/metadata-action@v5
with:
images: |
weishaw/sub2api
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=raw,value=latest,enable={{is_default_branch}}
# Build and push Docker image
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
COMMIT=${{ github.sha }}
DATE=${{ github.event.head_commit.timestamp }}
cache-from: type=gha
cache-to: type=gha,mode=max
# Update DockerHub description (optional)
# Update DockerHub description
- name: Update DockerHub description
uses: peter-evans/dockerhub-description@v4
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
repository: weishaw/sub2api
repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
short-description: "Sub2API - AI API Gateway Platform"
readme-filepath: ./deploy/DOCKER.md

15
.gitignore vendored
View File

@@ -28,6 +28,7 @@ node_modules/
frontend/node_modules/
frontend/dist/
*.local
*.tsbuildinfo
# 日志
npm-debug.log*
@@ -81,15 +82,27 @@ build/
release/
# 后端嵌入的前端构建产物
# Keep a placeholder file so `//go:embed all:dist` always has a match in CI/lint,
# while still ignoring generated frontend build outputs.
backend/internal/web/dist/
!backend/internal/web/dist/
backend/internal/web/dist/*
!backend/internal/web/dist/.keep
# 后端运行时缓存数据
backend/data/
# ===================
# 本地配置文件(包含敏感信息)
# ===================
backend/config.yaml
deploy/config.yaml
backend/.installed
# ===================
# 其他
# ===================
tests
CLAUDE.md
.claude
scripts
scripts

View File

@@ -11,6 +11,8 @@ builds:
dir: backend
main: ./cmd/server
binary: sub2api
flags:
- -tags=embed
env:
- CGO_ENABLED=0
goos:
@@ -50,10 +52,58 @@ changelog:
# 禁用自动 changelog完全使用 tag 消息
disable: true
# Docker images
dockers:
- id: amd64
goos: linux
goarch: amd64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
- id: arm64
goos: linux
goarch: arm64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
# Docker manifests for multi-arch support
docker_manifests:
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
release:
github:
owner: Wei-Shaw
name: sub2api
owner: "{{ .Env.GITHUB_REPO_OWNER }}"
name: "{{ .Env.GITHUB_REPO_NAME }}"
draft: false
prerelease: auto
name_template: "Sub2API {{.Version}}"
@@ -71,7 +121,7 @@ release:
**One-line install (Linux):**
```bash
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash
curl -sSL https://raw.githubusercontent.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/main/deploy/install.sh | sudo bash
```
**Manual download:**
@@ -79,5 +129,5 @@ release:
## 📚 Documentation
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api)
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md)
- [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }})
- [Installation Guide](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/blob/main/deploy/README.md)

View File

@@ -40,14 +40,15 @@ WORKDIR /app/backend
COPY backend/go.mod backend/go.sum ./
RUN go mod download
# Copy frontend dist from previous stage
COPY --from=frontend-builder /app/frontend/../backend/internal/web/dist ./internal/web/dist
# Copy backend source
# Copy backend source first
COPY backend/ ./
# Build the binary (BuildType=release for CI builds)
# Copy frontend dist from previous stage (must be after backend copy to avoid being overwritten)
COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
# Build the binary (BuildType=release for CI builds, embed frontend)
RUN CGO_ENABLED=0 GOOS=linux go build \
-tags embed \
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
-o /app/sub2api \
./cmd/server

40
Dockerfile.goreleaser Normal file
View File

@@ -0,0 +1,40 @@
# =============================================================================
# Sub2API Dockerfile for GoReleaser
# =============================================================================
# This Dockerfile is used by GoReleaser to build Docker images.
# It only packages the pre-built binary, no compilation needed.
# =============================================================================
FROM alpine:3.19
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
LABEL description="Sub2API - AI API Gateway Platform"
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
# Install runtime dependencies
RUN apk add --no-cache \
ca-certificates \
tzdata \
curl \
&& rm -rf /var/cache/apk/*
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
WORKDIR /app
# Copy pre-built binary from GoReleaser
COPY sub2api /app/sub2api
# Create data directory
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
USER sub2api
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
ENTRYPOINT ["/app/sub2api"]

View File

@@ -220,21 +220,21 @@ cd sub2api
cd frontend
npm install
npm run build
# Output will be in ../backend/internal/web/dist/
# 3. Copy frontend build to backend (for embedding)
cp -r dist ../backend/internal/web/
# 4. Build backend (requires frontend dist to be present)
# 3. Build backend with embedded frontend
cd ../backend
go build -o sub2api ./cmd/server
go build -tags embed -o sub2api ./cmd/server
# 5. Create configuration file
# 4. Create configuration file
cp ../deploy/config.example.yaml ./config.yaml
# 6. Edit configuration
# 5. Edit configuration
nano config.yaml
```
> **Note:** The `-tags embed` flag embeds the frontend into the binary. Without this flag, the binary will not serve the frontend UI.
**Key configuration in `config.yaml`:**
```yaml
@@ -265,7 +265,7 @@ default:
```
```bash
# 7. Run the application
# 6. Run the application
./sub2api
```

View File

@@ -220,21 +220,21 @@ cd sub2api
cd frontend
npm install
npm run build
# 构建产物输出到 ../backend/internal/web/dist/
# 3. 复制前端构建产物到后端(用于嵌入
cp -r dist ../backend/internal/web/
# 4. 编译后端(需要前端 dist 目录存在)
# 3. 编译后端(嵌入前端
cd ../backend
go build -o sub2api ./cmd/server
go build -tags embed -o sub2api ./cmd/server
# 5. 创建配置文件
# 4. 创建配置文件
cp ../deploy/config.example.yaml ./config.yaml
# 6. 编辑配置
# 5. 编辑配置
nano config.yaml
```
> **注意:** `-tags embed` 参数会将前端嵌入到二进制文件中。不使用此参数编译的程序将不包含前端界面。
**`config.yaml` 关键配置:**
```yaml
@@ -265,7 +265,7 @@ default:
```
```bash
# 7. 运行应用
# 6. 运行应用
./sub2api
```

594
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,594 @@
version: "2"
linters:
default: none
enable:
- depguard
- errcheck
- govet
- ineffassign
- staticcheck
- unused
settings:
depguard:
rules:
# Enforce: service must not depend on repository.
service-no-repository:
list-mode: original
files:
- "**/internal/service/**"
deny:
- pkg: sub2api/internal/repository
desc: "service must not import repository"
handler-no-repository:
list-mode: original
files:
- "**/internal/handler/**"
deny:
- pkg: sub2api/internal/repository
desc: "handler must not import repository"
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: true
# report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`.
# Such cases aren't reported by default.
# Default: false
check-blank: false
# To disable the errcheck built-in exclude list.
# See `-excludeonly` option in https://github.com/kisielk/errcheck#excluding-functions for details.
# Default: false
disable-default-exclusions: true
# List of functions to exclude from checking, where each entry is a single function to exclude.
# See https://github.com/kisielk/errcheck#excluding-functions for details.
exclude-functions:
- io/ioutil.ReadFile
- io.Copy(*bytes.Buffer)
- io.Copy(os.Stdout)
- fmt.Println
- fmt.Print
- fmt.Printf
- fmt.Fprint
- fmt.Fprintf
- fmt.Fprintln
# Display function signature instead of selector.
# Default: false
verbose: true
ineffassign:
# Check escaping variables of type error, may cause false positives.
# Default: false
check-escaping-errors: true
staticcheck:
# https://staticcheck.dev/docs/configuration/options/#dot_import_whitelist
# Default: ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"]
dot-import-whitelist:
- fmt
# https://staticcheck.dev/docs/configuration/options/#initialisms
# Default: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"]
initialisms: [ "ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS" ]
# https://staticcheck.dev/docs/configuration/options/#http_status_code_whitelist
# Default: ["200", "400", "404", "500"]
http-status-code-whitelist: [ "200", "400", "404", "500" ]
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
# Run `GL_DEBUG=staticcheck golangci-lint run --enable=staticcheck` to see all available checks and enabled by config checks.
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
checks:
# Invalid regular expression.
# https://staticcheck.dev/docs/checks/#SA1000
- SA1000
# Invalid template.
# https://staticcheck.dev/docs/checks/#SA1001
- SA1001
# Invalid format in 'time.Parse'.
# https://staticcheck.dev/docs/checks/#SA1002
- SA1002
# Unsupported argument to functions in 'encoding/binary'.
# https://staticcheck.dev/docs/checks/#SA1003
- SA1003
# Suspiciously small untyped constant in 'time.Sleep'.
# https://staticcheck.dev/docs/checks/#SA1004
- SA1004
# Invalid first argument to 'exec.Command'.
# https://staticcheck.dev/docs/checks/#SA1005
- SA1005
# 'Printf' with dynamic first argument and no further arguments.
# https://staticcheck.dev/docs/checks/#SA1006
- SA1006
# Invalid URL in 'net/url.Parse'.
# https://staticcheck.dev/docs/checks/#SA1007
- SA1007
# Non-canonical key in 'http.Header' map.
# https://staticcheck.dev/docs/checks/#SA1008
- SA1008
# '(*regexp.Regexp).FindAll' called with 'n == 0', which will always return zero results.
# https://staticcheck.dev/docs/checks/#SA1010
- SA1010
# Various methods in the "strings" package expect valid UTF-8, but invalid input is provided.
# https://staticcheck.dev/docs/checks/#SA1011
- SA1011
# A nil 'context.Context' is being passed to a function, consider using 'context.TODO' instead.
# https://staticcheck.dev/docs/checks/#SA1012
- SA1012
# 'io.Seeker.Seek' is being called with the whence constant as the first argument, but it should be the second.
# https://staticcheck.dev/docs/checks/#SA1013
- SA1013
# Non-pointer value passed to 'Unmarshal' or 'Decode'.
# https://staticcheck.dev/docs/checks/#SA1014
- SA1014
# Using 'time.Tick' in a way that will leak. Consider using 'time.NewTicker', and only use 'time.Tick' in tests, commands and endless functions.
# https://staticcheck.dev/docs/checks/#SA1015
- SA1015
# Trapping a signal that cannot be trapped.
# https://staticcheck.dev/docs/checks/#SA1016
- SA1016
# Channels used with 'os/signal.Notify' should be buffered.
# https://staticcheck.dev/docs/checks/#SA1017
- SA1017
# 'strings.Replace' called with 'n == 0', which does nothing.
# https://staticcheck.dev/docs/checks/#SA1018
- SA1018
# Using a deprecated function, variable, constant or field.
# https://staticcheck.dev/docs/checks/#SA1019
- SA1019
# Using an invalid host:port pair with a 'net.Listen'-related function.
# https://staticcheck.dev/docs/checks/#SA1020
- SA1020
# Using 'bytes.Equal' to compare two 'net.IP'.
# https://staticcheck.dev/docs/checks/#SA1021
- SA1021
# Modifying the buffer in an 'io.Writer' implementation.
# https://staticcheck.dev/docs/checks/#SA1023
- SA1023
# A string cutset contains duplicate characters.
# https://staticcheck.dev/docs/checks/#SA1024
- SA1024
# It is not possible to use '(*time.Timer).Reset''s return value correctly.
# https://staticcheck.dev/docs/checks/#SA1025
- SA1025
# Cannot marshal channels or functions.
# https://staticcheck.dev/docs/checks/#SA1026
- SA1026
# Atomic access to 64-bit variable must be 64-bit aligned.
# https://staticcheck.dev/docs/checks/#SA1027
- SA1027
# 'sort.Slice' can only be used on slices.
# https://staticcheck.dev/docs/checks/#SA1028
- SA1028
# Inappropriate key in call to 'context.WithValue'.
# https://staticcheck.dev/docs/checks/#SA1029
- SA1029
# Invalid argument in call to a 'strconv' function.
# https://staticcheck.dev/docs/checks/#SA1030
- SA1030
# Overlapping byte slices passed to an encoder.
# https://staticcheck.dev/docs/checks/#SA1031
- SA1031
# Wrong order of arguments to 'errors.Is'.
# https://staticcheck.dev/docs/checks/#SA1032
- SA1032
# 'sync.WaitGroup.Add' called inside the goroutine, leading to a race condition.
# https://staticcheck.dev/docs/checks/#SA2000
- SA2000
# Empty critical section, did you mean to defer the unlock?.
# https://staticcheck.dev/docs/checks/#SA2001
- SA2001
# Called 'testing.T.FailNow' or 'SkipNow' in a goroutine, which isn't allowed.
# https://staticcheck.dev/docs/checks/#SA2002
- SA2002
# Deferred 'Lock' right after locking, likely meant to defer 'Unlock' instead.
# https://staticcheck.dev/docs/checks/#SA2003
- SA2003
# 'TestMain' doesn't call 'os.Exit', hiding test failures.
# https://staticcheck.dev/docs/checks/#SA3000
- SA3000
# Assigning to 'b.N' in benchmarks distorts the results.
# https://staticcheck.dev/docs/checks/#SA3001
- SA3001
# Binary operator has identical expressions on both sides.
# https://staticcheck.dev/docs/checks/#SA4000
- SA4000
# '&*x' gets simplified to 'x', it does not copy 'x'.
# https://staticcheck.dev/docs/checks/#SA4001
- SA4001
# Comparing unsigned values against negative values is pointless.
# https://staticcheck.dev/docs/checks/#SA4003
- SA4003
# The loop exits unconditionally after one iteration.
# https://staticcheck.dev/docs/checks/#SA4004
- SA4004
# Field assignment that will never be observed. Did you mean to use a pointer receiver?.
# https://staticcheck.dev/docs/checks/#SA4005
- SA4005
# A value assigned to a variable is never read before being overwritten. Forgotten error check or dead code?.
# https://staticcheck.dev/docs/checks/#SA4006
- SA4006
# The variable in the loop condition never changes, are you incrementing the wrong variable?.
# https://staticcheck.dev/docs/checks/#SA4008
- SA4008
# A function argument is overwritten before its first use.
# https://staticcheck.dev/docs/checks/#SA4009
- SA4009
# The result of 'append' will never be observed anywhere.
# https://staticcheck.dev/docs/checks/#SA4010
- SA4010
# Break statement with no effect. Did you mean to break out of an outer loop?.
# https://staticcheck.dev/docs/checks/#SA4011
- SA4011
# Comparing a value against NaN even though no value is equal to NaN.
# https://staticcheck.dev/docs/checks/#SA4012
- SA4012
# Negating a boolean twice ('!!b') is the same as writing 'b'. This is either redundant, or a typo.
# https://staticcheck.dev/docs/checks/#SA4013
- SA4013
# An if/else if chain has repeated conditions and no side-effects; if the condition didn't match the first time, it won't match the second time, either.
# https://staticcheck.dev/docs/checks/#SA4014
- SA4014
# Calling functions like 'math.Ceil' on floats converted from integers doesn't do anything useful.
# https://staticcheck.dev/docs/checks/#SA4015
- SA4015
# Certain bitwise operations, such as 'x ^ 0', do not do anything useful.
# https://staticcheck.dev/docs/checks/#SA4016
- SA4016
# Discarding the return values of a function without side effects, making the call pointless.
# https://staticcheck.dev/docs/checks/#SA4017
- SA4017
# Self-assignment of variables.
# https://staticcheck.dev/docs/checks/#SA4018
- SA4018
# Multiple, identical build constraints in the same file.
# https://staticcheck.dev/docs/checks/#SA4019
- SA4019
# Unreachable case clause in a type switch.
# https://staticcheck.dev/docs/checks/#SA4020
- SA4020
# "x = append(y)" is equivalent to "x = y".
# https://staticcheck.dev/docs/checks/#SA4021
- SA4021
# Comparing the address of a variable against nil.
# https://staticcheck.dev/docs/checks/#SA4022
- SA4022
# Impossible comparison of interface value with untyped nil.
# https://staticcheck.dev/docs/checks/#SA4023
- SA4023
# Checking for impossible return value from a builtin function.
# https://staticcheck.dev/docs/checks/#SA4024
- SA4024
# Integer division of literals that results in zero.
# https://staticcheck.dev/docs/checks/#SA4025
- SA4025
# Go constants cannot express negative zero.
# https://staticcheck.dev/docs/checks/#SA4026
- SA4026
# '(*net/url.URL).Query' returns a copy, modifying it doesn't change the URL.
# https://staticcheck.dev/docs/checks/#SA4027
- SA4027
# 'x % 1' is always zero.
# https://staticcheck.dev/docs/checks/#SA4028
- SA4028
# Ineffective attempt at sorting slice.
# https://staticcheck.dev/docs/checks/#SA4029
- SA4029
# Ineffective attempt at generating random number.
# https://staticcheck.dev/docs/checks/#SA4030
- SA4030
# Checking never-nil value against nil.
# https://staticcheck.dev/docs/checks/#SA4031
- SA4031
# Comparing 'runtime.GOOS' or 'runtime.GOARCH' against impossible value.
# https://staticcheck.dev/docs/checks/#SA4032
- SA4032
# Assignment to nil map.
# https://staticcheck.dev/docs/checks/#SA5000
- SA5000
# Deferring 'Close' before checking for a possible error.
# https://staticcheck.dev/docs/checks/#SA5001
- SA5001
# The empty for loop ("for {}") spins and can block the scheduler.
# https://staticcheck.dev/docs/checks/#SA5002
- SA5002
# Defers in infinite loops will never execute.
# https://staticcheck.dev/docs/checks/#SA5003
- SA5003
# "for { select { ..." with an empty default branch spins.
# https://staticcheck.dev/docs/checks/#SA5004
- SA5004
# The finalizer references the finalized object, preventing garbage collection.
# https://staticcheck.dev/docs/checks/#SA5005
- SA5005
# Infinite recursive call.
# https://staticcheck.dev/docs/checks/#SA5007
- SA5007
# Invalid struct tag.
# https://staticcheck.dev/docs/checks/#SA5008
- SA5008
# Invalid Printf call.
# https://staticcheck.dev/docs/checks/#SA5009
- SA5009
# Impossible type assertion.
# https://staticcheck.dev/docs/checks/#SA5010
- SA5010
# Possible nil pointer dereference.
# https://staticcheck.dev/docs/checks/#SA5011
- SA5011
# Passing odd-sized slice to function expecting even size.
# https://staticcheck.dev/docs/checks/#SA5012
- SA5012
# Using 'regexp.Match' or related in a loop, should use 'regexp.Compile'.
# https://staticcheck.dev/docs/checks/#SA6000
- SA6000
# Missing an optimization opportunity when indexing maps by byte slices.
# https://staticcheck.dev/docs/checks/#SA6001
- SA6001
# Storing non-pointer values in 'sync.Pool' allocates memory.
# https://staticcheck.dev/docs/checks/#SA6002
- SA6002
# Converting a string to a slice of runes before ranging over it.
# https://staticcheck.dev/docs/checks/#SA6003
- SA6003
# Inefficient string comparison with 'strings.ToLower' or 'strings.ToUpper'.
# https://staticcheck.dev/docs/checks/#SA6005
- SA6005
# Using io.WriteString to write '[]byte'.
# https://staticcheck.dev/docs/checks/#SA6006
- SA6006
# Defers in range loops may not run when you expect them to.
# https://staticcheck.dev/docs/checks/#SA9001
- SA9001
# Using a non-octal 'os.FileMode' that looks like it was meant to be in octal.
# https://staticcheck.dev/docs/checks/#SA9002
- SA9002
# Empty body in an if or else branch.
# https://staticcheck.dev/docs/checks/#SA9003
- SA9003
# Only the first constant has an explicit type.
# https://staticcheck.dev/docs/checks/#SA9004
- SA9004
# Trying to marshal a struct with no public fields nor custom marshaling.
# https://staticcheck.dev/docs/checks/#SA9005
- SA9005
# Dubious bit shifting of a fixed size integer value.
# https://staticcheck.dev/docs/checks/#SA9006
- SA9006
# Deleting a directory that shouldn't be deleted.
# https://staticcheck.dev/docs/checks/#SA9007
- SA9007
# 'else' branch of a type assertion is probably not reading the right value.
# https://staticcheck.dev/docs/checks/#SA9008
- SA9008
# Ineffectual Go compiler directive.
# https://staticcheck.dev/docs/checks/#SA9009
- SA9009
# Incorrect or missing package comment.
# https://staticcheck.dev/docs/checks/#ST1000
- ST1000
# Dot imports are discouraged.
# https://staticcheck.dev/docs/checks/#ST1001
- ST1001
# Poorly chosen identifier.
# https://staticcheck.dev/docs/checks/#ST1003
- ST1003
# Incorrectly formatted error string.
# https://staticcheck.dev/docs/checks/#ST1005
- ST1005
# Poorly chosen receiver name.
# https://staticcheck.dev/docs/checks/#ST1006
- ST1006
# A function's error value should be its last return value.
# https://staticcheck.dev/docs/checks/#ST1008
- ST1008
# Poorly chosen name for variable of type 'time.Duration'.
# https://staticcheck.dev/docs/checks/#ST1011
- ST1011
# Poorly chosen name for error variable.
# https://staticcheck.dev/docs/checks/#ST1012
- ST1012
# Should use constants for HTTP error codes, not magic numbers.
# https://staticcheck.dev/docs/checks/#ST1013
- ST1013
# A switch's default case should be the first or last case.
# https://staticcheck.dev/docs/checks/#ST1015
- ST1015
# Use consistent method receiver names.
# https://staticcheck.dev/docs/checks/#ST1016
- ST1016
# Don't use Yoda conditions.
# https://staticcheck.dev/docs/checks/#ST1017
- ST1017
# Avoid zero-width and control characters in string literals.
# https://staticcheck.dev/docs/checks/#ST1018
- ST1018
# Importing the same package multiple times.
# https://staticcheck.dev/docs/checks/#ST1019
- ST1019
# The documentation of an exported function should start with the function's name.
# https://staticcheck.dev/docs/checks/#ST1020
- ST1020
# The documentation of an exported type should start with type's name.
# https://staticcheck.dev/docs/checks/#ST1021
- ST1021
# The documentation of an exported variable or constant should start with variable's name.
# https://staticcheck.dev/docs/checks/#ST1022
- ST1022
# Redundant type in variable declaration.
# https://staticcheck.dev/docs/checks/#ST1023
- ST1023
# Use plain channel send or receive instead of single-case select.
# https://staticcheck.dev/docs/checks/#S1000
- S1000
# Replace for loop with call to copy.
# https://staticcheck.dev/docs/checks/#S1001
- S1001
# Omit comparison with boolean constant.
# https://staticcheck.dev/docs/checks/#S1002
- S1002
# Replace call to 'strings.Index' with 'strings.Contains'.
# https://staticcheck.dev/docs/checks/#S1003
- S1003
# Replace call to 'bytes.Compare' with 'bytes.Equal'.
# https://staticcheck.dev/docs/checks/#S1004
- S1004
# Drop unnecessary use of the blank identifier.
# https://staticcheck.dev/docs/checks/#S1005
- S1005
# Use "for { ... }" for infinite loops.
# https://staticcheck.dev/docs/checks/#S1006
- S1006
# Simplify regular expression by using raw string literal.
# https://staticcheck.dev/docs/checks/#S1007
- S1007
# Simplify returning boolean expression.
# https://staticcheck.dev/docs/checks/#S1008
- S1008
# Omit redundant nil check on slices, maps, and channels.
# https://staticcheck.dev/docs/checks/#S1009
- S1009
# Omit default slice index.
# https://staticcheck.dev/docs/checks/#S1010
- S1010
# Use a single 'append' to concatenate two slices.
# https://staticcheck.dev/docs/checks/#S1011
- S1011
# Replace 'time.Now().Sub(x)' with 'time.Since(x)'.
# https://staticcheck.dev/docs/checks/#S1012
- S1012
# Use a type conversion instead of manually copying struct fields.
# https://staticcheck.dev/docs/checks/#S1016
- S1016
# Replace manual trimming with 'strings.TrimPrefix'.
# https://staticcheck.dev/docs/checks/#S1017
- S1017
# Use "copy" for sliding elements.
# https://staticcheck.dev/docs/checks/#S1018
- S1018
# Simplify "make" call by omitting redundant arguments.
# https://staticcheck.dev/docs/checks/#S1019
- S1019
# Omit redundant nil check in type assertion.
# https://staticcheck.dev/docs/checks/#S1020
- S1020
# Merge variable declaration and assignment.
# https://staticcheck.dev/docs/checks/#S1021
- S1021
# Omit redundant control flow.
# https://staticcheck.dev/docs/checks/#S1023
- S1023
# Replace 'x.Sub(time.Now())' with 'time.Until(x)'.
# https://staticcheck.dev/docs/checks/#S1024
- S1024
# Don't use 'fmt.Sprintf("%s", x)' unnecessarily.
# https://staticcheck.dev/docs/checks/#S1025
- S1025
# Simplify error construction with 'fmt.Errorf'.
# https://staticcheck.dev/docs/checks/#S1028
- S1028
# Range over the string directly.
# https://staticcheck.dev/docs/checks/#S1029
- S1029
# Use 'bytes.Buffer.String' or 'bytes.Buffer.Bytes'.
# https://staticcheck.dev/docs/checks/#S1030
- S1030
# Omit redundant nil check around loop.
# https://staticcheck.dev/docs/checks/#S1031
- S1031
# Use 'sort.Ints(x)', 'sort.Float64s(x)', and 'sort.Strings(x)'.
# https://staticcheck.dev/docs/checks/#S1032
- S1032
# Unnecessary guard around call to "delete".
# https://staticcheck.dev/docs/checks/#S1033
- S1033
# Use result of type assertion to simplify cases.
# https://staticcheck.dev/docs/checks/#S1034
- S1034
# Redundant call to 'net/http.CanonicalHeaderKey' in method call on 'net/http.Header'.
# https://staticcheck.dev/docs/checks/#S1035
- S1035
# Unnecessary guard around map access.
# https://staticcheck.dev/docs/checks/#S1036
- S1036
# Elaborate way of sleeping.
# https://staticcheck.dev/docs/checks/#S1037
- S1037
# Unnecessarily complex way of printing formatted string.
# https://staticcheck.dev/docs/checks/#S1038
- S1038
# Unnecessary use of 'fmt.Sprint'.
# https://staticcheck.dev/docs/checks/#S1039
- S1039
# Type assertion to current type.
# https://staticcheck.dev/docs/checks/#S1040
- S1040
# Apply De Morgan's law.
# https://staticcheck.dev/docs/checks/#QF1001
- QF1001
# Convert untagged switch to tagged switch.
# https://staticcheck.dev/docs/checks/#QF1002
- QF1002
# Convert if/else-if chain to tagged switch.
# https://staticcheck.dev/docs/checks/#QF1003
- QF1003
# Use 'strings.ReplaceAll' instead of 'strings.Replace' with 'n == -1'.
# https://staticcheck.dev/docs/checks/#QF1004
- QF1004
# Expand call to 'math.Pow'.
# https://staticcheck.dev/docs/checks/#QF1005
- QF1005
# Lift 'if'+'break' into loop condition.
# https://staticcheck.dev/docs/checks/#QF1006
- QF1006
# Merge conditional assignment into variable declaration.
# https://staticcheck.dev/docs/checks/#QF1007
- QF1007
# Omit embedded fields from selector expression.
# https://staticcheck.dev/docs/checks/#QF1008
- QF1008
# Use 'time.Time.Equal' instead of '==' operator.
# https://staticcheck.dev/docs/checks/#QF1009
- QF1009
# Convert slice of bytes to string when printing it.
# https://staticcheck.dev/docs/checks/#QF1010
- QF1010
# Omit redundant type from variable declaration.
# https://staticcheck.dev/docs/checks/#QF1011
- QF1011
# Use 'fmt.Fprintf(x, ...)' instead of 'x.Write(fmt.Sprintf(...))'.
# https://staticcheck.dev/docs/checks/#QF1012
- QF1012
unused:
# Mark all struct fields that have been written to as used.
# Default: true
field-writes-are-uses: false
# Treat IncDec statement (e.g. `i++` or `i--`) as both read and write operation instead of just write.
# Default: false
post-statements-are-reads: true
# Mark all exported fields as used.
# default: true
exported-fields-are-used: false
# Mark all function parameters as used.
# default: true
parameters-are-used: true
# Mark all local variables as used.
# default: true
local-variables-are-used: false
# Mark all identifiers inside generated files as used.
# Default: true
generated-is-used: false
formatters:
enable:
- gofmt
settings:
gofmt:
# Simplify code: gofmt with `-s` option.
# Default: true
simplify: false
# Apply the rewrite rules to the source before reformatting.
# https://pkg.go.dev/cmd/gofmt
# Default: []
rewrite-rules:
- pattern: 'interface{}'
replacement: 'any'
- pattern: 'a[b:len(a)]'
replacement: 'a[b:]'

View File

@@ -1,6 +1,16 @@
.PHONY: wire
.PHONY: wire build build-embed
wire:
@echo "生成 Wire 代码..."
@cd cmd/server && go generate
@echo "Wire 代码生成完成"
@echo "Wire 代码生成完成"
build:
@echo "构建后端(不嵌入前端)..."
@go build -o bin/server ./cmd/server
@echo "构建完成: bin/server"
build-embed:
@echo "构建后端(嵌入前端)..."
@go build -tags embed -o bin/server ./cmd/server
@echo "构建完成: bin/server (with embedded frontend)"

View File

@@ -15,11 +15,11 @@ import (
"syscall"
"time"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/middleware"
"sub2api/internal/setup"
"sub2api/internal/web"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/setup"
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
)

View File

@@ -4,12 +4,12 @@
package main
import (
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/infrastructure"
"sub2api/internal/repository"
"sub2api/internal/server"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/service"
"context"
"log"
@@ -85,6 +85,14 @@ func provideCleanup(
services.EmailQueue.Stop()
return nil
}},
{"OAuthService", func() error {
services.OAuth.Stop()
return nil
}},
{"OpenAIOAuthService", func() error {
services.OpenAIOAuth.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},

View File

@@ -8,17 +8,17 @@ package main
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
"log"
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/handler/admin"
"sub2api/internal/infrastructure"
"sub2api/internal/repository"
"sub2api/internal/server"
"sub2api/internal/service"
"time"
)
@@ -48,7 +48,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
authHandler := handler.NewAuthHandler(authService)
userService := service.NewUserService(userRepository, configConfig)
userService := service.NewUserService(userRepository)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db)
@@ -58,7 +58,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
billingCache := repository.NewBillingCache(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
@@ -67,22 +67,29 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardService := service.NewDashboardService(usageLogRepository)
dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db)
proxyExitInfoProber := repository.NewProxyExitInfoProber()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService, proxyExitInfoProber)
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService, claudeUsageFetcher)
claudeUpstream := repository.NewClaudeUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
httpUpstream := repository.NewHTTPUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService)
@@ -92,8 +99,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
gatewayCache := repository.NewGatewayCache(client)
pricingRemoteClient := repository.NewPricingRemoteClient()
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
@@ -103,43 +110,45 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(client)
identityService := service.NewIdentityService(identityCache)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
groupService := service.NewGroupService(groupRepository)
accountService := service.NewAccountService(accountRepository, groupRepository)
proxyService := service.NewProxyService(proxyRepository)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
services := &service.Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OAuth: oAuthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
Update: updateService,
TokenRefresh: tokenRefreshService,
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OpenAIGateway: openAIGatewayService,
OAuth: oAuthService,
OpenAIOAuth: openAIOAuthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
Update: updateService,
TokenRefresh: tokenRefreshService,
}
repositories := &repository.Repositories{
User: userRepository,
@@ -201,6 +210,14 @@ func provideCleanup(
services.EmailQueue.Stop()
return nil
}},
{"OAuthService", func() error {
services.OAuth.Stop()
return nil
}},
{"OpenAIOAuthService", func() error {
services.OpenAIOAuth.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},

View File

@@ -1,38 +0,0 @@
server:
host: "0.0.0.0"
port: 8080
mode: "debug" # debug/release
database:
host: "127.0.0.1"
port: 5432
user: "postgres"
password: "XZeRr7nkjHWhm8fw"
dbname: "sub2api"
sslmode: "disable"
redis:
host: "127.0.0.1"
port: 6379
password: ""
db: 0
jwt:
secret: "your-secret-key-change-in-production"
expire_hour: 24
default:
admin_email: "admin@sub2api.com"
admin_password: "admin123"
user_concurrency: 5
user_balance: 0
api_key_prefix: "sk-"
rate_multiplier: 1.0
# Timezone configuration (similar to PHP's date_default_timezone_set)
# This affects ALL time operations:
# - Database timestamps
# - Usage statistics "today" boundary
# - Subscription expiry times
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
timezone: "Asia/Shanghai"

View File

@@ -1,4 +1,4 @@
module sub2api
module github.com/Wei-Shaw/sub2api
go 1.24.0
@@ -8,10 +8,13 @@ require (
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.0
github.com/google/uuid v1.6.0
github.com/google/wire v0.7.0
github.com/imroc/req/v3 v3.56.0
github.com/lib/pq v1.10.9
github.com/redis/go-redis/v9 v9.3.0
github.com/spf13/viper v1.18.2
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
golang.org/x/crypto v0.44.0
golang.org/x/net v0.47.0
golang.org/x/term v0.37.0
@@ -35,7 +38,6 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/google/wire v0.7.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -64,6 +66,8 @@ require (
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
go.uber.org/atomic v1.9.0 // indirect

View File

@@ -139,6 +139,15 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=

View File

@@ -52,7 +52,7 @@ type PricingConfig struct {
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
}
@@ -163,7 +163,7 @@ func setDefaults() {
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
// Database
viper.SetDefault("database.host", "localhost")
@@ -210,10 +210,10 @@ func setDefaults() {
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
}
func (c *Config) Validate() error {

View File

@@ -3,9 +3,12 @@ package admin
import (
"strconv"
"sub2api/internal/pkg/claude"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -13,14 +16,12 @@ import (
// OAuthHandler handles OAuth-related operations for accounts
type OAuthHandler struct {
oauthService *service.OAuthService
adminService service.AdminService
}
// NewOAuthHandler creates a new OAuth handler
func NewOAuthHandler(oauthService *service.OAuthService, adminService service.AdminService) *OAuthHandler {
func NewOAuthHandler(oauthService *service.OAuthService) *OAuthHandler {
return &OAuthHandler{
oauthService: oauthService,
adminService: adminService,
}
}
@@ -28,47 +29,81 @@ func NewOAuthHandler(oauthService *service.OAuthService, adminService service.Ad
type AccountHandler struct {
adminService service.AdminService
oauthService *service.OAuthService
openaiOAuthService *service.OpenAIOAuthService
rateLimitService *service.RateLimitService
accountUsageService *service.AccountUsageService
accountTestService *service.AccountTestService
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
}
// NewAccountHandler creates a new admin account handler
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService) *AccountHandler {
func NewAccountHandler(
adminService service.AdminService,
oauthService *service.OAuthService,
openaiOAuthService *service.OpenAIOAuthService,
rateLimitService *service.RateLimitService,
accountUsageService *service.AccountUsageService,
accountTestService *service.AccountTestService,
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
) *AccountHandler {
return &AccountHandler{
adminService: adminService,
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
rateLimitService: rateLimitService,
accountUsageService: accountUsageService,
accountTestService: accountTestService,
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
}
}
// CreateAccountRequest represents create account request
type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]interface{} `json:"credentials" binding:"required"`
Extra map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
Name string `json:"name" binding:"required"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
}
// UpdateAccountRequest represents update account request
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateAccountRequest struct {
Name string `json:"name"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]interface{} `json:"credentials"`
Extra map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
Name string `json:"name"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
}
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*model.Account
CurrentConcurrency int `json:"current_concurrency"`
}
// List handles listing all accounts with pagination
@@ -86,7 +121,28 @@ func (h *AccountHandler) List(c *gin.Context) {
return
}
response.Paginated(c, accounts, total, page, pageSize)
// Get current concurrency counts for all accounts
accountIDs := make([]int64, len(accounts))
for i, acc := range accounts {
accountIDs[i] = acc.ID
}
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
if err != nil {
// Log error but don't fail the request, just use 0 for all
concurrencyCounts = make(map[int64]int)
}
// Build response with concurrency info
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
result[i] = AccountWithConcurrency{
Account: &accounts[i],
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
}
}
response.Paginated(c, result, total, page, pageSize)
}
// GetByID handles getting an account by ID
@@ -192,6 +248,13 @@ type TestAccountRequest struct {
ModelID string `json:"model_id"`
}
type SyncFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
}
// Test handles testing account connectivity with SSE streaming
// POST /api/v1/admin/accounts/:id/test
func (h *AccountHandler) Test(c *gin.Context) {
@@ -212,6 +275,35 @@ func (h *AccountHandler) Test(c *gin.Context) {
}
}
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
// POST /api/v1/admin/accounts/sync/crs
func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
var req SyncFromCRSRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Default to syncing proxies (can be disabled by explicitly setting false)
syncProxies := true
if req.SyncProxies != nil {
syncProxies = *req.SyncProxies
}
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
})
if err != nil {
response.BadRequest(c, "Sync failed: "+err.Error())
return
}
response.Success(c, result)
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {
@@ -234,26 +326,47 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
return
}
// Use OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
var newCredentials map[string]any
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
newCredentials := make(map[string]interface{})
for k, v := range account.Credentials {
newCredentials[k] = v
}
if account.IsOpenAI() {
// Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
// Update token-related fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = tokenInfo.ExpiresIn
newCredentials["expires_at"] = tokenInfo.ExpiresAt
newCredentials["refresh_token"] = tokenInfo.RefreshToken
newCredentials["scope"] = tokenInfo.Scope
// Build new credentials from token info
newCredentials = h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
} else {
// Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
newCredentials = make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
// Update token-related fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = tokenInfo.ExpiresIn
newCredentials["expires_at"] = tokenInfo.ExpiresAt
newCredentials["refresh_token"] = tokenInfo.RefreshToken
newCredentials["scope"] = tokenInfo.Scope
}
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
@@ -275,15 +388,26 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
return
}
// Return mock data for now
_ = accountID
response.Success(c, gin.H{
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"total_tokens": 0,
"average_response_time": 0,
})
// Parse days parameter (default 30)
days := 30
if daysStr := c.Query("days"); daysStr != "" {
if d, err := strconv.Atoi(daysStr); err == nil && d > 0 && d <= 90 {
days = d
}
}
// Calculate time range
now := timezone.Now()
endTime := timezone.StartOfDay(now.AddDate(0, 0, 1))
startTime := timezone.StartOfDay(now.AddDate(0, 0, -days+1))
stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get account stats: "+err.Error())
return
}
response.Success(c, stats)
}
// ClearError handles clearing account error
@@ -323,6 +447,136 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
})
}
// BatchUpdateCredentialsRequest represents batch credentials update request
type BatchUpdateCredentialsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Field string `json:"field" binding:"required,oneof=account_uuid org_uuid intercept_warmup_requests"`
Value any `json:"value"`
}
// BatchUpdateCredentials handles batch updating credentials fields
// POST /api/v1/admin/accounts/batch-update-credentials
func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
var req BatchUpdateCredentialsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Validate value type based on field
if req.Field == "intercept_warmup_requests" {
// Must be boolean
if _, ok := req.Value.(bool); !ok {
response.BadRequest(c, "intercept_warmup_requests must be boolean")
return
}
} else {
// account_uuid and org_uuid can be string or null
if req.Value != nil {
if _, ok := req.Value.(string); !ok {
response.BadRequest(c, req.Field+" must be string or null")
return
}
}
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
// Update account
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"account_id": accountID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
})
}
// BulkUpdate handles bulk updating accounts with selected fields/credentials.
// POST /api/v1/admin/accounts/bulk-update
func (h *AccountHandler) BulkUpdate(c *gin.Context) {
var req BulkUpdateAccountsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
hasUpdates := req.Name != "" ||
req.ProxyID != nil ||
req.Concurrency != nil ||
req.Priority != nil ||
req.Status != "" ||
req.GroupIDs != nil ||
len(req.Credentials) > 0 ||
len(req.Extra) > 0
if !hasUpdates {
response.BadRequest(c, "No updates provided")
return
}
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs,
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: req.Status,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
})
if err != nil {
response.InternalError(c, "Failed to bulk update accounts: "+err.Error())
return
}
response.Success(c, result)
}
// ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL
@@ -565,6 +819,46 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
// Handle OpenAI accounts
if account.IsOpenAI() {
// For OAuth accounts: return default OpenAI models
if account.IsOAuth() {
response.Success(c, openai.DefaultModels)
return
}
// For API Key accounts: check model_mapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
response.Success(c, openai.DefaultModels)
return
}
// Return mapped models
var models []openai.Model
for requestedModel := range mapping {
var found bool
for _, dm := range openai.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
if !found {
models = append(models, openai.Model{
ID: requestedModel,
Object: "model",
Type: "model",
DisplayName: requestedModel,
})
}
}
response.Success(c, models)
return
}
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
response.Success(c, claude.DefaultModels)
@@ -573,7 +867,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
// For API Key accounts: return models based on model_mapping
mapping := account.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
if len(mapping) == 0 {
// No mapping configured, return default models
response.Success(c, claude.DefaultModels)
return

View File

@@ -1,11 +1,10 @@
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"time"
"github.com/gin-gonic/gin"
@@ -13,17 +12,15 @@ import (
// DashboardHandler handles admin dashboard statistics
type DashboardHandler struct {
adminService service.AdminService
usageRepo *repository.UsageLogRepository
startTime time.Time // Server start time for uptime calculation
dashboardService *service.DashboardService
startTime time.Time // Server start time for uptime calculation
}
// NewDashboardHandler creates a new admin dashboard handler
func NewDashboardHandler(adminService service.AdminService, usageRepo *repository.UsageLogRepository) *DashboardHandler {
func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
return &DashboardHandler{
adminService: adminService,
usageRepo: usageRepo,
startTime: time.Now(),
dashboardService: dashboardService,
startTime: time.Now(),
}
}
@@ -61,7 +58,7 @@ func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
// GetStats handles getting dashboard statistics
// GET /api/v1/admin/dashboard/stats
func (h *DashboardHandler) GetStats(c *gin.Context) {
stats, err := h.usageRepo.GetDashboardStats(c.Request.Context())
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
@@ -110,6 +107,10 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
// 系统运行统计
"average_duration_ms": stats.AverageDurationMs,
"uptime": uptime,
// 性能指标
"rpm": stats.Rpm,
"tpm": stats.Tpm,
})
}
@@ -145,7 +146,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
}
}
trend, err := h.usageRepo.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -178,7 +179,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats, err := h.usageRepo.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
@@ -203,7 +204,7 @@ func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
limit = 5
}
trend, err := h.usageRepo.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get API key usage trend")
return
@@ -229,7 +230,7 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
limit = 12
}
trend, err := h.usageRepo.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
@@ -258,11 +259,11 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
}
if len(req.UserIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.usageRepo.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
@@ -286,11 +287,11 @@ func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
}
if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return

View File

@@ -3,9 +3,9 @@ package admin
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)

View File

@@ -0,0 +1,228 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// OpenAIOAuthHandler handles OpenAI OAuth-related operations
type OpenAIOAuthHandler struct {
openaiOAuthService *service.OpenAIOAuthService
adminService service.AdminService
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{
openaiOAuthService: openaiOAuthService,
adminService: adminService,
}
}
// OpenAIGenerateAuthURLRequest represents the request for generating OpenAI auth URL
type OpenAIGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
RedirectURI string `json:"redirect_uri"`
}
// GenerateAuthURL generates OpenAI OAuth authorization URL
// POST /api/v1/admin/openai/generate-auth-url
func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req OpenAIGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body
req = OpenAIGenerateAuthURLRequest{}
}
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
return
}
response.Success(c, result)
}
// OpenAIExchangeCodeRequest represents the request for exchanging OpenAI auth code
type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
}
// ExchangeCode exchanges OpenAI authorization code for tokens
// POST /api/v1/admin/openai/exchange-code
func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
var req OpenAIExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
var proxyURL string
if req.ProxyID != nil {
proxy, err := h.adminService.GetProxy(c.Request.Context(), *req.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
if err != nil {
response.BadRequest(c, "Failed to refresh token: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI account
// POST /api/v1/admin/openai/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
// Ensure account is OpenAI platform
if !account.IsOpenAI() {
response.BadRequest(c, "Account is not an OpenAI account")
return
}
// Only refresh OAuth-based accounts
if !account.IsOAuth() {
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
return
}
// Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
// Build new credentials from token info
newCredentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error())
return
}
response.Success(c, updatedAccount)
}
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Exchange code for tokens
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
// Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Use email as default name if not provided
name := req.Name
if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email
}
if name == "" {
name = "OpenAI OAuth Account"
}
// Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name,
Platform: "openai",
Type: "oauth",
Credentials: credentials,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
})
if err != nil {
response.InternalError(c, "Failed to create account: "+err.Error())
return
}
response.Success(c, account)
}

View File

@@ -4,8 +4,8 @@ import (
"strconv"
"strings"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -236,7 +236,6 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
response.Paginated(c, accounts, total, page, pageSize)
}
// BatchCreateProxyItem represents a single proxy in batch create request
type BatchCreateProxyItem struct {
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`

View File

@@ -6,8 +6,8 @@ import (
"fmt"
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -156,10 +156,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
func (h *RedeemHandler) GetStats(c *gin.Context) {
// Return mock data for now
response.Success(c, gin.H{
"total_codes": 0,
"active_codes": 0,
"used_codes": 0,
"expired_codes": 0,
"total_codes": 0,
"active_codes": 0,
"used_codes": 0,
"expired_codes": 0,
"total_value_distributed": 0.0,
"by_type": gin.H{
"balance": 0,
@@ -187,7 +187,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
writer := csv.NewWriter(&buf)
// Write header
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"})
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
// Write data rows
for _, code := range codes {
@@ -199,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
}
writer.Write([]string{
if err := writer.Write([]string{
fmt.Sprintf("%d", code.ID),
code.Code,
code.Type,
@@ -208,10 +211,17 @@ func (h *RedeemHandler) Export(c *gin.Context) {
usedBy,
usedAt,
code.CreatedAt.Format("2006-01-02 15:04:05"),
})
}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
}
writer.Flush()
if err := writer.Error(); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
c.Header("Content-Type", "text/csv")
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")

View File

@@ -1,9 +1,9 @@
package admin
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -60,6 +60,7 @@ type UpdateSettingsRequest struct {
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -104,6 +105,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
SiteSubtitle: req.SiteSubtitle,
ApiBaseUrl: req.ApiBaseUrl,
ContactInfo: req.ContactInfo,
DocUrl: req.DocUrl,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
}

View File

@@ -3,10 +3,10 @@ package admin
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)

View File

@@ -4,9 +4,9 @@ import (
"net/http"
"time"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/sysutil"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)

View File

@@ -4,35 +4,32 @@ import (
"strconv"
"time"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageRepo *repository.UsageLogRepository
apiKeyRepo *repository.ApiKeyRepository
usageService *service.UsageService
adminService service.AdminService
usageService *service.UsageService
apiKeyService *service.ApiKeyService
adminService service.AdminService
}
// NewUsageHandler creates a new admin usage handler
func NewUsageHandler(
usageRepo *repository.UsageLogRepository,
apiKeyRepo *repository.ApiKeyRepository,
usageService *service.UsageService,
apiKeyService *service.ApiKeyService,
adminService service.AdminService,
) *UsageHandler {
return &UsageHandler{
usageRepo: usageRepo,
apiKeyRepo: apiKeyRepo,
usageService: usageService,
adminService: adminService,
usageService: usageService,
apiKeyService: apiKeyService,
adminService: adminService,
}
}
@@ -84,14 +81,14 @@ func (h *UsageHandler) List(c *gin.Context) {
}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := repository.UsageLogFilters{
filters := usagestats.UsageLogFilters{
UserID: userID,
ApiKeyID: apiKeyID,
StartTime: startTime,
EndTime: endTime,
}
records, result, err := h.usageRepo.ListWithFilters(c.Request.Context(), params, filters)
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
return
@@ -179,7 +176,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
}
// Get global stats
stats, err := h.usageRepo.GetGlobalStats(c.Request.Context(), startTime, endTime)
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
@@ -193,7 +190,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
func (h *UsageHandler) SearchUsers(c *gin.Context) {
keyword := c.Query("q")
if keyword == "" {
response.Success(c, []interface{}{})
response.Success(c, []any{})
return
}
@@ -237,7 +234,7 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
userID = id
}
keys, err := h.apiKeyRepo.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil {
response.InternalError(c, "Failed to search API keys: "+err.Error())
return

View File

@@ -3,8 +3,8 @@ package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -25,6 +25,9 @@ func NewUserHandler(adminService service.AdminService) *UserHandler {
type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
AllowedGroups []int64 `json:"allowed_groups"`
@@ -35,6 +38,9 @@ type CreateUserRequest struct {
type UpdateUserRequest struct {
Email string `json:"email" binding:"omitempty,email"`
Password string `json:"password" binding:"omitempty,min=6"`
Username *string `json:"username"`
Wechat *string `json:"wechat"`
Notes *string `json:"notes"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
@@ -43,8 +49,9 @@ type UpdateUserRequest struct {
// UpdateBalanceRequest represents balance update request
type UpdateBalanceRequest struct {
Balance float64 `json:"balance" binding:"required"`
Balance float64 `json:"balance" binding:"required,gt=0"`
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
Notes string `json:"notes"`
}
// List handles listing all users with pagination
@@ -94,6 +101,9 @@ func (h *UserHandler) Create(c *gin.Context) {
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
Email: req.Email,
Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
AllowedGroups: req.AllowedGroups,
@@ -125,6 +135,9 @@ func (h *UserHandler) Update(c *gin.Context) {
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
Email: req.Email,
Password: req.Password,
Username: req.Username,
Wechat: req.Wechat,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
Status: req.Status,
@@ -171,7 +184,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return
}
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation)
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
if err != nil {
response.InternalError(c, "Failed to update balance: "+err.Error())
return

View File

@@ -3,10 +3,10 @@ package handler
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)

View File

@@ -1,9 +1,9 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)

View File

@@ -10,27 +10,21 @@ import (
"strings"
"time"
"sub2api/internal/middleware"
"sub2api/internal/model"
"sub2api/internal/pkg/claude"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// Maximum wait time for concurrency slot
maxConcurrencyWait = 60 * time.Second
// Ping interval during wait
pingInterval = 5 * time.Second
)
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
userService *service.UserService
concurrencyService *service.ConcurrencyService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
}
// NewGatewayHandler creates a new GatewayHandler
@@ -38,8 +32,8 @@ func NewGatewayHandler(gatewayService *service.GatewayService, userService *serv
return &GatewayHandler{
gatewayService: gatewayService,
userService: userService,
concurrencyService: concurrencyService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
}
}
@@ -89,7 +83,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyService.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
@@ -98,10 +92,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 确保在函数退出时减少wait计数
defer h.concurrencyService.DecrementWaitCount(c.Request.Context(), user.ID)
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.acquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -139,7 +133,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
@@ -173,133 +167,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}()
}
// acquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
}
// acquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
}
// concurrencyError represents a concurrency limit error with context
type concurrencyError struct {
SlotType string
IsTimeout bool
}
func (e *concurrencyError) Error() string {
if e.IsTimeout {
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
}
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests
// Note: For streaming requests, we send ping to keep the connection alive.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller)
func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
defer cancel()
// For streaming requests, set up SSE headers for ping
var flusher http.Flusher
if isStream {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported")
}
}
pingTicker := time.NewTicker(pingInterval)
defer pingTicker.Stop()
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
for {
select {
case <-ctx.Done():
return nil, &concurrencyError{
SlotType: slotType,
IsTimeout: true,
}
case <-pingTicker.C:
// Send ping for streaming requests to keep connection alive
if isStream && flusher != nil {
// Set headers on first ping (lazy initialization)
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
flusher.Flush()
}
case <-pollTicker.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
}
}
// Models handles listing available models
// GET /v1/models
// Returns different model lists based on the API key's group platform
func (h *GatewayHandler) Models(c *gin.Context) {
apiKey, _ := middleware.GetApiKeyFromContext(c)
// Return OpenAI models for OpenAI platform groups
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": openai.DefaultModels,
})
return
}
// Default: Claude models
c.JSON(http.StatusOK, gin.H{
"data": claude.DefaultModels,
"object": "list",
"data": claude.DefaultModels,
})
}
@@ -414,7 +300,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
if ok {
// Send error event in SSE format
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
fmt.Fprint(c.Writer, errorEvent)
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
@@ -574,11 +462,11 @@ func sendMockWarmupStream(c *gin.Context, model string) {
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) {
c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup",
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
"id": "msg_mock_warmup",
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
"stop_reason": "end_turn",
"usage": gin.H{
"input_tokens": 10,

View File

@@ -0,0 +1,180 @@
package handler
import (
"context"
"fmt"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
maxConcurrencyWait = 30 * time.Second
// pingInterval is the interval for sending ping events during slot wait
pingInterval = 15 * time.Second
)
// SSEPingFormat defines the format of SSE ping events for different platforms
type SSEPingFormat string
const (
// SSEPingFormatClaude is the Claude/Anthropic SSE ping format
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone SSEPingFormat = ""
)
// ConcurrencyError represents a concurrency limit error with context
type ConcurrencyError struct {
SlotType string
IsTimeout bool
}
func (e *ConcurrencyError) Error() string {
if e.IsTimeout {
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
}
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
}
// ConcurrencyHelper provides common concurrency slot management for gateway handlers
type ConcurrencyHelper struct {
concurrencyService *service.ConcurrencyService
pingFormat SSEPingFormat
}
// NewConcurrencyHelper creates a new ConcurrencyHelper
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
return &ConcurrencyHelper{
concurrencyService: concurrencyService,
pingFormat: pingFormat,
}
}
// IncrementWaitCount increments the wait count for a user
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
}
// DecrementWaitCount decrements the wait count for a user
func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) {
h.concurrencyService.DecrementWaitCount(ctx, userID)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
}
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
defer cancel()
// Determine if ping is needed (streaming + ping format defined)
needPing := isStream && h.pingFormat != ""
var flusher http.Flusher
if needPing {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported")
}
}
// Only create ping ticker if ping is needed
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
for {
select {
case <-ctx.Done():
return nil, &ConcurrencyError{
SlotType: slotType,
IsTimeout: true,
}
case <-pingCh:
// Send ping to keep connection alive
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
return nil, err
}
flusher.Flush()
case <-pollTicker.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
}
}

View File

@@ -1,7 +1,7 @@
package handler
import (
"sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
)
// AdminHandlers contains all admin-related HTTP handlers
@@ -11,6 +11,7 @@ type AdminHandlers struct {
Group *admin.GroupHandler
Account *admin.AccountHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Setting *admin.SettingHandler
@@ -21,15 +22,16 @@ type AdminHandlers struct {
// Handlers contains all HTTP handlers
type Handlers struct {
Auth *AuthHandler
User *UserHandler
APIKey *APIKeyHandler
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Admin *AdminHandlers
Gateway *GatewayHandler
Setting *SettingHandler
Auth *AuthHandler
User *UserHandler
APIKey *APIKeyHandler
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler
}
// BuildInfo contains build-time information

View File

@@ -0,0 +1,209 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// OpenAIGatewayHandler handles OpenAI API gateway requests
type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
) *OpenAIGatewayHandler {
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
}
}
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// Read request body
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// Parse request body to map for potential modification
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
// For non-Codex CLI requests, set default instructions
userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) {
reqBody["instructions"] = openai.DefaultInstructions
// Re-serialize body
body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
// Track if we've started streaming (for error handling)
streamStarted := false
// Get subscription info (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
// Ensure wait count is decremented when function exits
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
// 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return
}
// Generate session hash (from header for OpenAI)
sessionHash := h.gatewayService.GenerateSessionHash(c)
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if err != nil {
// Error response already handled in Forward, just log
log.Printf("Forward request failed: %v", err)
return
}
// Async record usage
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}()
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in OpenAI SSE format
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
}
return
}
// Normal case: return JSON response with proper status code
h.errorResponse(c, status, errType, message)
}
// errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -1,9 +1,9 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)

View File

@@ -1,8 +1,8 @@
package handler
import (
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)

View File

@@ -1,9 +1,9 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)

View File

@@ -4,12 +4,11 @@ import (
"strconv"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -17,15 +16,13 @@ import (
// UsageHandler handles usage-related requests
type UsageHandler struct {
usageService *service.UsageService
usageRepo *repository.UsageLogRepository
apiKeyService *service.ApiKeyService
}
// NewUsageHandler creates a new UsageHandler
func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.UsageLogRepository, apiKeyService *service.ApiKeyService) *UsageHandler {
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
return &UsageHandler{
usageService: usageService,
usageRepo: usageRepo,
apiKeyService: apiKeyService,
}
}
@@ -260,7 +257,7 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
return
}
stats, err := h.usageRepo.GetUserDashboardStats(c.Request.Context(), user.ID)
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get dashboard statistics")
return
@@ -287,7 +284,7 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
if err != nil {
response.InternalError(c, "Failed to get usage trend")
return
@@ -318,7 +315,7 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageRepo.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get model statistics")
return
@@ -358,7 +355,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
}
if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
@@ -383,11 +380,11 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
}
if len(validApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil {
response.InternalError(c, "Failed to get API key usage stats")
return

View File

@@ -1,9 +1,9 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -26,6 +26,12 @@ type ChangePasswordRequest struct {
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
Wechat *string `json:"wechat"`
}
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
@@ -47,6 +53,9 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
// 清空notes字段普通用户不应看到备注
userData.Notes = ""
response.Success(c, userData)
}
@@ -83,3 +92,40 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
response.Success(c, gin.H{"message": "Password changed successfully"})
}
// UpdateProfile handles updating user profile
// PUT /api/v1/users/me
func (h *UserHandler) UpdateProfile(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req UpdateProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.UpdateProfileRequest{
Username: req.Username,
Wechat: req.Wechat,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
if err != nil {
response.BadRequest(c, "Failed to update profile: "+err.Error())
return
}
// 清空notes字段普通用户不应看到备注
updatedUser.Notes = ""
response.Success(c, updatedUser)
}

View File

@@ -1,8 +1,8 @@
package handler
import (
"sub2api/internal/handler/admin"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire"
)
@@ -14,6 +14,7 @@ func ProvideAdminHandlers(
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
settingHandler *admin.SettingHandler,
@@ -27,6 +28,7 @@ func ProvideAdminHandlers(
Group: groupHandler,
Account: accountHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Setting: settingHandler,
@@ -56,18 +58,20 @@ func ProvideHandlers(
subscriptionHandler *SubscriptionHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
User: userHandler,
APIKey: apiKeyHandler,
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
Setting: settingHandler,
Auth: authHandler,
User: userHandler,
APIKey: apiKeyHandler,
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler,
}
}
@@ -81,6 +85,7 @@ var ProviderSet = wire.NewSet(
NewRedeemHandler,
NewSubscriptionHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
ProvideSettingHandler,
// Admin handlers
@@ -89,6 +94,7 @@ var ProviderSet = wire.NewSet(
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewSettingHandler,

View File

@@ -1,9 +1,9 @@
package infrastructure
import (
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"gorm.io/driver/postgres"
"gorm.io/gorm"

View File

@@ -1,7 +1,7 @@
package infrastructure
import (
"sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9"
)

View File

@@ -1,7 +1,7 @@
package infrastructure
import (
"sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"

View File

@@ -3,9 +3,9 @@ package middleware
import (
"context"
"crypto/subtle"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)

View File

@@ -1,7 +1,7 @@
package middleware
import (
"sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin"
)

View File

@@ -3,9 +3,9 @@ package middleware
import (
"context"
"errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"log"
"strings"
"sub2api/internal/model"
"github.com/gin-gonic/gin"
"gorm.io/gorm"

View File

@@ -2,9 +2,9 @@ package middleware
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)

View File

@@ -10,7 +10,7 @@ import (
)
// JSONB 用于存储JSONB数据
type JSONB map[string]interface{}
type JSONB map[string]any
func (j JSONB) Value() (driver.Value, error) {
if j == nil {
@@ -19,7 +19,7 @@ func (j JSONB) Value() (driver.Value, error) {
return json.Marshal(j)
}
func (j *JSONB) Scan(value interface{}) error {
func (j *JSONB) Scan(value any) error {
if value == nil {
*j = nil
return nil
@@ -40,8 +40,8 @@ type Account struct {
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
ProxyID *int64 `gorm:"index" json:"proxy_id"`
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
ErrorMessage string `gorm:"type:text" json:"error_message"`
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
@@ -68,7 +68,8 @@ type Account struct {
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
Groups []*Group `gorm:"-" json:"groups,omitempty"`
}
func (Account) TableName() string {
@@ -145,7 +146,7 @@ func (a *Account) GetModelMapping() map[string]string {
return nil
}
// 处理map[string]interface{}类型
if m, ok := raw.(map[string]interface{}); ok {
if m, ok := raw.(map[string]any); ok {
result := make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
@@ -163,7 +164,7 @@ func (a *Account) GetModelMapping() map[string]string {
// 如果没有设置模型映射,则支持所有模型
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
if len(mapping) == 0 {
return true // 没有映射配置,支持所有模型
}
_, exists := mapping[requestedModel]
@@ -174,7 +175,7 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
// 如果没有映射,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
if len(mapping) == 0 {
return requestedModel
}
if mappedModel, exists := mapping[requestedModel]; exists {
@@ -231,7 +232,7 @@ func (a *Account) GetCustomErrorCodes() []int {
return nil
}
// 处理 []interface{} 类型JSON反序列化后的格式
if arr, ok := raw.([]interface{}); ok {
if arr, ok := raw.([]any); ok {
result := make([]int, 0, len(arr))
for _, v := range arr {
// JSON 数字默认解析为 float64
@@ -277,3 +278,138 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
}
return false
}
// =============== OpenAI 相关方法 ===============
// IsOpenAI 检查是否为 OpenAI 平台账号
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}
// IsAnthropic 检查是否为 Anthropic 平台账号
func (a *Account) IsAnthropic() bool {
return a.Platform == PlatformAnthropic
}
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号Response 账号)
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
}
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
// 对于 API Key 类型账号,从 credentials 中获取 base_url
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
}
}
return "https://api.openai.com" // OpenAI 默认 API URL
}
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
func (a *Account) GetOpenAIAccessToken() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("access_token")
}
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
func (a *Account) GetOpenAIRefreshToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("refresh_token")
}
// GetOpenAIIDToken 获取 OpenAI ID TokenJWT包含用户信息
func (a *Account) GetOpenAIIDToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("id_token")
}
// GetOpenAIApiKey 获取 OpenAI API Key用于 Response 账号)
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
return ""
}
return a.GetCredential("api_key")
}
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
// 返回空字符串表示透传原始 User-Agent
func (a *Account) GetOpenAIUserAgent() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("user_agent")
}
// GetChatGPTAccountID 获取 ChatGPT 账号 ID从 ID Token 解析)
func (a *Account) GetChatGPTAccountID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_account_id")
}
// GetChatGPTUserID 获取 ChatGPT 用户 ID从 ID Token 解析)
func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_user_id")
}
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
func (a *Account) GetOpenAIOrganizationID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("organization_id")
}
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() {
return nil
}
expiresAtStr := a.GetCredential("expires_at")
if expiresAtStr == "" {
return nil
}
// 尝试解析时间
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
// 尝试解析为 Unix 时间戳
if v, ok := a.Credentials["expires_at"].(float64); ok {
t = time.Unix(int64(v), 0)
return &t
}
return nil
}
return &t
}
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
func (a *Account) IsOpenAITokenExpired() bool {
expiresAt := a.GetOpenAITokenExpiresAt()
if expiresAt == nil {
return false // 没有过期时间信息,假设未过期
}
// 提前 60 秒认为过期,便于刷新
return time.Now().Add(60 * time.Second).After(*expiresAt)
}

View File

@@ -13,13 +13,13 @@ const (
)
type Group struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
// 订阅功能字段
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription

View File

@@ -9,15 +9,16 @@ import (
type RedeemCode struct {
ID int64 `gorm:"primaryKey" json:"id"`
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
// 关联
@@ -40,8 +41,10 @@ func (r *RedeemCode) CanUse() bool {
}
// GenerateRedeemCode 生成唯一的兑换码
func GenerateRedeemCode() string {
func GenerateRedeemCode() (string, error) {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}

View File

@@ -19,17 +19,17 @@ func (Setting) TableName() string {
// 设置Key常量
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
@@ -42,6 +42,7 @@ const (
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
@@ -80,6 +81,7 @@ type SystemSettings struct {
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
@@ -97,5 +99,6 @@ type PublicSettings struct {
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}

View File

@@ -37,7 +37,7 @@ type UsageLog struct {
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率

View File

@@ -9,8 +9,11 @@ import (
)
type User struct {
ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
Username string `gorm:"size:100;default:''" json:"username"`
Wechat string `gorm:"size:100;default:''" json:"wechat"`
Notes string `gorm:"type:text;default:''" json:"notes"`
PasswordHash string `gorm:"size:255;not null" json:"-"`
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
@@ -22,7 +25,8 @@ type User struct {
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
}
func (User) TableName() string {

View File

@@ -43,18 +43,25 @@ type OAuthSession struct {
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
@@ -87,14 +94,20 @@ func (s *SessionStore) Delete(sessionID string) {
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
s.mu.Unlock()
}
}

View File

@@ -0,0 +1,42 @@
package openai
import _ "embed"
// Model represents an OpenAI model
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
}
// DefaultModels OpenAI models list
var DefaultModels = []Model{
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
}
// DefaultModelIDs returns the default model ID list
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel default model for testing OpenAI accounts
const DefaultTestModel = "gpt-5.1-codex"
// DefaultInstructions default instructions for non-Codex CLI requests
// Content loaded from instructions.txt at compile time
//
//go:embed instructions.txt
var DefaultInstructions string

View File

@@ -0,0 +1,118 @@
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
## General
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
## Editing constraints
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like \"Assigns the value to the variable\", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
- You may be in a dirty git worktree.
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
* If the changes are in unrelated files, just ignore them and don't revert them.
- Do not amend a commit unless explicitly requested to do so.
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
## Plan tool
When using the planning tool:
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
- Do not make single-step plans.
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
## Codex CLI harness, sandboxing, and approvals
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
- **read-only**: The sandbox only permits reading files.
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
- **restricted**: Requires approval
- **enabled**: No approval needed
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
- (for all of these, you should weigh alternative paths that do not require approval)
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to \"never\", in which case never ask for approvals.
When requesting approval to execute a command that will require escalated privileges:
- Provide the `sandbox_permissions` parameter with the value `\"require_escalated\"`
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
## Special user requests
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
- If the user asks for a \"review\", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
## Frontend tasks
When doing frontend design tasks, avoid collapsing into \"AI slop\" or safe, average-looking layouts.
Aim for interfaces that feel intentional, bold, and a bit surprising.
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
- Ensure the page loads properly on both desktop and mobile
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
## Presenting your work and final message
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
- Default: be very concise; friendly coding teammate tone.
- Ask only when needed; suggest ideas; mirror the user's style.
- For substantial work, summarize clearly; follow finalanswer formatting.
- Skip heavy formatting for simple confirmations.
- Don't dump large files you've written; reference paths only.
- No \"save/copy this file\" - User is on the same machine.
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
- For code changes:
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with \"summary\", just jump right in.
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
### Final answer structure and style guidelines
- Plain text; CLI handles styling. Use structure only when it helps scanability.
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
- Bullets: use - ; merge related points; keep to one line when possible; 46 per list ordered by importance; keep phrasing consistent.
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
- Tone: collaborative, concise, factual; present tense, active voice; selfcontained; no \"above/below\"; parallel wording.
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
- File References: When referencing files in your response follow the below rules:
* Use inline code to make file paths clickable.
* Each reference should have a stand alone path. Even if it's the same file.
* Accepted: absolute, workspacerelative, a/ or b/ diff prefixes, or bare filename/suffix.
* Optionally include line/column (1based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
* Do not use URIs like file://, vscode://, or https://.
* Do not provide range of lines
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\\repo\\project\\main.rs:12:5

View File

@@ -0,0 +1,366 @@
package openai
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
// OpenAI OAuth Constants (from CRS project - Codex CLI client)
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
TokenURL = "https://auth.openai.com/oauth/token"
// Default redirect URI (can be customized)
DefaultRedirectURI = "http://localhost:1455/auth/callback"
// Scopes
DefaultScopes = "openid profile email offline_access"
// RefreshScopes - scope for token refresh (without offline_access, aligned with CRS project)
RefreshScopes = "openid profile email"
// Session TTL
SessionTTL = 30 * time.Minute
)
// OAuthSession stores OAuth flow state for OpenAI
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages OAuth sessions in memory
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
// Get retrieves a session
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
// Delete removes a session
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
}
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
// GenerateRandomBytes generates cryptographically secure random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateState generates a random state string for OAuth
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateSessionID generates a unique session ID
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (64 bytes -> hex for OpenAI)
// OpenAI uses hex encoding instead of base64url
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(64)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
// Uses base64url encoding as per RFC 7636
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", DefaultScopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
// OpenAI specific parameters
params.Set("id_token_add_organizations", "true")
params.Set("codex_cli_simplified_flow", "true")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
}
// TokenResponse represents the token response from OpenAI OAuth
type TokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
}
// IDTokenClaims represents the claims from OpenAI ID Token
type IDTokenClaims struct {
// Standard claims
Sub string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Iss string `json:"iss"`
Aud []string `json:"aud"` // OpenAI returns aud as an array
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
// OpenAI specific claims (nested under https://api.openai.com/auth)
OpenAIAuth *OpenAIAuthClaims `json:"https://api.openai.com/auth,omitempty"`
}
// OpenAIAuthClaims represents the OpenAI specific auth claims
type OpenAIAuthClaims struct {
ChatGPTAccountID string `json:"chatgpt_account_id"`
ChatGPTUserID string `json:"chatgpt_user_id"`
UserID string `json:"user_id"`
Organizations []OrganizationClaim `json:"organizations"`
}
// OrganizationClaim represents an organization in the ID Token
type OrganizationClaim struct {
ID string `json:"id"`
Role string `json:"role"`
Title string `json:"title"`
IsDefault bool `json:"is_default"`
}
// BuildTokenRequest creates a token exchange request for OpenAI
func BuildTokenRequest(code, codeVerifier, redirectURI string) *TokenRequest {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: redirectURI,
CodeVerifier: codeVerifier,
}
}
// BuildRefreshTokenRequest creates a refresh token request for OpenAI
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
Scope: RefreshScopes,
}
}
// ToFormData converts TokenRequest to URL-encoded form data
func (r *TokenRequest) ToFormData() string {
params := url.Values{}
params.Set("grant_type", r.GrantType)
params.Set("client_id", r.ClientID)
params.Set("code", r.Code)
params.Set("redirect_uri", r.RedirectURI)
params.Set("code_verifier", r.CodeVerifier)
return params.Encode()
}
// ToFormData converts RefreshTokenRequest to URL-encoded form data
func (r *RefreshTokenRequest) ToFormData() string {
params := url.Values{}
params.Set("grant_type", r.GrantType)
params.Set("client_id", r.ClientID)
params.Set("refresh_token", r.RefreshToken)
params.Set("scope", r.Scope)
return params.Encode()
}
// ParseIDToken parses the ID Token JWT and extracts claims
// Note: This does NOT verify the signature - it only decodes the payload
// For production, you should verify the token signature using OpenAI's public keys
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
// Decode payload (second part)
payload := parts[1]
// Add padding if necessary
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
// Try standard encoding
decoded, err = base64.StdEncoding.DecodeString(payload)
if err != nil {
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
}
}
var claims IDTokenClaims
if err := json.Unmarshal(decoded, &claims); err != nil {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
return &claims, nil
}
// ExtractUserInfo extracts user information from ID Token claims
type UserInfo struct {
Email string
ChatGPTAccountID string
ChatGPTUserID string
UserID string
OrganizationID string
Organizations []OrganizationClaim
}
// GetUserInfo extracts user info from ID Token claims
func (c *IDTokenClaims) GetUserInfo() *UserInfo {
info := &UserInfo{
Email: c.Email,
}
if c.OpenAIAuth != nil {
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
info.UserID = c.OpenAIAuth.UserID
info.Organizations = c.OpenAIAuth.Organizations
// Get default organization ID
for _, org := range c.OpenAIAuth.Organizations {
if org.IsDefault {
info.OrganizationID = org.ID
break
}
}
// If no default, use first org
if info.OrganizationID == "" && len(c.OpenAIAuth.Organizations) > 0 {
info.OrganizationID = c.OpenAIAuth.Organizations[0].ID
}
}
return info
}

View File

@@ -0,0 +1,18 @@
package openai
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
var CodexCLIUserAgentPrefixes = []string{
"codex_vscode/",
"codex_cli_rs/",
}
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
func IsCodexCLIRequest(userAgent string) bool {
for _, prefix := range CodexCLIUserAgentPrefixes {
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
return true
}
}
return false
}

View File

@@ -9,22 +9,22 @@ import (
// Response 标准API响应格式
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
}
// PaginatedData 分页数据格式(匹配前端期望)
type PaginatedData struct {
Items interface{} `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Pages int `json:"pages"`
Items any `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Pages int `json:"pages"`
}
// Success 返回成功响应
func Success(c *gin.Context, data interface{}) {
func Success(c *gin.Context, data any) {
c.JSON(http.StatusOK, Response{
Code: 0,
Message: "success",
@@ -33,7 +33,7 @@ func Success(c *gin.Context, data interface{}) {
}
// Created 返回创建成功响应
func Created(c *gin.Context, data interface{}) {
func Created(c *gin.Context, data any) {
c.JSON(http.StatusCreated, Response{
Code: 0,
Message: "success",
@@ -75,7 +75,7 @@ func InternalError(c *gin.Context, message string) {
}
// Paginated 返回分页数据
func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize int) {
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 {
pages = 1
@@ -99,7 +99,7 @@ type PaginationResult struct {
}
// PaginatedWithResult 使用PaginationResult返回分页数据
func PaginatedWithResult(c *gin.Context, items interface{}, pagination *PaginationResult) {
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
if pagination == nil {
Success(c, PaginatedData{
Items: items,

View File

@@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) {
func TestTimeNowAffected(t *testing.T) {
// Reset to UTC first
Init("UTC")
if err := Init("UTC"); err != nil {
t.Fatalf("Init failed with UTC: %v", err)
}
utcNow := time.Now()
// Switch to Shanghai (UTC+8)
Init("Asia/Shanghai")
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
shanghaiNow := time.Now()
// The times should be the same instant, but different timezone representation
@@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) {
}
func TestToday(t *testing.T) {
Init("Asia/Shanghai")
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
today := Today()
now := Now()
@@ -75,7 +81,9 @@ func TestToday(t *testing.T) {
}
func TestStartOfDay(t *testing.T) {
Init("Asia/Shanghai")
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
// Create a time at 15:30:45
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
@@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) {
// This test demonstrates why Truncate(24*time.Hour) can be problematic
// and why StartOfDay is more reliable for timezone-aware code
Init("Asia/Shanghai")
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
now := Now()

View File

@@ -0,0 +1,209 @@
package usagestats
import "time"
// DashboardStats 仪表盘统计
type DashboardStats struct {
// 用户统计
TotalUsers int64 `json:"total_users"`
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 系统运行统计
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// TrendDataPoint represents a single point in trend data
type TrendDataPoint struct {
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ModelStat represents usage statistics for a single model
type ModelStat struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct {
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct {
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 性能统计
AverageDurationMs float64 `json:"average_duration_ms"`
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct {
UserID int64
ApiKeyID int64
StartTime *time.Time
EndTime *time.Time
}
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct {
ApiKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
ActualCost float64 `json:"actual_cost"`
}
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary struct {
Days int `json:"days"`
ActualDaysUsed int `json:"actual_days_used"`
TotalCost float64 `json:"total_cost"`
TotalStandardCost float64 `json:"total_standard_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
AvgDailyCost float64 `json:"avg_daily_cost"`
AvgDailyRequests float64 `json:"avg_daily_requests"`
AvgDailyTokens float64 `json:"avg_daily_tokens"`
AvgDurationMs float64 `json:"avg_duration_ms"`
Today *struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
} `json:"today"`
HighestCostDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
} `json:"highest_cost_day"`
HighestRequestDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
} `json:"highest_request_day"`
}
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse struct {
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
}

View File

@@ -2,11 +2,14 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type AccountRepository struct {
@@ -23,14 +26,34 @@ func (r *AccountRepository) Create(ctx context.Context, account *model.Account)
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups").First(&account, id).Error
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
if err != nil {
return nil, err
}
// 填充 GroupIDs 虚拟字段
// 填充 GroupIDs 和 Groups 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
return &account, nil
}
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
if crsAccountID == "" {
return nil, nil
}
var account model.Account
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &account, nil
}
@@ -78,15 +101,19 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
return nil, nil, err
}
if err := db.Preload("Proxy").Preload("AccountGroups").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
if err := db.Preload("Proxy").Preload("AccountGroups.Group").Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&accounts).Error; err != nil {
return nil, nil, err
}
// 填充每个 Account 的 GroupIDs 虚拟字段
// 填充每个 Account 的虚拟字段(GroupIDs 和 Groups
for i := range accounts {
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
for _, ag := range accounts[i].AccountGroups {
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
if ag.Group != nil {
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
}
}
}
@@ -131,7 +158,7 @@ func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
Updates(map[string]any{
"status": model.StatusError,
"error_message": errorMsg,
}).Error
@@ -222,11 +249,43 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
return accounts, err
}
// ListSchedulableByPlatform 按平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Where("platform = ?", platform).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
}
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.platform = ?", platform).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
}
// SetRateLimited 标记账号为限流状态(429)
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
Updates(map[string]any{
"rate_limited_at": now,
"rate_limit_reset_at": resetAt,
}).Error
@@ -241,7 +300,7 @@ func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until t
// ClearRateLimit 清除账号的限流状态
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
Updates(map[string]any{
"rate_limited_at": nil,
"rate_limit_reset_at": nil,
"overload_until": nil,
@@ -250,7 +309,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]interface{}{
updates := map[string]any{
"session_window_status": status,
}
if start != nil {
@@ -267,3 +326,75 @@ func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("schedulable", schedulable).Error
}
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
}
// Get current account to preserve existing Extra data
var account model.Account
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
return err
}
// Initialize Extra if nil
if account.Extra == nil {
account.Extra = make(model.JSONB)
}
// Merge updates into existing Extra
for k, v := range updates {
account.Extra[k] = v
}
// Save updated Extra
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("extra", account.Extra).Error
}
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ports.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
updateMap := map[string]any{}
if updates.Name != nil {
updateMap["name"] = *updates.Name
}
if updates.ProxyID != nil {
updateMap["proxy_id"] = updates.ProxyID
}
if updates.Concurrency != nil {
updateMap["concurrency"] = *updates.Concurrency
}
if updates.Priority != nil {
updateMap["priority"] = *updates.Priority
}
if updates.Status != nil {
updateMap["status"] = *updates.Status
}
if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
}
if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
}
if len(updateMap) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).
Model(&model.Account{}).
Where("id IN ?", ids).
Clauses(clause.Returning{}).
Updates(updateMap)
return result.RowsAffected, result.Error
}

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)

View File

@@ -2,8 +2,8 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)

View File

@@ -8,7 +8,7 @@ import (
"strconv"
"time"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
@@ -143,7 +143,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]interface{}{
fields := map[string]any{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,

View File

@@ -7,10 +7,11 @@ import (
"log"
"net/http"
"net/url"
"strings"
"time"
"sub2api/internal/pkg/oauth"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
)
@@ -64,7 +65,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
reqBody := map[string]interface{}{
reqBody := map[string]any{
"response_type": "code",
"client_id": oauth.ClientID,
"organization_uuid": orgUUID,
@@ -139,23 +140,15 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
// Parse code which may contain state in format "authCode#state"
authCode := code
codeState := ""
if len(code) > 0 {
parts := make([]string, 0, 2)
for i, part := range []rune(code) {
if part == '#' {
authCode = code[:i]
codeState = code[i+1:]
break
}
}
if len(parts) == 0 {
authCode = code
}
if idx := strings.Index(code, "#"); idx != -1 {
authCode = code[:idx]
codeState = code[idx+1:]
}
reqBody := map[string]interface{}{
reqBody := map[string]any{
"code": authCode,
"grant_type": "authorization_code",
"client_id": oauth.ClientID,

View File

@@ -9,7 +9,7 @@ import (
"net/url"
"time"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type claudeUsageService struct{}
@@ -19,7 +19,11 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("failed to get default transport")
}
transport = transport.Clone()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
@@ -43,7 +47,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)

View File

@@ -5,60 +5,95 @@ import (
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
accountConcurrencyKeyPrefix = "concurrency:account:"
userConcurrencyKeyPrefix = "concurrency:user:"
waitQueueKeyPrefix = "concurrency:wait:"
concurrencyTTL = 5 * time.Minute
// Key prefixes for independent slot keys
// Format: concurrency:account:{accountID}:{requestID}
accountSlotKeyPrefix = "concurrency:account:"
// Format: concurrency:user:{userID}:{requestID}
userSlotKeyPrefix = "concurrency:user:"
// Wait queue keeps counter format: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// Slot TTL - each slot expires independently
slotTTL = 5 * time.Minute
)
var (
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
// ARGV[1] = maxConcurrency
// ARGV[2] = TTL in seconds
acquireScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current < tonumber(ARGV[1]) then
redis.call('INCR', KEYS[1])
redis.call('EXPIRE', KEYS[1], ARGV[2])
local pattern = KEYS[1]
local slotKey = KEYS[2]
local maxConcurrency = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
-- Count existing slots using SCAN
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
-- Check if we can acquire a slot
if count < maxConcurrency then
redis.call('SET', slotKey, '1', 'EX', ttl)
return 1
end
return 0
`)
releaseScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
// getCountScript counts slots using SCAN
// KEYS[1] = pattern for SCAN
getCountScript = redis.NewScript(`
local pattern = KEYS[1]
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
return count
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
incrementWaitScript = redis.NewScript(`
local waitKey = KEYS[1]
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= maxWait then
if current >= tonumber(ARGV[1]) then
return 0
end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl)
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
@@ -76,49 +111,86 @@ func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
return &concurrencyCache{rdb: rdb}
}
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
// Helper functions for key generation
func accountSlotKey(accountID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
}
func accountSlotPattern(accountID int64) string {
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
}
func userSlotKey(userID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
}
func userSlotPattern(userID int64) string {
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
}
func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := accountSlotPattern(accountID)
slotKey := accountSlotKey(accountID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
return err
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
slotKey := accountSlotKey(accountID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
}
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
return c.rdb.Get(ctx, key).Int()
pattern := accountSlotPattern(accountID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
if err != nil {
return 0, err
}
return result, nil
}
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := userSlotPattern(userID)
slotKey := userSlotKey(userID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
return err
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
slotKey := userSlotKey(userID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
}
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
pattern := userSlotPattern(userID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
if err != nil {
return 0, err
}
return result, nil
}
// Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int()
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
if err != nil {
return false, err
}
@@ -126,7 +198,7 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
key := waitQueueKey(userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}

View File

@@ -5,7 +5,7 @@ import (
"encoding/json"
"time"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)

View File

@@ -4,7 +4,7 @@ import (
"context"
"time"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)

View File

@@ -9,7 +9,7 @@ import (
"os"
"time"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type githubReleaseClient struct {
@@ -38,7 +38,7 @@ func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo strin
if err != nil {
return nil, err
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
@@ -63,7 +63,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
if err != nil {
return err
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode)
@@ -78,7 +78,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
if err != nil {
return err
}
defer out.Close()
defer func() { _ = out.Close() }()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxSize+1)
@@ -89,7 +89,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
// Check if we hit the limit (downloaded more than maxSize)
if written > maxSize {
os.Remove(dest) // Clean up partial file
_ = os.Remove(dest) // Clean up partial file (best-effort)
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
}
@@ -106,7 +106,7 @@ func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string)
if err != nil {
return nil, err
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)

View File

@@ -2,8 +2,8 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)

View File

@@ -5,16 +5,19 @@ import (
"net/url"
"time"
"sub2api/internal/config"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
type claudeUpstreamService struct {
// httpUpstreamService is a generic HTTP upstream service that can be used for
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support.
type httpUpstreamService struct {
defaultClient *http.Client
cfg *config.Config
}
func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
// NewHTTPUpstream creates a new generic HTTP upstream service
func NewHTTPUpstream(cfg *config.Config) ports.HTTPUpstream {
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
@@ -27,13 +30,13 @@ func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
ResponseHeaderTimeout: responseHeaderTimeout,
}
return &claudeUpstreamService{
return &httpUpstreamService{
defaultClient: &http.Client{Transport: transport},
cfg: cfg,
}
}
func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
if proxyURL == "" {
return s.defaultClient.Do(req)
}
@@ -41,7 +44,7 @@ func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Re
return client.Do(req)
}
func (s *claudeUpstreamService) createProxyClient(proxyURL string) *http.Client {
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return s.defaultClient

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)

View File

@@ -0,0 +1,92 @@
package repository
import (
"context"
"fmt"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/imroc/req/v3"
)
type openaiOAuthService struct{}
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
return &openaiOAuthService{}
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
}
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("client_id", openai.ClientID)
formData.Set("code", code)
formData.Set("redirect_uri", redirectURI)
formData.Set("code_verifier", codeVerifier)
var tokenResp openai.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", openai.ClientID)
formData.Set("scope", openai.RefreshScopes)
var tokenResp openai.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func createOpenAIReqClient(proxyURL string) *req.Client {
client := req.C().
SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}

View File

@@ -8,7 +8,7 @@ import (
"strings"
"time"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type pricingRemoteClient struct {
@@ -33,7 +33,7 @@ func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string)
if err != nil {
return nil, err
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
@@ -52,7 +52,7 @@ func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (st
if err != nil {
return "", err
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)

View File

@@ -11,7 +11,7 @@ import (
"net/url"
"time"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"golang.org/x/net/proxy"
)
@@ -43,7 +43,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
if err != nil {
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
latencyMs := time.Since(startTime).Milliseconds()

View File

@@ -2,8 +2,8 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)

View File

@@ -2,8 +2,8 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"time"
"gorm.io/gorm"
@@ -99,7 +99,7 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
Where("id = ? AND status = ?", id, model.StatusUnused).
Updates(map[string]interface{}{
Updates(map[string]any{
"status": model.StatusUsed,
"used_by": userID,
"used_at": now,

View File

@@ -2,7 +2,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/model"
"time"
"gorm.io/gorm"

View File

@@ -9,7 +9,7 @@ import (
"strings"
"time"
"sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
@@ -44,7 +44,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
var result service.TurnstileVerifyResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {

View File

@@ -4,7 +4,7 @@ import (
"context"
"time"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)

View File

@@ -2,10 +2,10 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone"
"sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"time"
"gorm.io/gorm"
@@ -19,6 +19,30 @@ func NewUsageLogRepository(db *gorm.DB) *UsageLogRepository {
return &UsageLogRepository{db: db}
}
// getPerformanceStats 获取 RPM 和 TPM近5分钟平均值可选按用户过滤
func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
var perfStats struct {
RequestCount int64 `gorm:"column:request_count"`
TokenCount int64 `gorm:"column:token_count"`
}
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
`).
Where("created_at >= ?", fiveMinutesAgo)
if userID > 0 {
db = db.Where("user_id = ?", userID)
}
db.Scan(&perfStats)
// 返回5分钟平均值
return perfStats.RequestCount / 5, perfStats.TokenCount / 5
}
func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error
}
@@ -113,46 +137,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
}
// DashboardStats 仪表盘统计
type DashboardStats struct {
// 用户统计
TotalUsers int64 `json:"total_users"`
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 系统运行统计
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
}
type DashboardStats = usagestats.DashboardStats
func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
var stats DashboardStats
@@ -269,6 +254,9 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
stats.TodayCost = todayStats.TodayCost
stats.TodayActualCost = todayStats.TodayActualCost
// 性能指标RPM 和 TPM最近1分钟全局
stats.Rpm, stats.Tpm = r.getPerformanceStats(ctx, 0)
return &stats, nil
}
@@ -398,47 +386,16 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
}
// TrendDataPoint represents a single point in trend data
type TrendDataPoint struct {
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
type TrendDataPoint = usagestats.TrendDataPoint
// ModelStat represents usage statistics for a single model
type ModelStat struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
type ModelStat = usagestats.ModelStat
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct {
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
@@ -531,34 +488,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
}
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct {
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 性能统计
AverageDurationMs float64 `json:"average_duration_ms"`
}
type UserDashboardStats = usagestats.UserDashboardStats
// GetUserDashboardStats 获取用户专属的仪表盘统计
func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
@@ -641,6 +571,9 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
stats.TodayCost = todayStats.TodayCost
stats.TodayActualCost = todayStats.TodayActualCost
// 性能指标RPM 和 TPM最近1分钟仅统计该用户的请求
stats.Rpm, stats.Tpm = r.getPerformanceStats(ctx, userID)
return &stats, nil
}
@@ -705,12 +638,7 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64
}
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct {
UserID int64
ApiKeyID int64
StartTime *time.Time
EndTime *time.Time
}
type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin)
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
@@ -758,23 +686,10 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params paginat
}
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
type BatchUserUsageStats = usagestats.BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
@@ -834,11 +749,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct {
ApiKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
@@ -937,7 +848,7 @@ func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]ModelStat, error) {
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
@@ -958,6 +869,9 @@ func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
if apiKeyID > 0 {
db = db.Where("api_key_id = ?", apiKeyID)
}
if accountID > 0 {
db = db.Where("account_id = ?", accountID)
}
err := db.Group("model").Order("total_tokens DESC").Scan(&results).Error
if err != nil {
@@ -1007,3 +921,169 @@ func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
AverageDurationMs: stats.AverageDurationMs,
}, nil
}
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory = usagestats.AccountUsageHistory
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary = usagestats.AccountUsageSummary
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
if daysCount <= 0 {
daysCount = 30
}
// Get daily history
var historyResults []struct {
Date string `gorm:"column:date"`
Requests int64 `gorm:"column:requests"`
Tokens int64 `gorm:"column:tokens"`
Cost float64 `gorm:"column:cost"`
ActualCost float64 `gorm:"column:actual_cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Group("date").
Order("date ASC").
Scan(&historyResults).Error
if err != nil {
return nil, err
}
// Build history with labels
history := make([]AccountUsageHistory, 0, len(historyResults))
for _, h := range historyResults {
// Parse date to get label (MM/DD)
t, _ := time.Parse("2006-01-02", h.Date)
label := t.Format("01/02")
history = append(history, AccountUsageHistory{
Date: h.Date,
Label: label,
Requests: h.Requests,
Tokens: h.Tokens,
Cost: h.Cost,
ActualCost: h.ActualCost,
})
}
// Calculate summary
var totalActualCost, totalStandardCost float64
var totalRequests, totalTokens int64
var highestCostDay, highestRequestDay *AccountUsageHistory
for i := range history {
h := &history[i]
totalActualCost += h.ActualCost
totalStandardCost += h.Cost
totalRequests += h.Requests
totalTokens += h.Tokens
if highestCostDay == nil || h.ActualCost > highestCostDay.ActualCost {
highestCostDay = h
}
if highestRequestDay == nil || h.Requests > highestRequestDay.Requests {
highestRequestDay = h
}
}
actualDaysUsed := len(history)
if actualDaysUsed == 0 {
actualDaysUsed = 1
}
// Get average duration
var avgDuration struct {
AvgDurationMs float64 `gorm:"column:avg_duration_ms"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms").
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Scan(&avgDuration)
summary := AccountUsageSummary{
Days: daysCount,
ActualDaysUsed: actualDaysUsed,
TotalCost: totalActualCost,
TotalStandardCost: totalStandardCost,
TotalRequests: totalRequests,
TotalTokens: totalTokens,
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
AvgDurationMs: avgDuration.AvgDurationMs,
}
// Set today's stats
todayStr := timezone.Now().Format("2006-01-02")
for i := range history {
if history[i].Date == todayStr {
summary.Today = &struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}{
Date: history[i].Date,
Cost: history[i].ActualCost,
Requests: history[i].Requests,
Tokens: history[i].Tokens,
}
break
}
}
// Set highest cost day
if highestCostDay != nil {
summary.HighestCostDay = &struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
}{
Date: highestCostDay.Date,
Label: highestCostDay.Label,
Cost: highestCostDay.ActualCost,
Requests: highestCostDay.Requests,
}
}
// Set highest request day
if highestRequestDay != nil {
summary.HighestRequestDay = &struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
}{
Date: highestRequestDay.Date,
Label: highestRequestDay.Label,
Requests: highestRequestDay.Requests,
Cost: highestRequestDay.ActualCost,
}
}
// Get model statistics using the unified method
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
if err != nil {
models = []ModelStat{}
}
return &AccountUsageStatsResponse{
History: history,
Summary: summary,
Models: models,
}, nil
}

View File

@@ -2,8 +2,8 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -66,17 +66,47 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.
}
if search != "" {
searchPattern := "%" + search + "%"
db = db.Where("email ILIKE ?", searchPattern)
db = db.Where(
"email ILIKE ? OR username ILIKE ? OR wechat ILIKE ?",
searchPattern, searchPattern, searchPattern,
)
}
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
}
// Query users with pagination (reuse the same db with filters applied)
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil {
return nil, nil, err
}
// Batch load subscriptions for all users (avoid N+1)
if len(users) > 0 {
userIDs := make([]int64, len(users))
userMap := make(map[int64]*model.User, len(users))
for i := range users {
userIDs[i] = users[i].ID
userMap[users[i].ID] = &users[i]
}
// Query active subscriptions with groups in one query
var subscriptions []model.UserSubscription
if err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive).
Find(&subscriptions).Error; err != nil {
return nil, nil, err
}
// Associate subscriptions with users
for i := range subscriptions {
if user, ok := userMap[subscriptions[i].UserID]; ok {
user.Subscriptions = append(user.Subscriptions, subscriptions[i])
}
}
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++

View File

@@ -4,8 +4,8 @@ import (
"context"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -185,7 +185,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
Updates(map[string]any{
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
"weekly_usage_usd": gorm.Expr("weekly_usage_usd + ?", costUSD),
"monthly_usage_usd": gorm.Expr("monthly_usage_usd + ?", costUSD),
@@ -197,7 +197,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
Updates(map[string]any{
"daily_usage_usd": 0,
"daily_window_start": newWindowStart,
"updated_at": time.Now(),
@@ -208,7 +208,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
Updates(map[string]any{
"weekly_usage_usd": 0,
"weekly_window_start": newWindowStart,
"updated_at": time.Now(),
@@ -219,7 +219,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
Updates(map[string]any{
"monthly_usage_usd": 0,
"monthly_window_start": newWindowStart,
"updated_at": time.Now(),
@@ -230,7 +230,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
Updates(map[string]any{
"daily_window_start": activateTime,
"weekly_window_start": activateTime,
"monthly_window_start": activateTime,
@@ -242,7 +242,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
Updates(map[string]any{
"status": status,
"updated_at": time.Now(),
}).Error
@@ -252,7 +252,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
Updates(map[string]any{
"expires_at": newExpiresAt,
"updated_at": time.Now(),
}).Error
@@ -262,7 +262,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]interface{}{
Updates(map[string]any{
"notes": notes,
"updated_at": time.Now(),
}).Error
@@ -281,7 +281,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Updates(map[string]interface{}{
Updates(map[string]any{
"status": model.SubscriptionStatusExpired,
"updated_at": time.Now(),
})

View File

@@ -1,7 +1,7 @@
package repository
import (
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/google/wire"
)
@@ -36,7 +36,8 @@ var ProviderSet = wire.NewSet(
NewProxyExitInfoProber,
NewClaudeUsageFetcher,
NewClaudeOAuthClient,
NewClaudeUpstream,
NewHTTPUpstream,
NewOpenAIOAuthClient,
// Bind concrete repositories to service port interfaces
wire.Bind(new(ports.UserRepository), new(*UserRepository)),

View File

@@ -1,11 +1,11 @@
package server
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/repository"
"sub2api/internal/service"
"time"
"github.com/gin-gonic/gin"

View File

@@ -1,13 +1,13 @@
package server
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/web"
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/middleware"
"sub2api/internal/repository"
"sub2api/internal/service"
"sub2api/internal/web"
"github.com/gin-gonic/gin"
)
@@ -82,6 +82,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
}
// API Key管理
@@ -179,6 +180,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
@@ -191,8 +193,10 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// OAuth routes
// Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
@@ -201,6 +205,16 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
// OpenAI OAuth routes
openai := admin.Group("/openai")
{
openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
// 代理管理
proxies := admin.Group("/proxies")
{
@@ -289,5 +303,10 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription), h.OpenAIGateway.Responses)
}

View File

@@ -4,9 +4,9 @@ import (
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"gorm.io/gorm"
)
@@ -17,27 +17,27 @@ var (
// CreateAccountRequest 创建账号请求
type CreateAccountRequest struct {
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]interface{} `json:"credentials"`
Extra map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
}
// UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct {
Name *string `json:"name"`
Credentials *map[string]interface{} `json:"credentials"`
Extra *map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"`
Name *string `json:"name"`
Credentials *map[string]any `json:"credentials"`
Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"`
}
// AccountService 账号管理服务

View File

@@ -14,15 +14,19 @@ import (
"strings"
"time"
"sub2api/internal/pkg/claude"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testOpenAIAPIURL = "https://api.openai.com/v1/responses"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
)
// TestEvent represents a SSE event for account testing
@@ -36,37 +40,46 @@ type TestEvent struct {
// AccountTestService handles account testing operations
type AccountTestService struct {
accountRepo ports.AccountRepository
oauthService *OAuthService
claudeUpstream ClaudeUpstream
accountRepo ports.AccountRepository
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
httpUpstream ports.HTTPUpstream
}
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, claudeUpstream ClaudeUpstream) *AccountTestService {
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream ports.HTTPUpstream) *AccountTestService {
return &AccountTestService{
accountRepo: accountRepo,
oauthService: oauthService,
claudeUpstream: claudeUpstream,
accountRepo: accountRepo,
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
httpUpstream: httpUpstream,
}
}
// generateSessionString generates a Claude Code style session string
func generateSessionString() string {
func generateSessionString() (string, error) {
bytes := make([]byte, 32)
rand.Read(bytes)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
hex64 := hex.EncodeToString(bytes)
sessionUUID := uuid.New().String()
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
}
// createTestPayload creates a Claude Code style test request payload
func createTestPayload(modelID string) map[string]interface{} {
return map[string]interface{}{
func createTestPayload(modelID string) (map[string]any, error) {
sessionID, err := generateSessionString()
if err != nil {
return nil, err
}
return map[string]any{
"model": modelID,
"messages": []map[string]interface{}{
"messages": []map[string]any{
{
"role": "user",
"content": []map[string]interface{}{
"content": []map[string]any{
{
"type": "text",
"text": "hi",
@@ -77,7 +90,7 @@ func createTestPayload(modelID string) map[string]interface{} {
},
},
},
"system": []map[string]interface{}{
"system": []map[string]any{
{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
@@ -87,12 +100,12 @@ func createTestPayload(modelID string) map[string]interface{} {
},
},
"metadata": map[string]string{
"user_id": generateSessionString(),
"user_id": sessionID,
},
"max_tokens": 1024,
"temperature": 1,
"stream": true,
}
}, nil
}
// TestAccountConnection tests an account's connection by sending a test request
@@ -107,6 +120,18 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.sendErrorAndEnd(c, "Account not found")
}
// Route to platform-specific test method
if account.IsOpenAI() {
return s.testOpenAIAccountConnection(c, account, modelID)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
// testClaudeAccountConnection tests an Anthropic Claude account's connection
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
ctx := c.Request.Context()
// Determine the model to use
testModelID := modelID
if testModelID == "" {
@@ -116,7 +141,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if mapping != nil && len(mapping) > 0 {
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
@@ -178,7 +203,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
c.Writer.Flush()
// Create Claude Code style payload (same for all account types)
payload := createTestPayload(testModelID)
payload, err := createTestPayload(testModelID)
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create test payload")
}
payloadBytes, _ := json.Marshal(payload)
// Send test_start event
@@ -212,11 +240,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
proxyURL = account.Proxy.URL()
}
resp, err := s.claudeUpstream.Do(req, proxyURL)
resp, err := s.httpUpstream.Do(req, proxyURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -224,11 +252,155 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
}
// Process SSE stream
return s.processStream(c, resp.Body)
return s.processClaudeStream(c, resp.Body)
}
// processStream processes the SSE stream from Claude API
func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error {
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
ctx := c.Request.Context()
// Default to openai.DefaultTestModel for OpenAI testing
testModelID := modelID
if testModelID == "" {
testModelID = openai.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
}
// Determine authentication method and API URL
var authToken string
var apiURL string
var isOAuth bool
var chatgptAccountID string
if account.IsOAuth() {
isOAuth = true
// OAuth - use Bearer token with ChatGPT internal API
authToken = account.GetOpenAIAccessToken()
if authToken == "" {
return s.sendErrorAndEnd(c, "No access token available")
}
// Check if token is expired and refresh if needed
if account.IsOpenAITokenExpired() && s.openaiOAuthService != nil {
tokenInfo, err := s.openaiOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
}
authToken = tokenInfo.AccessToken
}
// OAuth uses ChatGPT internal API
apiURL = chatgptCodexAPIURL
chatgptAccountID = account.GetChatGPTAccountID()
} else if account.Type == "apikey" {
// API Key - use Platform API
authToken = account.GetOpenAIApiKey()
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
apiURL = strings.TrimSuffix(baseURL, "/") + "/v1/responses"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
// Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
// Create OpenAI Responses API payload
payload := createOpenAITestPayload(testModelID, isOAuth)
payloadBytes, _ := json.Marshal(payload)
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
// Set common headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
// Set OAuth-specific headers for ChatGPT internal API
if isOAuth {
req.Host = "chatgpt.com"
req.Header.Set("accept", "text/event-stream")
if chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
}
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
// Process SSE stream
return s.processOpenAIStream(c, resp.Body)
}
// createOpenAITestPayload creates a test payload for OpenAI Responses API
func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
payload := map[string]any{
"model": modelID,
"input": []map[string]any{
{
"role": "user",
"content": []map[string]any{
{
"type": "input_text",
"text": "hi",
},
},
},
},
"stream": true,
}
// OAuth accounts using ChatGPT internal API require store: false
if isOAuth {
payload["store"] = false
}
// All accounts require instructions for Responses API
payload["instructions"] = openai.DefaultInstructions
return payload
}
// processClaudeStream processes the SSE stream from Claude API
func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
for {
@@ -252,7 +424,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
return nil
}
var data map[string]interface{}
var data map[string]any
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
continue
}
@@ -261,7 +433,7 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
switch eventType {
case "content_block_delta":
if delta, ok := data["delta"].(map[string]interface{}); ok {
if delta, ok := data["delta"].(map[string]any); ok {
if text, ok := delta["text"].(string); ok {
s.sendEvent(c, TestEvent{Type: "content", Text: text})
}
@@ -271,7 +443,60 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
return nil
case "error":
errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]interface{}); ok {
if errData, ok := data["error"].(map[string]any); ok {
if msg, ok := errData["message"].(string); ok {
errorMsg = msg
}
}
return s.sendErrorAndEnd(c, errorMsg)
}
}
}
// processOpenAIStream processes the SSE stream from OpenAI Responses API
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
}
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data: ") {
continue
}
jsonStr := strings.TrimPrefix(line, "data: ")
if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
var data map[string]any
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
continue
}
eventType, _ := data["type"].(string)
switch eventType {
case "response.output_text.delta":
// OpenAI Responses API uses "delta" field for text content
if delta, ok := data["delta"].(string); ok && delta != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
}
case "response.completed":
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
case "error":
errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]any); ok {
if msg, ok := errData["message"].(string); ok {
errorMsg = msg
}
@@ -284,7 +509,10 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
// sendEvent sends a SSE event to the client
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
eventJSON, _ := json.Marshal(event)
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
log.Printf("failed to write SSE event: %v", err)
return
}
c.Writer.Flush()
}

View File

@@ -7,8 +7,9 @@ import (
"sync"
"time"
"sub2api/internal/model"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
// usageCache 用于缓存usage数据
@@ -70,16 +71,14 @@ type ClaudeUsageFetcher interface {
type AccountUsageService struct {
accountRepo ports.AccountRepository
usageLogRepo ports.UsageLogRepository
oauthService *OAuthService
usageFetcher ClaudeUsageFetcher
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
return &AccountUsageService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
oauthService: oauthService,
usageFetcher: usageFetcher,
}
}
@@ -98,8 +97,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
if account.CanGetUsage() {
// 检查缓存
if cached, ok := usageCacheMap.Load(accountID); ok {
cache := cached.(*usageCache)
if time.Since(cache.timestamp) < cacheTTL {
cache, ok := cached.(*usageCache)
if !ok {
usageCacheMap.Delete(accountID)
} else if time.Since(cache.timestamp) < cacheTTL {
return cache.data, nil
}
}
@@ -176,6 +177,14 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
}, nil
}
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get account usage stats failed: %w", err)
}
return stats, nil
}
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
accessToken := account.GetCredential("access_token")

View File

@@ -4,11 +4,12 @@ import (
"context"
"errors"
"fmt"
"log"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"gorm.io/gorm"
)
@@ -21,9 +22,9 @@ type AdminService interface {
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error)
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
@@ -44,6 +45,7 @@ type AdminService interface {
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
@@ -70,6 +72,9 @@ type AdminService interface {
type CreateUserInput struct {
Email string
Password string
Username string
Wechat string
Notes string
Balance float64
Concurrency int
AllowedGroups []int64
@@ -78,6 +83,9 @@ type CreateUserInput struct {
type UpdateUserInput struct {
Email string
Password string
Username *string
Wechat *string
Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string
@@ -113,8 +121,8 @@ type CreateAccountInput struct {
Name string
Platform string
Type string
Credentials map[string]interface{}
Extra map[string]interface{}
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
@@ -124,8 +132,8 @@ type CreateAccountInput struct {
type UpdateAccountInput struct {
Name string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]interface{}
Extra map[string]interface{}
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
@@ -133,6 +141,33 @@ type UpdateAccountInput struct {
GroupIDs *[]int64
}
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
Name string
ProxyID *int64
Concurrency *int
Priority *int
Status string
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
}
// BulkUpdateAccountResult captures the result for a single account update.
type BulkUpdateAccountResult struct {
AccountID int64 `json:"account_id"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
Failed int `json:"failed"`
Results []BulkUpdateAccountResult `json:"results"`
}
type CreateProxyInput struct {
Name string
Protocol string
@@ -192,8 +227,6 @@ type adminServiceImpl struct {
proxyRepo ports.ProxyRepository
apiKeyRepo ports.ApiKeyRepository
redeemCodeRepo ports.RedeemCodeRepository
usageLogRepo ports.UsageLogRepository
userSubRepo ports.UserSubscriptionRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
}
@@ -206,8 +239,6 @@ func NewAdminService(
proxyRepo ports.ProxyRepository,
apiKeyRepo ports.ApiKeyRepository,
redeemCodeRepo ports.RedeemCodeRepository,
usageLogRepo ports.UsageLogRepository,
userSubRepo ports.UserSubscriptionRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
) AdminService {
@@ -218,8 +249,6 @@ func NewAdminService(
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
usageLogRepo: usageLogRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
}
@@ -242,6 +271,9 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User,
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
user := &model.User{
Email: input.Email,
Username: input.Username,
Wechat: input.Wechat,
Notes: input.Notes,
Role: "user", // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
@@ -267,8 +299,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
return nil, errors.New("cannot disable admin user")
}
// Track balance and concurrency changes for logging
oldBalance := user.Balance
oldConcurrency := user.Concurrency
if input.Email != "" {
@@ -279,22 +309,25 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
return nil, err
}
}
// Role is not allowed to be changed via API to prevent privilege escalation
if input.Username != nil {
user.Username = *input.Username
}
if input.Wechat != nil {
user.Wechat = *input.Wechat
}
if input.Notes != nil {
user.Notes = *input.Notes
}
if input.Status != "" {
user.Status = input.Status
}
// 只在指针非 nil 时更新 Balance支持设置为 0
if input.Balance != nil {
user.Balance = *input.Balance
}
// 只在指针非 nil 时更新 Concurrency支持设置为任意值
if input.Concurrency != nil {
user.Concurrency = *input.Concurrency
}
// 只在指针非 nil 时更新 AllowedGroups
if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups
}
@@ -303,39 +336,15 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
return nil, err
}
// 余额变化时失效缓存
if input.Balance != nil && *input.Balance != oldBalance {
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, id)
}()
}
}
// Create adjustment records for balance/concurrency changes
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
adjustmentRecord := &model.RedeemCode{
Code: model.GenerateRedeemCode(),
Type: model.AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: model.StatusUsed,
UsedBy: &user.ID,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
// Log error but don't fail the update
// The user update has already succeeded
}
}
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
code, err := model.GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{
Code: model.GenerateRedeemCode(),
Code: code,
Type: model.AdjustmentTypeAdminConcurrency,
Value: float64(concurrencyDiff),
Status: model.StatusUsed,
@@ -344,8 +353,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
// Log error but don't fail the update
// The user update has already succeeded
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
}
}
@@ -364,12 +372,14 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
return s.userRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error) {
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
oldBalance := user.Balance
switch operation {
case "set":
user.Balance = balance
@@ -379,19 +389,48 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
user.Balance -= balance
}
if user.Balance < 0 {
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
}
}()
}
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := model.GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{
Code: code,
Type: model.AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: model.StatusUsed,
UsedBy: &user.ID,
Notes: notes,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create balance adjustment redeem code: %v", err)
}
}
return user, nil
}
@@ -404,9 +443,9 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
return keys, result.Total, nil
}
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error) {
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
// Return mock data for now
return map[string]interface{}{
return map[string]any{
"period": period,
"total_requests": 0,
"total_cost": 0.0,
@@ -579,7 +618,9 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for _, userID := range affectedUserIDs {
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
}
}
}()
}
@@ -646,10 +687,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Type != "" {
account.Type = input.Type
}
if input.Credentials != nil && len(input.Credentials) > 0 {
if len(input.Credentials) > 0 {
account.Credentials = model.JSONB(input.Credentials)
}
if input.Extra != nil && len(input.Extra) > 0 {
if len(input.Extra) > 0 {
account.Extra = model.JSONB(input.Extra)
}
if input.ProxyID != nil {
@@ -681,6 +722,65 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
return account, nil
}
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
result := &BulkUpdateAccountsResult{
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
}
if len(input.AccountIDs) == 0 {
return result, nil
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := ports.AccountBulkUpdate{
Credentials: input.Credentials,
Extra: input.Extra,
}
if input.Name != "" {
repoUpdates.Name = &input.Name
}
if input.ProxyID != nil {
repoUpdates.ProxyID = input.ProxyID
}
if input.Concurrency != nil {
repoUpdates.Concurrency = input.Concurrency
}
if input.Priority != nil {
repoUpdates.Priority = input.Priority
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
// Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
return nil, err
}
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
entry.Success = true
result.Success++
result.Results = append(result.Results, entry)
}
return result, nil
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id)
}
@@ -831,8 +931,12 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
codes := make([]model.RedeemCode, 0, input.Count)
for i := 0; i < input.Count; i++ {
codeValue, err := model.GenerateRedeemCode()
if err != nil {
return nil, err
}
code := model.RedeemCode{
Code: model.GenerateRedeemCode(),
Code: codeValue,
Type: input.Type,
Value: input.Value,
Status: model.StatusUnused,

View File

@@ -6,11 +6,11 @@ import (
"encoding/hex"
"errors"
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"time"
"github.com/redis/go-redis/v9"
@@ -100,10 +100,13 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
// 检查字符:只允许字母、数字、下划线、连字符
for _, c := range key {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
return ErrApiKeyInvalidChars
if (c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c == '_' || c == '-' {
continue
}
return ErrApiKeyInvalidChars
}
return nil
@@ -452,3 +455,11 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}
return keys, nil
}

View File

@@ -4,10 +4,10 @@ import (
"context"
"errors"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"log"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/service/ports"
"time"
"github.com/golang-jwt/jwt/v5"
@@ -23,6 +23,7 @@ var (
ErrTokenExpired = errors.New("token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled")
ErrServiceUnavailable = errors.New("service temporarily unavailable")
)
// JWTClaims JWT载荷数据
@@ -90,7 +91,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return "", nil, fmt.Errorf("check email exists: %w", err)
log.Printf("[Auth] Database error checking email exists: %v", err)
return "", nil, ErrServiceUnavailable
}
if existsEmail {
return "", nil, ErrEmailExists
@@ -121,7 +123,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
if err := s.userRepo.Create(ctx, user); err != nil {
return "", nil, fmt.Errorf("create user: %w", err)
log.Printf("[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
// 生成token
@@ -148,7 +151,8 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return fmt.Errorf("check email exists: %w", err)
log.Printf("[Auth] Database error checking email exists: %v", err)
return ErrServiceUnavailable
}
if existsEmail {
return ErrEmailExists
@@ -181,8 +185,8 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Error checking email exists: %v", err)
return nil, fmt.Errorf("check email exists: %w", err)
log.Printf("[Auth] Database error checking email exists: %v", err)
return nil, ErrServiceUnavailable
}
if existsEmail {
log.Printf("[Auth] Email already exists: %s", email)
@@ -254,7 +258,9 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", nil, ErrInvalidCredentials
}
return "", nil, fmt.Errorf("get user by email: %w", err)
// 记录数据库错误但不暴露给用户
log.Printf("[Auth] Database error during login: %v", err)
return "", nil, ErrServiceUnavailable
}
// 验证密码
@@ -278,7 +284,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
@@ -354,7 +360,8 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrInvalidToken
}
return "", fmt.Errorf("get user: %w", err)
log.Printf("[Auth] Database error refreshing token: %v", err)
return "", ErrServiceUnavailable
}
// 检查用户状态

View File

@@ -7,8 +7,8 @@ import (
"log"
"time"
"sub2api/internal/model"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
// 错误定义

View File

@@ -2,9 +2,9 @@ package service
import (
"fmt"
"github.com/Wei-Shaw/sub2api/internal/config"
"log"
"strings"
"sub2api/internal/config"
)
// ModelPricing 模型价格配置per-token价格与LiteLLM格式一致
@@ -259,11 +259,11 @@ func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, es
}
// GetPricingServiceStatus 获取价格服务状态
func (s *BillingService) GetPricingServiceStatus() map[string]interface{} {
func (s *BillingService) GetPricingServiceStatus() map[string]any {
if s.pricingService != nil {
return s.pricingService.GetStatus()
}
return map[string]interface{}{
return map[string]any{
"model_count": len(s.fallbackPrices),
"last_updated": "using fallback",
"local_hash": "N/A",

View File

@@ -2,19 +2,27 @@ package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"time"
"sub2api/internal/service/ports"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
// generateRequestID generates a unique request ID for concurrency slot tracking
// Uses 8 random bytes (16 hex chars) for uniqueness
func generateRequestID() string {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
// Fallback to nanosecond timestamp (extremely rare case)
return fmt.Sprintf("%x", time.Now().UnixNano())
}
return hex.EncodeToString(b)
}
const (
// Wait polling interval
waitPollInterval = 100 * time.Millisecond
// Default max wait time
defaultMaxWait = 60 * time.Second
// Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots = 20
)
@@ -31,7 +39,7 @@ func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
// AcquireResult represents the result of acquiring a concurrency slot
type AcquireResult struct {
Acquired bool
Acquired bool
ReleaseFunc func() // Must be called when done (typically via defer)
}
@@ -47,19 +55,22 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
}, nil
}
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency)
// Generate unique request ID for this slot
requestID := generateRequestID()
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency, requestID)
if err != nil {
return nil, err
}
if acquired {
return &AcquireResult{
Acquired: true,
Acquired: true,
ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil {
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err)
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil {
log.Printf("Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
}
},
}, nil
@@ -83,19 +94,22 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
}, nil
}
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency)
// Generate unique request ID for this slot
requestID := generateRequestID()
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency, requestID)
if err != nil {
return nil, err
}
if acquired {
return &AcquireResult{
Acquired: true,
Acquired: true,
ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil {
log.Printf("Warning: failed to release user slot for %d: %v", userID, err)
if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil {
log.Printf("Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
}
},
}, nil
@@ -153,3 +167,20 @@ func CalculateMaxWait(userConcurrency int) int {
}
return userConcurrency + defaultExtraWaitSlots
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int)
for _, accountID := range accountIDs {
count, err := s.cache.GetAccountConcurrency(ctx, accountID)
if err != nil {
// If key doesn't exist in Redis, count is 0
count = 0
}
result[accountID] = count
}
return result, nil
}

View File

@@ -0,0 +1,961 @@
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
type CRSSyncService struct {
accountRepo ports.AccountRepository
proxyRepo ports.ProxyRepository
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
}
func NewCRSSyncService(
accountRepo ports.AccountRepository,
proxyRepo ports.ProxyRepository,
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
) *CRSSyncService {
return &CRSSyncService{
accountRepo: accountRepo,
proxyRepo: proxyRepo,
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
}
}
type SyncFromCRSInput struct {
BaseURL string
Username string
Password string
SyncProxies bool
}
type SyncFromCRSItemResult struct {
CRSAccountID string `json:"crs_account_id"`
Kind string `json:"kind"`
Name string `json:"name"`
Action string `json:"action"` // created/updated/failed/skipped
Error string `json:"error,omitempty"`
}
type SyncFromCRSResult struct {
Created int `json:"created"`
Updated int `json:"updated"`
Skipped int `json:"skipped"`
Failed int `json:"failed"`
Items []SyncFromCRSItemResult `json:"items"`
}
type crsLoginResponse struct {
Success bool `json:"success"`
Token string `json:"token"`
Message string `json:"message"`
Error string `json:"error"`
Username string `json:"username"`
}
type crsExportResponse struct {
Success bool `json:"success"`
Error string `json:"error"`
Message string `json:"message"`
Data struct {
ExportedAt string `json:"exportedAt"`
ClaudeAccounts []crsClaudeAccount `json:"claudeAccounts"`
ClaudeConsoleAccounts []crsConsoleAccount `json:"claudeConsoleAccounts"`
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
} `json:"data"`
}
type crsProxy struct {
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
}
type crsClaudeAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
AuthType string `json:"authType"` // oauth/setup-token
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
type crsConsoleAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
MaxConcurrentTasks int `json:"maxConcurrentTasks"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
}
type crsOpenAIResponsesAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
}
type crsOpenAIOAuthAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
AuthType string `json:"authType"` // oauth
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL)
if err != nil {
return nil, err
}
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
return nil, errors.New("username and password are required")
}
client := &http.Client{Timeout: 20 * time.Second}
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
if err != nil {
return nil, err
}
exported, err := crsExportAccounts(ctx, client, baseURL, adminToken)
if err != nil {
return nil, err
}
now := time.Now().UTC().Format(time.RFC3339)
result := &SyncFromCRSResult{
Items: make(
[]SyncFromCRSItemResult,
0,
len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts),
),
}
var proxies []model.Proxy
if input.SyncProxies {
proxies, _ = s.proxyRepo.ListActive(ctx)
}
// Claude OAuth / Setup Token -> sub2api anthropic oauth/setup-token
for _, src := range exported.Data.ClaudeAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
targetType := strings.TrimSpace(src.AuthType)
if targetType == "" {
targetType = "oauth"
}
if targetType != model.AccountTypeOAuth && targetType != model.AccountTypeSetupToken {
item.Action = "skipped"
item.Error = "unsupported authType: " + targetType
result.Skipped++
result.Items = append(result.Items, item)
continue
}
accessToken, _ := src.Credentials["access_token"].(string)
if strings.TrimSpace(accessToken) == "" {
item.Action = "failed"
item.Error = "missing access_token"
result.Failed++
result.Items = append(result.Items, item)
continue
}
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
// 🔧 Remove /v1 suffix from base_url for Claude accounts
cleanBaseURL(credentials, "/v1")
// 🔧 Convert expires_at from ISO string to Unix timestamp
if expiresAtStr, ok := credentials["expires_at"].(string); ok && expiresAtStr != "" {
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
credentials["expires_at"] = t.Unix()
}
}
// 🔧 Add intercept_warmup_requests if not present (defaults to false)
if _, exists := credentials["intercept_warmup_requests"]; !exists {
credentials["intercept_warmup_requests"] = false
}
priority := clampPriority(src.Priority)
concurrency := 3
status := mapCRSStatus(src.IsActive, src.Status)
// 🔧 Preserve all CRS extra fields and add sync metadata
extra := make(map[string]any)
if src.Extra != nil {
for k, v := range src.Extra {
extra[k] = v
}
}
extra["crs_account_id"] = src.ID
extra["crs_kind"] = src.Kind
extra["crs_synced_at"] = now
// Extract org_uuid and account_uuid from CRS credentials to extra
if orgUUID, ok := src.Credentials["org_uuid"]; ok {
extra["org_uuid"] = orgUUID
}
if accountUUID, ok := src.Credentials["account_uuid"]; ok {
extra["account_uuid"] = accountUUID
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic,
Type: targetType,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after creation
if targetType == model.AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
}
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
// Update existing
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic
existing.Type = targetType
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after update
if targetType == model.AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
}
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
// Claude Console API Key -> sub2api anthropic apikey
for _, src := range exported.Data.ClaudeConsoleAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
apiKey, _ := src.Credentials["api_key"].(string)
if strings.TrimSpace(apiKey) == "" {
item.Action = "failed"
item.Error = "missing api_key"
result.Failed++
result.Items = append(result.Items, item)
continue
}
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
priority := clampPriority(src.Priority)
concurrency := 3
if src.MaxConcurrentTasks > 0 {
concurrency = src.MaxConcurrentTasks
}
status := mapCRSStatus(src.IsActive, src.Status)
extra := map[string]any{
"crs_account_id": src.ID,
"crs_kind": src.Kind,
"crs_synced_at": now,
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic
existing.Type = model.AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
// OpenAI OAuth -> sub2api openai oauth
for _, src := range exported.Data.OpenAIOAuthAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
accessToken, _ := src.Credentials["access_token"].(string)
if strings.TrimSpace(accessToken) == "" {
item.Action = "failed"
item.Error = "missing access_token"
result.Failed++
result.Items = append(result.Items, item)
continue
}
proxyID, err := s.mapOrCreateProxy(
ctx,
input.SyncProxies,
&proxies,
src.Proxy,
fmt.Sprintf("crs-%s", src.Name),
)
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
// Normalize token_type
if v, ok := credentials["token_type"].(string); !ok || strings.TrimSpace(v) == "" {
credentials["token_type"] = "Bearer"
}
// 🔧 Convert expires_at from ISO string to Unix timestamp
if expiresAtStr, ok := credentials["expires_at"].(string); ok && expiresAtStr != "" {
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
credentials["expires_at"] = t.Unix()
}
}
priority := clampPriority(src.Priority)
concurrency := 3
status := mapCRSStatus(src.IsActive, src.Status)
// 🔧 Preserve all CRS extra fields and add sync metadata
extra := make(map[string]any)
if src.Extra != nil {
for k, v := range src.Extra {
extra[k] = v
}
}
extra["crs_account_id"] = src.ID
extra["crs_kind"] = src.Kind
extra["crs_synced_at"] = now
// Extract email from CRS extra (crs_email -> email)
if crsEmail, ok := src.Extra["crs_email"]; ok {
extra["email"] = crsEmail
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI,
Type: model.AccountTypeOAuth,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after creation
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI
existing.Type = model.AccountTypeOAuth
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after update
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
// OpenAI Responses API Key -> sub2api openai apikey
for _, src := range exported.Data.OpenAIResponsesAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
apiKey, _ := src.Credentials["api_key"].(string)
if strings.TrimSpace(apiKey) == "" {
item.Action = "failed"
item.Error = "missing api_key"
result.Failed++
result.Items = append(result.Items, item)
continue
}
if baseURL, ok := src.Credentials["base_url"].(string); !ok || strings.TrimSpace(baseURL) == "" {
src.Credentials["base_url"] = "https://api.openai.com"
}
// 🔧 Remove /v1 suffix from base_url for OpenAI accounts
cleanBaseURL(src.Credentials, "/v1")
proxyID, err := s.mapOrCreateProxy(
ctx,
input.SyncProxies,
&proxies,
src.Proxy,
fmt.Sprintf("crs-%s", src.Name),
)
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
priority := clampPriority(src.Priority)
concurrency := 3
status := mapCRSStatus(src.IsActive, src.Status)
extra := map[string]any{
"crs_account_id": src.ID,
"crs_kind": src.Kind,
"crs_synced_at": now,
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI
existing.Type = model.AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
return result, nil
}
// mergeJSONB merges two JSONB maps without removing keys that are absent in updates.
func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
out := make(model.JSONB)
for k, v := range existing {
out[k] = v
}
for k, v := range updates {
out[k] = v
}
return out
}
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]model.Proxy, src *crsProxy, defaultName string) (*int64, error) {
if !enabled || src == nil {
return nil, nil
}
protocol := strings.ToLower(strings.TrimSpace(src.Protocol))
switch protocol {
case "socks":
protocol = "socks5"
case "socks5h":
protocol = "socks5"
}
host := strings.TrimSpace(src.Host)
port := src.Port
username := strings.TrimSpace(src.Username)
password := strings.TrimSpace(src.Password)
if protocol == "" || host == "" || port <= 0 {
return nil, nil
}
if protocol != "http" && protocol != "https" && protocol != "socks5" {
return nil, nil
}
// Find existing proxy (active only).
for _, p := range *cached {
if strings.EqualFold(p.Protocol, protocol) &&
p.Host == host &&
p.Port == port &&
p.Username == username &&
p.Password == password {
id := p.ID
return &id, nil
}
}
// Create new proxy
proxy := &model.Proxy{
Name: defaultProxyName(defaultName, protocol, host, port),
Protocol: protocol,
Host: host,
Port: port,
Username: username,
Password: password,
Status: model.StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
}
*cached = append(*cached, *proxy)
id := proxy.ID
return &id, nil
}
func defaultProxyName(base, protocol, host string, port int) string {
base = strings.TrimSpace(base)
if base == "" {
base = "crs"
}
return fmt.Sprintf("%s (%s://%s:%d)", base, protocol, host, port)
}
func defaultName(name, id string) string {
if strings.TrimSpace(name) != "" {
return strings.TrimSpace(name)
}
return "CRS " + id
}
func clampPriority(priority int) int {
if priority < 1 || priority > 100 {
return 50
}
return priority
}
func sanitizeCredentialsMap(input map[string]any) map[string]any {
if input == nil {
return map[string]any{}
}
out := make(map[string]any, len(input))
for k, v := range input {
// Avoid nil values to keep JSONB cleaner
if v != nil {
out[k] = v
}
}
return out
}
func mapCRSStatus(isActive bool, status string) string {
if !isActive {
return "inactive"
}
if strings.EqualFold(strings.TrimSpace(status), "error") {
return "error"
}
return "active"
}
func normalizeBaseURL(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("base_url is required")
}
u, err := url.Parse(trimmed)
if err != nil || u.Scheme == "" || u.Host == "" {
return "", fmt.Errorf("invalid base_url: %s", trimmed)
}
u.Path = strings.TrimRight(u.Path, "/")
return strings.TrimRight(u.String(), "/"), nil
}
// cleanBaseURL removes trailing suffix from base_url in credentials
// Used for both Claude and OpenAI accounts to remove /v1
func cleanBaseURL(credentials map[string]any, suffixToRemove string) {
if baseURL, ok := credentials["base_url"].(string); ok && baseURL != "" {
trimmed := strings.TrimSpace(baseURL)
if strings.HasSuffix(trimmed, suffixToRemove) {
credentials["base_url"] = strings.TrimSuffix(trimmed, suffixToRemove)
}
}
}
func crsLogin(ctx context.Context, client *http.Client, baseURL, username, password string) (string, error) {
payload := map[string]any{
"username": username,
"password": password,
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/web/auth/login", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("crs login failed: status=%d body=%s", resp.StatusCode, string(raw))
}
var parsed crsLoginResponse
if err := json.Unmarshal(raw, &parsed); err != nil {
return "", fmt.Errorf("crs login parse failed: %w", err)
}
if !parsed.Success || strings.TrimSpace(parsed.Token) == "" {
msg := parsed.Message
if msg == "" {
msg = parsed.Error
}
if msg == "" {
msg = "unknown error"
}
return "", errors.New("crs login failed: " + msg)
}
return parsed.Token, nil
}
func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminToken string) (*crsExportResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/admin/sync/export-accounts?include_secrets=true", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+adminToken)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 5<<20))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("crs export failed: status=%d body=%s", resp.StatusCode, string(raw))
}
var parsed crsExportResponse
if err := json.Unmarshal(raw, &parsed); err != nil {
return nil, fmt.Errorf("crs export parse failed: %w", err)
}
if !parsed.Success {
msg := parsed.Message
if msg == "" {
msg = parsed.Error
}
if msg == "" {
msg = "unknown error"
}
return nil, errors.New("crs export failed: " + msg)
}
return &parsed, nil
}
// refreshOAuthToken attempts to refresh OAuth token for a synced account
// Returns updated credentials or nil if refresh failed/not applicable
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.Account) model.JSONB {
if account.Type != model.AccountTypeOAuth {
return nil
}
var newCredentials map[string]any
var err error
switch account.Platform {
case model.PlatformAnthropic:
if s.oauthService == nil {
return nil
}
tokenInfo, refreshErr := s.oauthService.RefreshAccountToken(ctx, account)
if refreshErr != nil {
err = refreshErr
} else {
// Preserve existing credentials
newCredentials = make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
// Update token fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = tokenInfo.ExpiresIn
newCredentials["expires_at"] = tokenInfo.ExpiresAt
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
}
case model.PlatformOpenAI:
if s.openaiOAuthService == nil {
return nil
}
tokenInfo, refreshErr := s.openaiOAuthService.RefreshAccountToken(ctx, account)
if refreshErr != nil {
err = refreshErr
} else {
newCredentials = s.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
}
default:
return nil
}
if err != nil {
// Log but don't fail the sync - token might still be valid or refreshable later
return nil
}
return model.JSONB(newCredentials)
}

View File

@@ -0,0 +1,77 @@
package service
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
// DashboardService provides aggregated statistics for admin dashboard.
type DashboardService struct {
usageRepo ports.UsageLogRepository
}
func NewDashboardService(usageRepo ports.UsageLogRepository) *DashboardService {
return &DashboardService{
usageRepo: usageRepo,
}
}
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
stats, err := s.usageRepo.GetDashboardStats(ctx)
if err != nil {
return nil, fmt.Errorf("get dashboard stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, 0)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
trend, err := s.usageRepo.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get user usage trend: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs)
if err != nil {
return nil, fmt.Errorf("get batch user usage stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}
return stats, nil
}

Some files were not shown because too many files have changed in this diff Show More