Compare commits

...

31 Commits

Author SHA1 Message Date
刀刀
9d30ceae8d CC 400 返回具体错误信息 && 非 CC 请求时增加 system prompt (#26)
* feat: http 400 返回具体错误

* 更新 workflows

* 优化打包/docker 构建流程

* 400 是返回 原始错误 - json 格式

* feat: 非 cc请求时补充 system

* go mod tidy
2025-12-25 14:47:19 +08:00
IanShaw
60f6ed6bf6 feat: CRS 同步增强 - 自动刷新 OAuth token 和修复测试配置 (#27)
* fix(service): 修复 OpenAI Responses API 测试负载配置

- 所有账号类型统一添加 instructions 字段(不再仅限 OAuth)
- Responses API 要求所有请求必须包含 instructions 参数

* feat(crs-sync): CRS 同步时自动刷新 OAuth token 并保留完整 extra 字段

**核心功能**:
- CRSSyncService 注入 OAuth 服务依赖(Anthropic + OpenAI)
- 账号创建/更新后自动刷新 OAuth token,确保可用性
- 完整保留 CRS extra 字段,避免数据丢失

**Extra 字段增强**:
- 保留 CRS 所有原始 extra 字段
- 新增同步元数据: crs_account_id, crs_kind, crs_synced_at
- Claude 账号: 从 credentials 提取 org_uuid/account_uuid 到 extra
- OpenAI 账号: 映射 crs_email -> email

**Token 刷新逻辑**:
- 新增 refreshOAuthToken() 方法处理 Anthropic/OpenAI 平台
- 保留原有 credentials 字段,仅更新 token 相关字段
- 刷新失败静默处理,不中断同步流程

**依赖注入**:
- wire_gen.go: CRSSyncService 新增 oAuthService/openaiOAuthService

* style(crs-sync): 使用 switch 替代 if-else 修复 golangci-lint 警告

- 将 refreshOAuthToken 中的 if-else 改为 switch 语句
- 符合 staticcheck 规范
- 添加 default 分支处理未知平台
2025-12-25 14:45:17 +08:00
shaw
4a2f7d4a99 chore: CRS迁移功能增加版本提示 2025-12-25 10:57:04 +08:00
shaw
c19a393be9 Merge PR #24: feat: 添加账户同步与批量编辑功能
- 添加从 CRS 同步账户功能 (Claude OAuth/API Key, OpenAI OAuth/Responses)
- 添加批量编辑账户功能,支持 JSONB 字段智能合并
- 新增 CRSSyncService、BulkUpdate 仓储方法
- 前端新增 SyncFromCrsModal 和 BulkEditAccountModal 组件
2025-12-25 10:44:40 +08:00
ianshaw
938ffb002e style(frontend): format code with prettier
格式化前端业务代码,符合代码规范
- 统一代码风格
- 修复 ESLint 警告
2025-12-24 18:07:58 -08:00
ianshaw
372a01290b fix(backend): handle defer Close() errors in crs_sync_service
修复 golangci-lint 错误检查问题
- 使用匿名函数包装 defer Close() 并忽略错误
- 符合 Go 最佳实践
2025-12-24 17:58:47 -08:00
ianshaw
8b163ca49b chore: trigger CI after enabling Actions 2025-12-24 17:56:55 -08:00
ianshaw
d23810dc53 chore: trigger CI workflow 2025-12-24 17:54:43 -08:00
ianshaw
62ed5422dd feat(account): 优化批量更新实现,使用统一 SQL 合并 JSONB 字段
- 新增 BulkUpdate 仓储方法,使用单条 SQL 更新所有账户
- credentials/extra 使用 COALESCE(...) || ? 合并,只更新传入的 key
- name/proxy_id/concurrency/priority/status 只在提供时更新
- 分组绑定仍逐账号处理(需要独立操作)
- 前端优化:Base URL 留空则不修改,按勾选字段更新
- 完善 i18n 文案:说明留空不修改、批量更新行为
2025-12-24 17:16:19 -08:00
ianshaw
2e76302af7 feat(account): 添加批量编辑账户凭据功能并优化 CRS 同步
- 新增批量更新账户凭据接口(account_uuid/org_uuid/intercept_warmup_requests)
- 新增前端批量编辑模态框组件
- 优化 CRS 同步逻辑,改进 extra 字段处理
- 优化 CRS 同步 UI,添加更详细的结果展示
- 完善国际化文案(中英文)
2025-12-24 16:56:48 -08:00
ianshaw
6553828008 feat(account): 添加从 CRS 同步账户功能
- 添加账户同步 API 接口 (account_handler.go)
- 实现 CRS 同步服务 (crs_sync_service.go)
- 添加前端同步对话框组件 (SyncFromCrsModal.vue)
- 更新账户管理界面支持同步操作
- 添加账户仓库批量创建方法
- 添加中英文国际化翻译
- 更新依赖注入配置
2025-12-24 08:48:58 -08:00
ianshaw
adcb7bf00e chore: 更新 .gitignore 忽略配置文件并还原 Makefile
- 添加 backend/config.yaml 到 .gitignore(包含敏感信息)
- 添加 deploy/config.yaml 到 .gitignore(包含敏感信息)
- 添加 backend/.installed 到 .gitignore
- 还原 Makefile 到原始版本
2025-12-24 08:48:49 -08:00
shaw
876e85e7ad Merge branch 'feat/rename-go-module' 2025-12-24 21:34:37 +08:00
shaw
2e7818d688 feat(settings): 添加文档链接配置功能
- 后台系统设置新增文档链接(doc_url)配置项
- 首页顶部导航栏显示文档链接图标(条件渲染)
- Footer区域添加文档链接和GitHub链接
- 支持中英文国际化
2025-12-24 21:30:19 +08:00
Forest
836c4dda2b refactor: 重命名 go module 2025-12-24 21:07:21 +08:00
shaw
e65e9587b4 fix(concurrency): 重构并发管理使用独立Key+原生TTL
问题:旧方案使用计数器模式,每次acquire都刷新TTL,导致僵尸数据永不过期

解决方案:
- 每个槽位使用独立Redis Key: concurrency:account:{id}:{requestID}
- 利用Redis原生TTL,每个槽位独立5分钟过期
- 服务崩溃后僵尸数据自动清理,无需手动干预
- 兼容多实例K8s部署

技术改动:
- 新增SCAN脚本统计活跃槽位数量
- 移除冗余的releaseScript,直接使用DEL命令
- Wait队列TTL只在首次创建时设置,避免刷新
2025-12-24 21:00:29 +08:00
shaw
aaadd6ed04 fix(dashboard): 修复性能指标 RPM/TPM 显示为0的问题
- 修复 Admin Dashboard Handler 遗漏返回 rpm/tpm 字段
- 将性能统计时间窗口从1分钟改为5分钟平均值,数据更稳定
2025-12-24 19:58:33 +08:00
shaw
870b21916c feat(install): 添加安装指定版本和回退功能
- 新增 rollback 命令支持回退到指定版本
- 新增 list-versions 命令列出可用版本
- 新增 -v/--version 参数指定安装版本
- upgrade 命令支持升级到指定版本
- 添加安装状态检查,未安装时给出明确提示
- 版本切换仅替换二进制文件,保留配置和数据
- 自动备份当前版本(带版本号或时间戳后缀)
- 改进网络错误处理,添加超时和友好提示
- 修复 grep -oP 兼容性问题,改用 grep -oE
2025-12-24 17:44:13 +08:00
shaw
fb119f9a67 fix(version): 优化服务重启后页面刷新时机
- 将重启后等待时间从 3 秒增加到 8 秒
- 添加倒计时显示,提升用户体验
- 倒计时结束后先检测服务健康状态再刷新页面
- 避免刷新过早导致 502 错误
2025-12-24 17:21:17 +08:00
shaw
ad54795a24 feat(gateway): 添加上游错误重试机制
- OAuth/Setup Token 账号遇到 403 错误时,等待 2 秒后重试,最多 3 次
- Console 账号遇到未配置的错误码时,同样进行重试
- 重试耗尽后:OAuth 403 标记账号异常,Console 未配置错误码不标记账号
- 移除 handleErrorResponse 中已被重试逻辑覆盖的死代码
2025-12-24 16:55:46 +08:00
shaw
0abe322cca feat(accounts): 账户列表显示实时并发数
- 在账户列表 API 返回中添加 current_concurrency 字段
- 合并平台和类型列为 PlatformTypeBadge 组件,节省表格空间
- 新增并发状态列,显示 当前/最大 并发数,支持颜色编码
2025-12-24 15:44:45 +08:00
shaw
b071511676 refactor(accounts): 优化用量窗口显示,统一 OAuth 和 Setup Token 处理
- Setup Token 账号现在也调用 API 获取 5h 窗口用量数据
- 重新设计 UsageProgressBar UI,将用量统计移到进度条上方
- 删除冗余的 SetupTokenTimeWindow 组件
- 请求数/Token数支持 K/M/B 单位显示
2025-12-24 10:57:40 +08:00
shaw
7d9a757a26 feat(dashboard): 添加 RPM/TPM 性能指标
在 Dashboard 中用 RPM/TPM 卡片替换原来的"今日缓存"卡片,
实时显示最近1分钟的请求数和 Token 吞吐量。
2025-12-24 10:24:02 +08:00
Forest
bbf4024dc7 refactor(usage): 移动 usage 查询到 services 2025-12-24 08:41:31 +08:00
shaw
5831eb8a6a fix: 修复Claude OAuth token交换时authorization code解析错误
原代码中 `parts` 变量被创建但从未使用,导致 `len(parts) == 0`
永远为 true,使得即使成功从 `code#state` 格式中分割出 authCode,
最后也会被覆盖为原始的完整字符串。

这导致传递给Claude Token端点的code包含了 `#state` 部分,
Claude返回 "Invalid 'code' in request" 错误。
2025-12-23 19:42:52 +08:00
shaw
61838cdb3d fix: 兼容GLM等API的usage数据解析
部分第三方API(如GLM)的SSE响应格式与标准Claude API不同:
- 标准Claude: input_tokens在message_start中
- GLM等API: 所有tokens都在message_delta中

现在从message_delta中也解析input_tokens和cache相关字段,
如果message_start中没有值则使用message_delta中的数据。
2025-12-23 19:42:52 +08:00
dexcoder6
50dba656fd feat: 添加用户余额充值/退款功能 (#17)
## 功能特性

### 前端
- 在用户列表操作列添加充值和退款按钮
- 实现充值/退款对话框,支持输入金额和备注
- 从编辑用户表单中移除余额字段,防止直接修改
- 添加余额不足验证,实时显示操作后余额
- 优化备注提示词,提供多种场景示例

### 后端
- 为 redeem_codes 表添加 notes 字段(迁移文件)
- 在 UpdateUserBalance 接口添加 notes 参数支持
- 添加余额验证:金额必须大于0,操作后余额不能为负
- UpdateUser 接口移除 balance 字段处理,防止误操作
- 完整的审计日志和缓存管理

## 安全保护

- 前端:余额不足时禁用提交按钮,实时提示
- 后端:双重验证(输入金额 > 0 + 结果余额 >= 0)
- 权限:仅管理员可访问(AdminAuth 中间件)
- 审计:所有操作记录到 redeem_codes 表

## 修改文件

后端:
- backend/migrations/004_add_redeem_code_notes.sql
- backend/internal/model/redeem_code.go
- backend/internal/service/admin_service.go
- backend/internal/handler/admin/user_handler.go

前端:
- frontend/src/views/admin/UsersView.vue
- frontend/src/api/admin/users.ts
- frontend/src/i18n/locales/zh.ts
- frontend/src/i18n/locales/en.ts

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-23 16:29:57 +08:00
shaw
0e2821456c chore: 忽略TypeScript增量编译缓存文件 2025-12-23 16:27:56 +08:00
shaw
f25ac3aff5 feat: OpenAI OAuth账号显示Codex使用量
从响应头提取x-codex-*使用量信息并保存到账号Extra字段,
前端账号列表展示5h/7d窗口的使用进度条。
2025-12-23 16:26:07 +08:00
shaw
f6341b7f2b chore: 将"代理管理"菜单更名为"IP管理" 2025-12-23 15:46:10 +08:00
shaw
4e257512b9 style: 统一平台和分组列的样式
- 账号页面平台列改为与分组页面一致的标签样式
- 订阅页面分组列改用 GroupBadge 组件展示
- 修正 OpenAI OAuth 类型描述文案
2025-12-23 15:40:22 +08:00
144 changed files with 5255 additions and 1220 deletions

View File

@@ -85,6 +85,19 @@ jobs:
go-version: '1.24' go-version: '1.24'
cache-dependency-path: backend/go.sum cache-dependency-path: backend/go.sum
# Docker setup for GoReleaser
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Fetch tags with annotations - name: Fetch tags with annotations
run: | run: |
# 确保获取完整的 annotated tag 信息 # 确保获取完整的 annotated tag 信息
@@ -117,87 +130,16 @@ jobs:
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }} TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
# =========================================================================== # Update DockerHub description
# Docker Build and Push
# ===========================================================================
docker:
needs: [update-version, build-frontend]
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Download VERSION artifact
uses: actions/download-artifact@v4
with:
name: version-file
path: backend/cmd/server/
- name: Download frontend artifact
uses: actions/download-artifact@v4
with:
name: frontend-dist
path: backend/internal/web/dist/
# Extract version from tag
- name: Extract version
id: version
run: |
VERSION=${GITHUB_REF#refs/tags/v}
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Version: $VERSION"
# Set up Docker Buildx for multi-platform builds
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# Login to DockerHub
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Extract metadata for Docker
- name: Extract Docker metadata
id: meta
uses: docker/metadata-action@v5
with:
images: |
weishaw/sub2api
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=raw,value=latest,enable={{is_default_branch}}
# Build and push Docker image
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
COMMIT=${{ github.sha }}
DATE=${{ github.event.head_commit.timestamp }}
cache-from: type=gha
cache-to: type=gha,mode=max
# Update DockerHub description (optional)
- name: Update DockerHub description - name: Update DockerHub description
uses: peter-evans/dockerhub-description@v4 uses: peter-evans/dockerhub-description@v4
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
repository: weishaw/sub2api repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
short-description: "Sub2API - AI API Gateway Platform" short-description: "Sub2API - AI API Gateway Platform"
readme-filepath: ./deploy/DOCKER.md readme-filepath: ./deploy/DOCKER.md

8
.gitignore vendored
View File

@@ -28,6 +28,7 @@ node_modules/
frontend/node_modules/ frontend/node_modules/
frontend/dist/ frontend/dist/
*.local *.local
*.tsbuildinfo
# 日志 # 日志
npm-debug.log* npm-debug.log*
@@ -91,6 +92,13 @@ backend/internal/web/dist/*
# 后端运行时缓存数据 # 后端运行时缓存数据
backend/data/ backend/data/
# ===================
# 本地配置文件(包含敏感信息)
# ===================
backend/config.yaml
deploy/config.yaml
backend/.installed
# =================== # ===================
# 其他 # 其他
# =================== # ===================

View File

@@ -52,10 +52,58 @@ changelog:
# 禁用自动 changelog完全使用 tag 消息 # 禁用自动 changelog完全使用 tag 消息
disable: true disable: true
# Docker images
dockers:
- id: amd64
goos: linux
goarch: amd64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
- id: arm64
goos: linux
goarch: arm64
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
dockerfile: Dockerfile.goreleaser
use: buildx
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .Commit }}"
# Docker manifests for multi-arch support
docker_manifests:
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}"
image_templates:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
release: release:
github: github:
owner: Wei-Shaw owner: "{{ .Env.GITHUB_REPO_OWNER }}"
name: sub2api name: "{{ .Env.GITHUB_REPO_NAME }}"
draft: false draft: false
prerelease: auto prerelease: auto
name_template: "Sub2API {{.Version}}" name_template: "Sub2API {{.Version}}"
@@ -73,7 +121,7 @@ release:
**One-line install (Linux):** **One-line install (Linux):**
```bash ```bash
curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install.sh | sudo bash curl -sSL https://raw.githubusercontent.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/main/deploy/install.sh | sudo bash
``` ```
**Manual download:** **Manual download:**
@@ -81,5 +129,5 @@ release:
## 📚 Documentation ## 📚 Documentation
- [GitHub Repository](https://github.com/Wei-Shaw/sub2api) - [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }})
- [Installation Guide](https://github.com/Wei-Shaw/sub2api/blob/main/deploy/README.md) - [Installation Guide](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}/blob/main/deploy/README.md)

40
Dockerfile.goreleaser Normal file
View File

@@ -0,0 +1,40 @@
# =============================================================================
# Sub2API Dockerfile for GoReleaser
# =============================================================================
# This Dockerfile is used by GoReleaser to build Docker images.
# It only packages the pre-built binary, no compilation needed.
# =============================================================================
FROM alpine:3.19
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
LABEL description="Sub2API - AI API Gateway Platform"
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
# Install runtime dependencies
RUN apk add --no-cache \
ca-certificates \
tzdata \
curl \
&& rm -rf /var/cache/apk/*
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
WORKDIR /app
# Copy pre-built binary from GoReleaser
COPY sub2api /app/sub2api
# Create data directory
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
USER sub2api
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
ENTRYPOINT ["/app/sub2api"]

View File

@@ -17,10 +17,17 @@ linters:
service-no-repository: service-no-repository:
list-mode: original list-mode: original
files: files:
- internal/service/** - "**/internal/service/**"
deny: deny:
- pkg: sub2api/internal/repository - pkg: sub2api/internal/repository
desc: "service must not import repository" desc: "service must not import repository"
handler-no-repository:
list-mode: original
files:
- "**/internal/handler/**"
deny:
- pkg: sub2api/internal/repository
desc: "handler must not import repository"
errcheck: errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`. # Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default. # Such cases aren't reported by default.

View File

@@ -15,11 +15,11 @@ import (
"syscall" "syscall"
"time" "time"
"sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"sub2api/internal/middleware" "github.com/Wei-Shaw/sub2api/internal/middleware"
"sub2api/internal/setup" "github.com/Wei-Shaw/sub2api/internal/setup"
"sub2api/internal/web" "github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -4,12 +4,12 @@
package main package main
import ( import (
"sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/infrastructure"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/repository"
"sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"context" "context"
"log" "log"

View File

@@ -8,17 +8,17 @@ package main
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
"log" "log"
"net/http" "net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/handler/admin"
"sub2api/internal/infrastructure"
"sub2api/internal/repository"
"sub2api/internal/server"
"sub2api/internal/service"
"time" "time"
) )
@@ -58,7 +58,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(db) usageLogRepository := repository.NewUsageLogRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository) usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db) redeemCodeRepository := repository.NewRedeemCodeRepository(db)
billingCache := repository.NewBillingCache(client) billingCache := repository.NewBillingCache(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
@@ -67,7 +67,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
redeemHandler := handler.NewRedeemHandler(redeemService) redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardHandler := admin.NewDashboardHandler(usageLogRepository) dashboardService := service.NewDashboardService(usageLogRepository)
dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(db) accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db) proxyRepository := repository.NewProxyRepository(db)
proxyExitInfoProber := repository.NewProxyExitInfoProber() proxyExitInfoProber := repository.NewProxyExitInfoProber()
@@ -83,7 +84,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
httpUpstream := repository.NewHTTPUpstream(configConfig) httpUpstream := repository.NewHTTPUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, usageLogRepository) concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService) oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
proxyHandler := admin.NewProxyHandler(adminService) proxyHandler := admin.NewProxyHandler(adminService)
@@ -95,7 +99,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService) systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService) adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
gatewayCache := repository.NewGatewayCache(client) gatewayCache := repository.NewGatewayCache(client)
pricingRemoteClient := repository.NewPricingRemoteClient() pricingRemoteClient := repository.NewPricingRemoteClient()
@@ -107,8 +111,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityCache := repository.NewIdentityCache(client) identityCache := repository.NewIdentityCache(client)
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream) gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService) gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)

View File

@@ -1,38 +0,0 @@
server:
host: "0.0.0.0"
port: 8080
mode: "debug" # debug/release
database:
host: "127.0.0.1"
port: 5432
user: "postgres"
password: "XZeRr7nkjHWhm8fw"
dbname: "sub2api"
sslmode: "disable"
redis:
host: "127.0.0.1"
port: 6379
password: ""
db: 0
jwt:
secret: "your-secret-key-change-in-production"
expire_hour: 24
default:
admin_email: "admin@sub2api.com"
admin_password: "admin123"
user_concurrency: 5
user_balance: 0
api_key_prefix: "sk-"
rate_multiplier: 1.0
# Timezone configuration (similar to PHP's date_default_timezone_set)
# This affects ALL time operations:
# - Database timestamps
# - Usage statistics "today" boundary
# - Subscription expiry times
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
timezone: "Asia/Shanghai"

View File

@@ -1,4 +1,4 @@
module sub2api module github.com/Wei-Shaw/sub2api
go 1.24.0 go 1.24.0
@@ -8,10 +8,13 @@ require (
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.0 github.com/golang-jwt/jwt/v5 v5.2.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/google/wire v0.7.0
github.com/imroc/req/v3 v3.56.0 github.com/imroc/req/v3 v3.56.0
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/redis/go-redis/v9 v9.3.0 github.com/redis/go-redis/v9 v9.3.0
github.com/spf13/viper v1.18.2 github.com/spf13/viper v1.18.2
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
golang.org/x/crypto v0.44.0 golang.org/x/crypto v0.44.0
golang.org/x/net v0.47.0 golang.org/x/net v0.47.0
golang.org/x/term v0.37.0 golang.org/x/term v0.37.0
@@ -35,7 +38,6 @@ require (
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect github.com/google/subcommands v1.2.0 // indirect
github.com/google/wire v0.7.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -64,6 +66,8 @@ require (
github.com/spf13/cast v1.6.0 // indirect github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
go.uber.org/atomic v1.9.0 // indirect go.uber.org/atomic v1.9.0 // indirect

View File

@@ -139,6 +139,15 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=

View File

@@ -3,12 +3,12 @@ package admin
import ( import (
"strconv" "strconv"
"sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -33,11 +33,21 @@ type AccountHandler struct {
rateLimitService *service.RateLimitService rateLimitService *service.RateLimitService
accountUsageService *service.AccountUsageService accountUsageService *service.AccountUsageService
accountTestService *service.AccountTestService accountTestService *service.AccountTestService
usageLogRepo *repository.UsageLogRepository concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
} }
// NewAccountHandler creates a new admin account handler // NewAccountHandler creates a new admin account handler
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, openaiOAuthService *service.OpenAIOAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService, usageLogRepo *repository.UsageLogRepository) *AccountHandler { func NewAccountHandler(
adminService service.AdminService,
oauthService *service.OAuthService,
openaiOAuthService *service.OpenAIOAuthService,
rateLimitService *service.RateLimitService,
accountUsageService *service.AccountUsageService,
accountTestService *service.AccountTestService,
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
) *AccountHandler {
return &AccountHandler{ return &AccountHandler{
adminService: adminService, adminService: adminService,
oauthService: oauthService, oauthService: oauthService,
@@ -45,7 +55,8 @@ func NewAccountHandler(adminService service.AdminService, oauthService *service.
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
accountUsageService: accountUsageService, accountUsageService: accountUsageService,
accountTestService: accountTestService, accountTestService: accountTestService,
usageLogRepo: usageLogRepo, concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
} }
} }
@@ -76,6 +87,25 @@ type UpdateAccountRequest struct {
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
} }
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*model.Account
CurrentConcurrency int `json:"current_concurrency"`
}
// List handles listing all accounts with pagination // List handles listing all accounts with pagination
// GET /api/v1/admin/accounts // GET /api/v1/admin/accounts
func (h *AccountHandler) List(c *gin.Context) { func (h *AccountHandler) List(c *gin.Context) {
@@ -91,7 +121,28 @@ func (h *AccountHandler) List(c *gin.Context) {
return return
} }
response.Paginated(c, accounts, total, page, pageSize) // Get current concurrency counts for all accounts
accountIDs := make([]int64, len(accounts))
for i, acc := range accounts {
accountIDs[i] = acc.ID
}
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
if err != nil {
// Log error but don't fail the request, just use 0 for all
concurrencyCounts = make(map[int64]int)
}
// Build response with concurrency info
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
result[i] = AccountWithConcurrency{
Account: &accounts[i],
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
}
}
response.Paginated(c, result, total, page, pageSize)
} }
// GetByID handles getting an account by ID // GetByID handles getting an account by ID
@@ -197,6 +248,13 @@ type TestAccountRequest struct {
ModelID string `json:"model_id"` ModelID string `json:"model_id"`
} }
type SyncFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
}
// Test handles testing account connectivity with SSE streaming // Test handles testing account connectivity with SSE streaming
// POST /api/v1/admin/accounts/:id/test // POST /api/v1/admin/accounts/:id/test
func (h *AccountHandler) Test(c *gin.Context) { func (h *AccountHandler) Test(c *gin.Context) {
@@ -217,6 +275,35 @@ func (h *AccountHandler) Test(c *gin.Context) {
} }
} }
// SyncFromCRS handles syncing accounts from claude-relay-service (CRS)
// POST /api/v1/admin/accounts/sync/crs
func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
var req SyncFromCRSRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Default to syncing proxies (can be disabled by explicitly setting false)
syncProxies := true
if req.SyncProxies != nil {
syncProxies = *req.SyncProxies
}
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
})
if err != nil {
response.BadRequest(c, "Sync failed: "+err.Error())
return
}
response.Success(c, result)
}
// Refresh handles refreshing account credentials // Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh // POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) { func (h *AccountHandler) Refresh(c *gin.Context) {
@@ -314,7 +401,7 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
endTime := timezone.StartOfDay(now.AddDate(0, 0, 1)) endTime := timezone.StartOfDay(now.AddDate(0, 0, 1))
startTime := timezone.StartOfDay(now.AddDate(0, 0, -days+1)) startTime := timezone.StartOfDay(now.AddDate(0, 0, -days+1))
stats, err := h.usageLogRepo.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime) stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get account stats: "+err.Error()) response.InternalError(c, "Failed to get account stats: "+err.Error())
return return
@@ -360,6 +447,136 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
}) })
} }
// BatchUpdateCredentialsRequest represents batch credentials update request
type BatchUpdateCredentialsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Field string `json:"field" binding:"required,oneof=account_uuid org_uuid intercept_warmup_requests"`
Value any `json:"value"`
}
// BatchUpdateCredentials handles batch updating credentials fields
// POST /api/v1/admin/accounts/batch-update-credentials
func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
var req BatchUpdateCredentialsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Validate value type based on field
if req.Field == "intercept_warmup_requests" {
// Must be boolean
if _, ok := req.Value.(bool); !ok {
response.BadRequest(c, "intercept_warmup_requests must be boolean")
return
}
} else {
// account_uuid and org_uuid can be string or null
if req.Value != nil {
if _, ok := req.Value.(string); !ok {
response.BadRequest(c, req.Field+" must be string or null")
return
}
}
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
// Update account
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"account_id": accountID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
})
}
// BulkUpdate handles bulk updating accounts with selected fields/credentials.
// POST /api/v1/admin/accounts/bulk-update
func (h *AccountHandler) BulkUpdate(c *gin.Context) {
var req BulkUpdateAccountsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
hasUpdates := req.Name != "" ||
req.ProxyID != nil ||
req.Concurrency != nil ||
req.Priority != nil ||
req.Status != "" ||
req.GroupIDs != nil ||
len(req.Credentials) > 0 ||
len(req.Extra) > 0
if !hasUpdates {
response.BadRequest(c, "No updates provided")
return
}
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs,
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: req.Status,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
})
if err != nil {
response.InternalError(c, "Failed to bulk update accounts: "+err.Error())
return
}
response.Success(c, result)
}
// ========== OAuth Handlers ========== // ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL // GenerateAuthURLRequest represents the request for generating auth URL

View File

@@ -1,10 +1,10 @@
package admin package admin
import ( import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"strconv" "strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -12,15 +12,15 @@ import (
// DashboardHandler handles admin dashboard statistics // DashboardHandler handles admin dashboard statistics
type DashboardHandler struct { type DashboardHandler struct {
usageRepo *repository.UsageLogRepository dashboardService *service.DashboardService
startTime time.Time // Server start time for uptime calculation startTime time.Time // Server start time for uptime calculation
} }
// NewDashboardHandler creates a new admin dashboard handler // NewDashboardHandler creates a new admin dashboard handler
func NewDashboardHandler(usageRepo *repository.UsageLogRepository) *DashboardHandler { func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
return &DashboardHandler{ return &DashboardHandler{
usageRepo: usageRepo, dashboardService: dashboardService,
startTime: time.Now(), startTime: time.Now(),
} }
} }
@@ -58,7 +58,7 @@ func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
// GetStats handles getting dashboard statistics // GetStats handles getting dashboard statistics
// GET /api/v1/admin/dashboard/stats // GET /api/v1/admin/dashboard/stats
func (h *DashboardHandler) GetStats(c *gin.Context) { func (h *DashboardHandler) GetStats(c *gin.Context) {
stats, err := h.usageRepo.GetDashboardStats(c.Request.Context()) stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics") response.Error(c, 500, "Failed to get dashboard statistics")
return return
@@ -107,6 +107,10 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
// 系统运行统计 // 系统运行统计
"average_duration_ms": stats.AverageDurationMs, "average_duration_ms": stats.AverageDurationMs,
"uptime": uptime, "uptime": uptime,
// 性能指标
"rpm": stats.Rpm,
"tpm": stats.Tpm,
}) })
} }
@@ -142,7 +146,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
} }
} }
trend, err := h.usageRepo.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID) trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get usage trend") response.Error(c, 500, "Failed to get usage trend")
return return
@@ -175,7 +179,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
} }
} }
stats, err := h.usageRepo.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, 0) stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get model statistics") response.Error(c, 500, "Failed to get model statistics")
return return
@@ -200,7 +204,7 @@ func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
limit = 5 limit = 5
} }
trend, err := h.usageRepo.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit) trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get API key usage trend") response.Error(c, 500, "Failed to get API key usage trend")
return return
@@ -226,7 +230,7 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
limit = 12 limit = 12
} }
trend, err := h.usageRepo.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit) trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get user usage trend") response.Error(c, 500, "Failed to get user usage trend")
return return
@@ -259,7 +263,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return return
} }
stats, err := h.usageRepo.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs) stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get user usage stats") response.Error(c, 500, "Failed to get user usage stats")
return return
@@ -287,7 +291,7 @@ func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
return return
} }
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs) stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get API key usage stats") response.Error(c, 500, "Failed to get API key usage stats")
return return

View File

@@ -3,9 +3,9 @@ package admin
import ( import (
"strconv" "strconv"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -3,8 +3,8 @@ package admin
import ( import (
"strconv" "strconv"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -4,8 +4,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -6,8 +6,8 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -1,9 +1,9 @@
package admin package admin
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -60,6 +60,7 @@ type UpdateSettingsRequest struct {
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"` ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"` ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置 // 默认配置
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
@@ -104,6 +105,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
SiteSubtitle: req.SiteSubtitle, SiteSubtitle: req.SiteSubtitle,
ApiBaseUrl: req.ApiBaseUrl, ApiBaseUrl: req.ApiBaseUrl,
ContactInfo: req.ContactInfo, ContactInfo: req.ContactInfo,
DocUrl: req.DocUrl,
DefaultConcurrency: req.DefaultConcurrency, DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance, DefaultBalance: req.DefaultBalance,
} }

View File

@@ -3,10 +3,10 @@ package admin
import ( import (
"strconv" "strconv"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -4,9 +4,9 @@ import (
"net/http" "net/http"
"time" "time"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/pkg/sysutil" "github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -4,35 +4,32 @@ import (
"strconv" "strconv"
"time" "time"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// UsageHandler handles admin usage-related requests // UsageHandler handles admin usage-related requests
type UsageHandler struct { type UsageHandler struct {
usageRepo *repository.UsageLogRepository usageService *service.UsageService
apiKeyRepo *repository.ApiKeyRepository apiKeyService *service.ApiKeyService
usageService *service.UsageService adminService service.AdminService
adminService service.AdminService
} }
// NewUsageHandler creates a new admin usage handler // NewUsageHandler creates a new admin usage handler
func NewUsageHandler( func NewUsageHandler(
usageRepo *repository.UsageLogRepository,
apiKeyRepo *repository.ApiKeyRepository,
usageService *service.UsageService, usageService *service.UsageService,
apiKeyService *service.ApiKeyService,
adminService service.AdminService, adminService service.AdminService,
) *UsageHandler { ) *UsageHandler {
return &UsageHandler{ return &UsageHandler{
usageRepo: usageRepo, usageService: usageService,
apiKeyRepo: apiKeyRepo, apiKeyService: apiKeyService,
usageService: usageService, adminService: adminService,
adminService: adminService,
} }
} }
@@ -84,14 +81,14 @@ func (h *UsageHandler) List(c *gin.Context) {
} }
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := repository.UsageLogFilters{ filters := usagestats.UsageLogFilters{
UserID: userID, UserID: userID,
ApiKeyID: apiKeyID, ApiKeyID: apiKeyID,
StartTime: startTime, StartTime: startTime,
EndTime: endTime, EndTime: endTime,
} }
records, result, err := h.usageRepo.ListWithFilters(c.Request.Context(), params, filters) records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error()) response.InternalError(c, "Failed to list usage records: "+err.Error())
return return
@@ -179,7 +176,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
} }
// Get global stats // Get global stats
stats, err := h.usageRepo.GetGlobalStats(c.Request.Context(), startTime, endTime) stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error()) response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return return
@@ -237,7 +234,7 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
userID = id userID = id
} }
keys, err := h.apiKeyRepo.SearchApiKeys(c.Request.Context(), userID, keyword, 30) keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil { if err != nil {
response.InternalError(c, "Failed to search API keys: "+err.Error()) response.InternalError(c, "Failed to search API keys: "+err.Error())
return return

View File

@@ -3,8 +3,8 @@ package admin
import ( import (
"strconv" "strconv"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -49,8 +49,9 @@ type UpdateUserRequest struct {
// UpdateBalanceRequest represents balance update request // UpdateBalanceRequest represents balance update request
type UpdateBalanceRequest struct { type UpdateBalanceRequest struct {
Balance float64 `json:"balance" binding:"required"` Balance float64 `json:"balance" binding:"required,gt=0"`
Operation string `json:"operation" binding:"required,oneof=set add subtract"` Operation string `json:"operation" binding:"required,oneof=set add subtract"`
Notes string `json:"notes"`
} }
// List handles listing all users with pagination // List handles listing all users with pagination
@@ -183,7 +184,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return return
} }
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation) user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
if err != nil { if err != nil {
response.InternalError(c, "Failed to update balance: "+err.Error()) response.InternalError(c, "Failed to update balance: "+err.Error())
return return

View File

@@ -3,10 +3,10 @@ package handler
import ( import (
"strconv" "strconv"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -1,9 +1,9 @@
package handler package handler
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -10,11 +10,11 @@ import (
"strings" "strings"
"time" "time"
"sub2api/internal/middleware" "github.com/Wei-Shaw/sub2api/internal/middleware"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -6,8 +6,8 @@ import (
"net/http" "net/http"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -1,7 +1,7 @@
package handler package handler
import ( import (
"sub2api/internal/handler/admin" "github.com/Wei-Shaw/sub2api/internal/handler/admin"
) )
// AdminHandlers contains all admin-related HTTP handlers // AdminHandlers contains all admin-related HTTP handlers

View File

@@ -9,9 +9,9 @@ import (
"net/http" "net/http"
"time" "time"
"sub2api/internal/middleware" "github.com/Wei-Shaw/sub2api/internal/middleware"
"sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -1,9 +1,9 @@
package handler package handler
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -1,8 +1,8 @@
package handler package handler
import ( import (
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -1,9 +1,9 @@
package handler package handler
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -4,12 +4,11 @@ import (
"strconv" "strconv"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/service"
"sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -17,15 +16,13 @@ import (
// UsageHandler handles usage-related requests // UsageHandler handles usage-related requests
type UsageHandler struct { type UsageHandler struct {
usageService *service.UsageService usageService *service.UsageService
usageRepo *repository.UsageLogRepository
apiKeyService *service.ApiKeyService apiKeyService *service.ApiKeyService
} }
// NewUsageHandler creates a new UsageHandler // NewUsageHandler creates a new UsageHandler
func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.UsageLogRepository, apiKeyService *service.ApiKeyService) *UsageHandler { func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
return &UsageHandler{ return &UsageHandler{
usageService: usageService, usageService: usageService,
usageRepo: usageRepo,
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
} }
} }
@@ -260,7 +257,7 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
return return
} }
stats, err := h.usageRepo.GetUserDashboardStats(c.Request.Context(), user.ID) stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get dashboard statistics") response.InternalError(c, "Failed to get dashboard statistics")
return return
@@ -287,7 +284,7 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
startTime, endTime := parseUserTimeRange(c) startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day") granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity) trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage trend") response.InternalError(c, "Failed to get usage trend")
return return
@@ -318,7 +315,7 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
startTime, endTime := parseUserTimeRange(c) startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageRepo.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime) stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get model statistics") response.InternalError(c, "Failed to get model statistics")
return return
@@ -387,7 +384,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
return return
} }
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs) stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get API key usage stats") response.InternalError(c, "Failed to get API key usage stats")
return return

View File

@@ -1,9 +1,9 @@
package handler package handler
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -1,8 +1,8 @@
package handler package handler
import ( import (
"sub2api/internal/handler/admin" "github.com/Wei-Shaw/sub2api/internal/handler/admin"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire" "github.com/google/wire"
) )

View File

@@ -1,9 +1,9 @@
package infrastructure package infrastructure
import ( import (
"sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"

View File

@@ -1,7 +1,7 @@
package infrastructure package infrastructure
import ( import (
"sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )

View File

@@ -1,7 +1,7 @@
package infrastructure package infrastructure
import ( import (
"sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire" "github.com/google/wire"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"

View File

@@ -3,9 +3,9 @@ package middleware
import ( import (
"context" "context"
"crypto/subtle" "crypto/subtle"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"strings" "strings"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -1,7 +1,7 @@
package middleware package middleware
import ( import (
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -3,9 +3,9 @@ package middleware
import ( import (
"context" "context"
"errors" "errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"log" "log"
"strings" "strings"
"sub2api/internal/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"

View File

@@ -2,9 +2,9 @@ package middleware
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"strings" "strings"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -14,6 +14,7 @@ type RedeemCode struct {
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"` UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"` UsedAt *time.Time `json:"used_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"` CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段 // 订阅类型专用字段

View File

@@ -42,6 +42,7 @@ const (
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入 SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式 SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
@@ -80,6 +81,7 @@ type SystemSettings struct {
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"` ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"` ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置 // 默认配置
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
@@ -97,5 +99,6 @@ type PublicSettings struct {
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"` ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"` ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"` Version string `json:"version"`
} }

View File

@@ -0,0 +1,209 @@
package usagestats
import "time"
// DashboardStats 仪表盘统计
type DashboardStats struct {
// 用户统计
TotalUsers int64 `json:"total_users"`
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 系统运行统计
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// TrendDataPoint represents a single point in trend data
type TrendDataPoint struct {
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ModelStat represents usage statistics for a single model
type ModelStat struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct {
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct {
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 性能统计
AverageDurationMs float64 `json:"average_duration_ms"`
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct {
UserID int64
ApiKeyID int64
StartTime *time.Time
EndTime *time.Time
}
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct {
ApiKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
ActualCost float64 `json:"actual_cost"`
}
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary struct {
Days int `json:"days"`
ActualDaysUsed int `json:"actual_days_used"`
TotalCost float64 `json:"total_cost"`
TotalStandardCost float64 `json:"total_standard_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
AvgDailyCost float64 `json:"avg_daily_cost"`
AvgDailyRequests float64 `json:"avg_daily_requests"`
AvgDailyTokens float64 `json:"avg_daily_tokens"`
AvgDurationMs float64 `json:"avg_duration_ms"`
Today *struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
} `json:"today"`
HighestCostDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
} `json:"highest_cost_day"`
HighestRequestDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
} `json:"highest_request_day"`
}
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse struct {
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
}

View File

@@ -2,11 +2,14 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "errors"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type AccountRepository struct { type AccountRepository struct {
@@ -23,14 +26,34 @@ func (r *AccountRepository) Create(ctx context.Context, account *model.Account)
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups").First(&account, id).Error err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 填充 GroupIDs 虚拟字段 // 填充 GroupIDs 和 Groups 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups)) account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups { for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID) account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
return &account, nil
}
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
if crsAccountID == "" {
return nil, nil
}
var account model.Account
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
} }
return &account, nil return &account, nil
} }
@@ -303,3 +326,75 @@ func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("schedulable", schedulable).Error Update("schedulable", schedulable).Error
} }
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
}
// Get current account to preserve existing Extra data
var account model.Account
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
return err
}
// Initialize Extra if nil
if account.Extra == nil {
account.Extra = make(model.JSONB)
}
// Merge updates into existing Extra
for k, v := range updates {
account.Extra[k] = v
}
// Save updated Extra
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("extra", account.Extra).Error
}
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ports.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
updateMap := map[string]any{}
if updates.Name != nil {
updateMap["name"] = *updates.Name
}
if updates.ProxyID != nil {
updateMap["proxy_id"] = updates.ProxyID
}
if updates.Concurrency != nil {
updateMap["concurrency"] = *updates.Concurrency
}
if updates.Priority != nil {
updateMap["priority"] = *updates.Priority
}
if updates.Status != nil {
updateMap["status"] = *updates.Status
}
if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
}
if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
}
if len(updateMap) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).
Model(&model.Account{}).
Where("id IN ?", ids).
Clauses(clause.Returning{}).
Updates(updateMap)
return result.RowsAffected, result.Error
}

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )

View File

@@ -2,8 +2,8 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )

View File

@@ -8,7 +8,7 @@ import (
"strconv" "strconv"
"time" "time"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )

View File

@@ -7,10 +7,11 @@ import (
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"time" "time"
"sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
) )
@@ -139,20 +140,12 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) { func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL) client := createReqClient(proxyURL)
// Parse code which may contain state in format "authCode#state"
authCode := code authCode := code
codeState := "" codeState := ""
if len(code) > 0 { if idx := strings.Index(code, "#"); idx != -1 {
parts := make([]string, 0, 2) authCode = code[:idx]
for i, part := range []rune(code) { codeState = code[idx+1:]
if part == '#' {
authCode = code[:i]
codeState = code[i+1:]
break
}
}
if len(parts) == 0 {
authCode = code
}
} }
reqBody := map[string]any{ reqBody := map[string]any{

View File

@@ -9,7 +9,7 @@ import (
"net/url" "net/url"
"time" "time"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
type claudeUsageService struct{} type claudeUsageService struct{}

View File

@@ -5,60 +5,95 @@ import (
"fmt" "fmt"
"time" "time"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
const ( const (
accountConcurrencyKeyPrefix = "concurrency:account:" // Key prefixes for independent slot keys
userConcurrencyKeyPrefix = "concurrency:user:" // Format: concurrency:account:{accountID}:{requestID}
waitQueueKeyPrefix = "concurrency:wait:" accountSlotKeyPrefix = "concurrency:account:"
concurrencyTTL = 5 * time.Minute // Format: concurrency:user:{userID}:{requestID}
userSlotKeyPrefix = "concurrency:user:"
// Wait queue keeps counter format: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// Slot TTL - each slot expires independently
slotTTL = 5 * time.Minute
) )
var ( var (
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
// ARGV[1] = maxConcurrency
// ARGV[2] = TTL in seconds
acquireScript = redis.NewScript(` acquireScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1]) local pattern = KEYS[1]
if current == false then local slotKey = KEYS[2]
current = 0 local maxConcurrency = tonumber(ARGV[1])
else local ttl = tonumber(ARGV[2])
current = tonumber(current)
end -- Count existing slots using SCAN
if current < tonumber(ARGV[1]) then local cursor = "0"
redis.call('INCR', KEYS[1]) local count = 0
redis.call('EXPIRE', KEYS[1], ARGV[2]) repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
-- Check if we can acquire a slot
if count < maxConcurrency then
redis.call('SET', slotKey, '1', 'EX', ttl)
return 1 return 1
end end
return 0 return 0
`) `)
releaseScript = redis.NewScript(` // getCountScript counts slots using SCAN
local current = redis.call('GET', KEYS[1]) // KEYS[1] = pattern for SCAN
if current ~= false and tonumber(current) > 0 then getCountScript = redis.NewScript(`
redis.call('DECR', KEYS[1]) local pattern = KEYS[1]
end local cursor = "0"
return 1 local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
return count
`) `)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
incrementWaitScript = redis.NewScript(` incrementWaitScript = redis.NewScript(`
local waitKey = KEYS[1] local current = redis.call('GET', KEYS[1])
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
if current == false then if current == false then
current = 0 current = 0
else else
current = tonumber(current) current = tonumber(current)
end end
if current >= maxWait then
if current >= tonumber(ARGV[1]) then
return 0 return 0
end end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl) local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1 return 1
`) `)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(` decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1]) local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then if current ~= false and tonumber(current) > 0 then
@@ -76,49 +111,86 @@ func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
return &concurrencyCache{rdb: rdb} return &concurrencyCache{rdb: rdb}
} }
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) { // Helper functions for key generation
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) func accountSlotKey(accountID int64, requestID string) string {
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int() return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
}
func accountSlotPattern(accountID int64) string {
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
}
func userSlotKey(userID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
}
func userSlotPattern(userID int64) string {
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
}
func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := accountSlotPattern(accountID)
slotKey := accountSlotKey(accountID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil { if err != nil {
return false, err return false, err
} }
return result == 1, nil return result == 1, nil
} }
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error { func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) slotKey := accountSlotKey(accountID, requestID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result() return c.rdb.Del(ctx, slotKey).Err()
return err
} }
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID) pattern := accountSlotPattern(accountID)
return c.rdb.Get(ctx, key).Int() result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
if err != nil {
return 0, err
}
return result, nil
} }
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) { // User slot operations
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int() func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := userSlotPattern(userID)
slotKey := userSlotKey(userID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil { if err != nil {
return false, err return false, err
} }
return result == 1, nil return result == 1, nil
} }
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error { func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID) slotKey := userSlotKey(userID, requestID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result() return c.rdb.Del(ctx, slotKey).Err()
return err
} }
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID) pattern := userSlotPattern(userID)
return c.rdb.Get(ctx, key).Int() result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
if err != nil {
return 0, err
}
return result, nil
} }
// Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int() result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -126,7 +198,7 @@ func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64,
} }
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error { func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) key := waitQueueKey(userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err return err
} }

View File

@@ -5,7 +5,7 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"time" "time"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )

View File

@@ -9,7 +9,7 @@ import (
"os" "os"
"time" "time"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
type githubReleaseClient struct { type githubReleaseClient struct {

View File

@@ -2,8 +2,8 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )

View File

@@ -5,8 +5,8 @@ import (
"net/url" "net/url"
"time" "time"
"sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
// httpUpstreamService is a generic HTTP upstream service that can be used for // httpUpstreamService is a generic HTTP upstream service that can be used for

View File

@@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"time" "time"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )

View File

@@ -6,8 +6,8 @@ import (
"net/url" "net/url"
"time" "time"
"sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
) )

View File

@@ -8,7 +8,7 @@ import (
"strings" "strings"
"time" "time"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
type pricingRemoteClient struct { type pricingRemoteClient struct {

View File

@@ -11,7 +11,7 @@ import (
"net/url" "net/url"
"time" "time"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
) )

View File

@@ -2,8 +2,8 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )

View File

@@ -2,8 +2,8 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"

View File

@@ -2,7 +2,7 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"

View File

@@ -9,7 +9,7 @@ import (
"strings" "strings"
"time" "time"
"sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify" const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"time" "time"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )

View File

@@ -2,10 +2,10 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
@@ -19,6 +19,30 @@ func NewUsageLogRepository(db *gorm.DB) *UsageLogRepository {
return &UsageLogRepository{db: db} return &UsageLogRepository{db: db}
} }
// getPerformanceStats 获取 RPM 和 TPM近5分钟平均值可选按用户过滤
func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
var perfStats struct {
RequestCount int64 `gorm:"column:request_count"`
TokenCount int64 `gorm:"column:token_count"`
}
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
`).
Where("created_at >= ?", fiveMinutesAgo)
if userID > 0 {
db = db.Where("user_id = ?", userID)
}
db.Scan(&perfStats)
// 返回5分钟平均值
return perfStats.RequestCount / 5, perfStats.TokenCount / 5
}
func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error { func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error return r.db.WithContext(ctx).Create(log).Error
} }
@@ -113,46 +137,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
} }
// DashboardStats 仪表盘统计 // DashboardStats 仪表盘统计
type DashboardStats struct { type DashboardStats = usagestats.DashboardStats
// 用户统计
TotalUsers int64 `json:"total_users"`
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 系统运行统计
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
}
func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
var stats DashboardStats var stats DashboardStats
@@ -269,6 +254,9 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
stats.TodayCost = todayStats.TodayCost stats.TodayCost = todayStats.TodayCost
stats.TodayActualCost = todayStats.TodayActualCost stats.TodayActualCost = todayStats.TodayActualCost
// 性能指标RPM 和 TPM最近1分钟全局
stats.Rpm, stats.Tpm = r.getPerformanceStats(ctx, 0)
return &stats, nil return &stats, nil
} }
@@ -398,47 +386,16 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
} }
// TrendDataPoint represents a single point in trend data // TrendDataPoint represents a single point in trend data
type TrendDataPoint struct { type TrendDataPoint = usagestats.TrendDataPoint
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ModelStat represents usage statistics for a single model // ModelStat represents usage statistics for a single model
type ModelStat struct { type ModelStat = usagestats.ModelStat
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point // UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct { type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ApiKeyUsageTrendPoint represents API key usage trend data point // ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct { type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date // GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) { func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
@@ -531,34 +488,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
} }
// UserDashboardStats 用户仪表盘统计 // UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct { type UserDashboardStats = usagestats.UserDashboardStats
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 性能统计
AverageDurationMs float64 `json:"average_duration_ms"`
}
// GetUserDashboardStats 获取用户专属的仪表盘统计 // GetUserDashboardStats 获取用户专属的仪表盘统计
func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) { func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
@@ -641,6 +571,9 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
stats.TodayCost = todayStats.TodayCost stats.TodayCost = todayStats.TodayCost
stats.TodayActualCost = todayStats.TodayActualCost stats.TodayActualCost = todayStats.TodayActualCost
// 性能指标RPM 和 TPM最近1分钟仅统计该用户的请求
stats.Rpm, stats.Tpm = r.getPerformanceStats(ctx, userID)
return &stats, nil return &stats, nil
} }
@@ -705,12 +638,7 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64
} }
// UsageLogFilters represents filters for usage log queries // UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct { type UsageLogFilters = usagestats.UsageLogFilters
UserID int64
ApiKeyID int64
StartTime *time.Time
EndTime *time.Time
}
// ListWithFilters lists usage logs with optional filters (for admin) // ListWithFilters lists usage logs with optional filters (for admin)
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
@@ -758,23 +686,10 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params paginat
} }
// UsageStats represents usage statistics // UsageStats represents usage statistics
type UsageStats struct { type UsageStats = usagestats.UsageStats
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user // BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct { type BatchUserUsageStats = usagestats.BatchUserUsageStats
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// GetBatchUserUsageStats gets today and total actual_cost for multiple users // GetBatchUserUsageStats gets today and total actual_cost for multiple users
func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
@@ -834,11 +749,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
} }
// BatchApiKeyUsageStats represents usage stats for a single API key // BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct { type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
ApiKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys // GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) { func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
@@ -1012,53 +923,13 @@ func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
} }
// AccountUsageHistory represents daily usage history for an account // AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory struct { type AccountUsageHistory = usagestats.AccountUsageHistory
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
ActualCost float64 `json:"actual_cost"`
}
// AccountUsageSummary represents summary statistics for an account // AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary struct { type AccountUsageSummary = usagestats.AccountUsageSummary
Days int `json:"days"`
ActualDaysUsed int `json:"actual_days_used"`
TotalCost float64 `json:"total_cost"`
TotalStandardCost float64 `json:"total_standard_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
AvgDailyCost float64 `json:"avg_daily_cost"`
AvgDailyRequests float64 `json:"avg_daily_requests"`
AvgDailyTokens float64 `json:"avg_daily_tokens"`
AvgDurationMs float64 `json:"avg_duration_ms"`
Today *struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
} `json:"today"`
HighestCostDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
} `json:"highest_cost_day"`
HighestRequestDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
} `json:"highest_request_day"`
}
// AccountUsageStatsResponse represents the full usage statistics response for an account // AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse struct { type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
}
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range // GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) { func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {

View File

@@ -2,8 +2,8 @@ package repository
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )

View File

@@ -4,8 +4,8 @@ import (
"context" "context"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )

View File

@@ -1,7 +1,7 @@
package repository package repository
import ( import (
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/google/wire" "github.com/google/wire"
) )

View File

@@ -1,11 +1,11 @@
package server package server
import ( import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"net/http" "net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/repository"
"sub2api/internal/service"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"

View File

@@ -1,13 +1,13 @@
package server package server
import ( import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/web"
"net/http" "net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/middleware"
"sub2api/internal/repository"
"sub2api/internal/service"
"sub2api/internal/web"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -180,6 +180,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
accounts.GET("", h.Admin.Account.List) accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID) accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create) accounts.POST("", h.Admin.Account.Create)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update) accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete) accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test) accounts.POST("/:id/test", h.Admin.Account.Test)
@@ -192,6 +193,8 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels) accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate) accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// Claude OAuth routes // Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL) accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)

View File

@@ -4,9 +4,9 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"gorm.io/gorm" "gorm.io/gorm"
) )

View File

@@ -14,10 +14,10 @@ import (
"strings" "strings"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
@@ -388,12 +388,14 @@ func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
"stream": true, "stream": true,
} }
// OAuth accounts using ChatGPT internal API require store: false and instructions // OAuth accounts using ChatGPT internal API require store: false
if isOAuth { if isOAuth {
payload["store"] = false payload["store"] = false
payload["instructions"] = openai.DefaultInstructions
} }
// All accounts require instructions for Responses API
payload["instructions"] = openai.DefaultInstructions
return payload return payload
} }

View File

@@ -7,8 +7,9 @@ import (
"sync" "sync"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
// usageCache 用于缓存usage数据 // usageCache 用于缓存usage数据
@@ -176,6 +177,14 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
}, nil }, nil
} }
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get account usage stats failed: %w", err)
}
return stats, nil
}
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量 // fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) { func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")

View File

@@ -7,9 +7,9 @@ import (
"log" "log"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -22,7 +22,7 @@ type AdminService interface {
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
@@ -45,6 +45,7 @@ type AdminService interface {
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
ClearAccountError(ctx context.Context, id int64) (*model.Account, error) ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
// Proxy management // Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
@@ -140,6 +141,33 @@ type UpdateAccountInput struct {
GroupIDs *[]int64 GroupIDs *[]int64
} }
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
Name string
ProxyID *int64
Concurrency *int
Priority *int
Status string
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
}
// BulkUpdateAccountResult captures the result for a single account update.
type BulkUpdateAccountResult struct {
AccountID int64 `json:"account_id"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
Failed int `json:"failed"`
Results []BulkUpdateAccountResult `json:"results"`
}
type CreateProxyInput struct { type CreateProxyInput struct {
Name string Name string
Protocol string Protocol string
@@ -271,8 +299,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
return nil, errors.New("cannot disable admin user") return nil, errors.New("cannot disable admin user")
} }
// Track balance and concurrency changes for logging
oldBalance := user.Balance
oldConcurrency := user.Concurrency oldConcurrency := user.Concurrency
if input.Email != "" { if input.Email != "" {
@@ -284,7 +310,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
} }
} }
// 更新用户字段
if input.Username != nil { if input.Username != nil {
user.Username = *input.Username user.Username = *input.Username
} }
@@ -295,22 +320,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.Notes = *input.Notes user.Notes = *input.Notes
} }
// Role is not allowed to be changed via API to prevent privilege escalation
if input.Status != "" { if input.Status != "" {
user.Status = input.Status user.Status = input.Status
} }
// 只在指针非 nil 时更新 Balance支持设置为 0
if input.Balance != nil {
user.Balance = *input.Balance
}
// 只在指针非 nil 时更新 Concurrency支持设置为任意值
if input.Concurrency != nil { if input.Concurrency != nil {
user.Concurrency = *input.Concurrency user.Concurrency = *input.Concurrency
} }
// 只在指针非 nil 时更新 AllowedGroups
if input.AllowedGroups != nil { if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups user.AllowedGroups = *input.AllowedGroups
} }
@@ -319,41 +336,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
return nil, err return nil, err
} }
// 余额变化时失效缓存
if input.Balance != nil && *input.Balance != oldBalance {
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, id); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", id, err)
}
}()
}
}
// Create adjustment records for balance/concurrency changes
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := model.GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{
Code: code,
Type: model.AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: model.StatusUsed,
UsedBy: &user.ID,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create balance adjustment redeem code: %v", err)
}
}
concurrencyDiff := user.Concurrency - oldConcurrency concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 { if concurrencyDiff != 0 {
code, err := model.GenerateRedeemCode() code, err := model.GenerateRedeemCode()
@@ -390,12 +372,14 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
return s.userRepo.Delete(ctx, id) return s.userRepo.Delete(ctx, id)
} }
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error) { func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
oldBalance := user.Balance
switch operation { switch operation {
case "set": case "set":
user.Balance = balance user.Balance = balance
@@ -405,11 +389,14 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
user.Balance -= balance user.Balance -= balance
} }
if user.Balance < 0 {
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
}
if err := s.userRepo.Update(ctx, user); err != nil { if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err return nil, err
} }
// 失效余额缓存
if s.billingCacheService != nil { if s.billingCacheService != nil {
go func() { go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@@ -420,6 +407,30 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
}() }()
} }
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := model.GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{
Code: code,
Type: model.AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: model.StatusUsed,
UsedBy: &user.ID,
Notes: notes,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
log.Printf("failed to create balance adjustment redeem code: %v", err)
}
}
return user, nil return user, nil
} }
@@ -711,6 +722,65 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
return account, nil return account, nil
} }
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
result := &BulkUpdateAccountsResult{
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
}
if len(input.AccountIDs) == 0 {
return result, nil
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := ports.AccountBulkUpdate{
Credentials: input.Credentials,
Extra: input.Extra,
}
if input.Name != "" {
repoUpdates.Name = &input.Name
}
if input.ProxyID != nil {
repoUpdates.ProxyID = input.ProxyID
}
if input.Concurrency != nil {
repoUpdates.Concurrency = input.Concurrency
}
if input.Priority != nil {
repoUpdates.Priority = input.Priority
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
// Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
return nil, err
}
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
entry.Success = true
result.Success++
result.Results = append(result.Results, entry)
}
return result, nil
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id) return s.accountRepo.Delete(ctx, id)
} }

View File

@@ -6,11 +6,11 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"time" "time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
@@ -455,3 +455,11 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.
// 标准类型分组:使用原有逻辑 // 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive) return user.CanBindGroup(group.ID, group.IsExclusive)
} }
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}
return keys, nil
}

View File

@@ -4,10 +4,10 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"log" "log"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/service/ports"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"

View File

@@ -7,8 +7,8 @@ import (
"log" "log"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
// 错误定义 // 错误定义

View File

@@ -2,9 +2,9 @@ package service
import ( import (
"fmt" "fmt"
"github.com/Wei-Shaw/sub2api/internal/config"
"log" "log"
"strings" "strings"
"sub2api/internal/config"
) )
// ModelPricing 模型价格配置per-token价格与LiteLLM格式一致 // ModelPricing 模型价格配置per-token价格与LiteLLM格式一致

View File

@@ -2,12 +2,26 @@ package service
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt"
"log" "log"
"time" "time"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
// generateRequestID generates a unique request ID for concurrency slot tracking
// Uses 8 random bytes (16 hex chars) for uniqueness
func generateRequestID() string {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
// Fallback to nanosecond timestamp (extremely rare case)
return fmt.Sprintf("%x", time.Now().UnixNano())
}
return hex.EncodeToString(b)
}
const ( const (
// Default extra wait slots beyond concurrency limit // Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots = 20 defaultExtraWaitSlots = 20
@@ -41,7 +55,10 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
}, nil }, nil
} }
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency) // Generate unique request ID for this slot
requestID := generateRequestID()
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency, requestID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -52,8 +69,8 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
ReleaseFunc: func() { ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil { if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil {
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err) log.Printf("Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
} }
}, },
}, nil }, nil
@@ -77,7 +94,10 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
}, nil }, nil
} }
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency) // Generate unique request ID for this slot
requestID := generateRequestID()
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency, requestID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -88,8 +108,8 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
ReleaseFunc: func() { ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil { if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil {
log.Printf("Warning: failed to release user slot for %d: %v", userID, err) log.Printf("Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
} }
}, },
}, nil }, nil
@@ -147,3 +167,20 @@ func CalculateMaxWait(userConcurrency int) int {
} }
return userConcurrency + defaultExtraWaitSlots return userConcurrency + defaultExtraWaitSlots
} }
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int)
for _, accountID := range accountIDs {
count, err := s.cache.GetAccountConcurrency(ctx, accountID)
if err != nil {
// If key doesn't exist in Redis, count is 0
count = 0
}
result[accountID] = count
}
return result, nil
}

View File

@@ -0,0 +1,961 @@
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
type CRSSyncService struct {
accountRepo ports.AccountRepository
proxyRepo ports.ProxyRepository
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
}
func NewCRSSyncService(
accountRepo ports.AccountRepository,
proxyRepo ports.ProxyRepository,
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
) *CRSSyncService {
return &CRSSyncService{
accountRepo: accountRepo,
proxyRepo: proxyRepo,
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
}
}
type SyncFromCRSInput struct {
BaseURL string
Username string
Password string
SyncProxies bool
}
type SyncFromCRSItemResult struct {
CRSAccountID string `json:"crs_account_id"`
Kind string `json:"kind"`
Name string `json:"name"`
Action string `json:"action"` // created/updated/failed/skipped
Error string `json:"error,omitempty"`
}
type SyncFromCRSResult struct {
Created int `json:"created"`
Updated int `json:"updated"`
Skipped int `json:"skipped"`
Failed int `json:"failed"`
Items []SyncFromCRSItemResult `json:"items"`
}
type crsLoginResponse struct {
Success bool `json:"success"`
Token string `json:"token"`
Message string `json:"message"`
Error string `json:"error"`
Username string `json:"username"`
}
type crsExportResponse struct {
Success bool `json:"success"`
Error string `json:"error"`
Message string `json:"message"`
Data struct {
ExportedAt string `json:"exportedAt"`
ClaudeAccounts []crsClaudeAccount `json:"claudeAccounts"`
ClaudeConsoleAccounts []crsConsoleAccount `json:"claudeConsoleAccounts"`
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
} `json:"data"`
}
type crsProxy struct {
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
}
type crsClaudeAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
AuthType string `json:"authType"` // oauth/setup-token
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
type crsConsoleAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
MaxConcurrentTasks int `json:"maxConcurrentTasks"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
}
type crsOpenAIResponsesAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
}
type crsOpenAIOAuthAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
AuthType string `json:"authType"` // oauth
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL)
if err != nil {
return nil, err
}
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
return nil, errors.New("username and password are required")
}
client := &http.Client{Timeout: 20 * time.Second}
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
if err != nil {
return nil, err
}
exported, err := crsExportAccounts(ctx, client, baseURL, adminToken)
if err != nil {
return nil, err
}
now := time.Now().UTC().Format(time.RFC3339)
result := &SyncFromCRSResult{
Items: make(
[]SyncFromCRSItemResult,
0,
len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts),
),
}
var proxies []model.Proxy
if input.SyncProxies {
proxies, _ = s.proxyRepo.ListActive(ctx)
}
// Claude OAuth / Setup Token -> sub2api anthropic oauth/setup-token
for _, src := range exported.Data.ClaudeAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
targetType := strings.TrimSpace(src.AuthType)
if targetType == "" {
targetType = "oauth"
}
if targetType != model.AccountTypeOAuth && targetType != model.AccountTypeSetupToken {
item.Action = "skipped"
item.Error = "unsupported authType: " + targetType
result.Skipped++
result.Items = append(result.Items, item)
continue
}
accessToken, _ := src.Credentials["access_token"].(string)
if strings.TrimSpace(accessToken) == "" {
item.Action = "failed"
item.Error = "missing access_token"
result.Failed++
result.Items = append(result.Items, item)
continue
}
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
// 🔧 Remove /v1 suffix from base_url for Claude accounts
cleanBaseURL(credentials, "/v1")
// 🔧 Convert expires_at from ISO string to Unix timestamp
if expiresAtStr, ok := credentials["expires_at"].(string); ok && expiresAtStr != "" {
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
credentials["expires_at"] = t.Unix()
}
}
// 🔧 Add intercept_warmup_requests if not present (defaults to false)
if _, exists := credentials["intercept_warmup_requests"]; !exists {
credentials["intercept_warmup_requests"] = false
}
priority := clampPriority(src.Priority)
concurrency := 3
status := mapCRSStatus(src.IsActive, src.Status)
// 🔧 Preserve all CRS extra fields and add sync metadata
extra := make(map[string]any)
if src.Extra != nil {
for k, v := range src.Extra {
extra[k] = v
}
}
extra["crs_account_id"] = src.ID
extra["crs_kind"] = src.Kind
extra["crs_synced_at"] = now
// Extract org_uuid and account_uuid from CRS credentials to extra
if orgUUID, ok := src.Credentials["org_uuid"]; ok {
extra["org_uuid"] = orgUUID
}
if accountUUID, ok := src.Credentials["account_uuid"]; ok {
extra["account_uuid"] = accountUUID
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic,
Type: targetType,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after creation
if targetType == model.AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
}
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
// Update existing
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic
existing.Type = targetType
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after update
if targetType == model.AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
}
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
// Claude Console API Key -> sub2api anthropic apikey
for _, src := range exported.Data.ClaudeConsoleAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
apiKey, _ := src.Credentials["api_key"].(string)
if strings.TrimSpace(apiKey) == "" {
item.Action = "failed"
item.Error = "missing api_key"
result.Failed++
result.Items = append(result.Items, item)
continue
}
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
priority := clampPriority(src.Priority)
concurrency := 3
if src.MaxConcurrentTasks > 0 {
concurrency = src.MaxConcurrentTasks
}
status := mapCRSStatus(src.IsActive, src.Status)
extra := map[string]any{
"crs_account_id": src.ID,
"crs_kind": src.Kind,
"crs_synced_at": now,
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic
existing.Type = model.AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
// OpenAI OAuth -> sub2api openai oauth
for _, src := range exported.Data.OpenAIOAuthAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
accessToken, _ := src.Credentials["access_token"].(string)
if strings.TrimSpace(accessToken) == "" {
item.Action = "failed"
item.Error = "missing access_token"
result.Failed++
result.Items = append(result.Items, item)
continue
}
proxyID, err := s.mapOrCreateProxy(
ctx,
input.SyncProxies,
&proxies,
src.Proxy,
fmt.Sprintf("crs-%s", src.Name),
)
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
// Normalize token_type
if v, ok := credentials["token_type"].(string); !ok || strings.TrimSpace(v) == "" {
credentials["token_type"] = "Bearer"
}
// 🔧 Convert expires_at from ISO string to Unix timestamp
if expiresAtStr, ok := credentials["expires_at"].(string); ok && expiresAtStr != "" {
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
credentials["expires_at"] = t.Unix()
}
}
priority := clampPriority(src.Priority)
concurrency := 3
status := mapCRSStatus(src.IsActive, src.Status)
// 🔧 Preserve all CRS extra fields and add sync metadata
extra := make(map[string]any)
if src.Extra != nil {
for k, v := range src.Extra {
extra[k] = v
}
}
extra["crs_account_id"] = src.ID
extra["crs_kind"] = src.Kind
extra["crs_synced_at"] = now
// Extract email from CRS extra (crs_email -> email)
if crsEmail, ok := src.Extra["crs_email"]; ok {
extra["email"] = crsEmail
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI,
Type: model.AccountTypeOAuth,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after creation
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI
existing.Type = model.AccountTypeOAuth
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
// 🔄 Refresh OAuth token after update
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
// OpenAI Responses API Key -> sub2api openai apikey
for _, src := range exported.Data.OpenAIResponsesAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
apiKey, _ := src.Credentials["api_key"].(string)
if strings.TrimSpace(apiKey) == "" {
item.Action = "failed"
item.Error = "missing api_key"
result.Failed++
result.Items = append(result.Items, item)
continue
}
if baseURL, ok := src.Credentials["base_url"].(string); !ok || strings.TrimSpace(baseURL) == "" {
src.Credentials["base_url"] = "https://api.openai.com"
}
// 🔧 Remove /v1 suffix from base_url for OpenAI accounts
cleanBaseURL(src.Credentials, "/v1")
proxyID, err := s.mapOrCreateProxy(
ctx,
input.SyncProxies,
&proxies,
src.Proxy,
fmt.Sprintf("crs-%s", src.Name),
)
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
priority := clampPriority(src.Priority)
concurrency := 3
status := mapCRSStatus(src.IsActive, src.Status)
extra := map[string]any{
"crs_account_id": src.ID,
"crs_kind": src.Kind,
"crs_synced_at": now,
}
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
Status: status,
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI
existing.Type = model.AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = concurrency
existing.Priority = priority
existing.Status = status
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
return result, nil
}
// mergeJSONB merges two JSONB maps without removing keys that are absent in updates.
func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
out := make(model.JSONB)
for k, v := range existing {
out[k] = v
}
for k, v := range updates {
out[k] = v
}
return out
}
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]model.Proxy, src *crsProxy, defaultName string) (*int64, error) {
if !enabled || src == nil {
return nil, nil
}
protocol := strings.ToLower(strings.TrimSpace(src.Protocol))
switch protocol {
case "socks":
protocol = "socks5"
case "socks5h":
protocol = "socks5"
}
host := strings.TrimSpace(src.Host)
port := src.Port
username := strings.TrimSpace(src.Username)
password := strings.TrimSpace(src.Password)
if protocol == "" || host == "" || port <= 0 {
return nil, nil
}
if protocol != "http" && protocol != "https" && protocol != "socks5" {
return nil, nil
}
// Find existing proxy (active only).
for _, p := range *cached {
if strings.EqualFold(p.Protocol, protocol) &&
p.Host == host &&
p.Port == port &&
p.Username == username &&
p.Password == password {
id := p.ID
return &id, nil
}
}
// Create new proxy
proxy := &model.Proxy{
Name: defaultProxyName(defaultName, protocol, host, port),
Protocol: protocol,
Host: host,
Port: port,
Username: username,
Password: password,
Status: model.StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
}
*cached = append(*cached, *proxy)
id := proxy.ID
return &id, nil
}
func defaultProxyName(base, protocol, host string, port int) string {
base = strings.TrimSpace(base)
if base == "" {
base = "crs"
}
return fmt.Sprintf("%s (%s://%s:%d)", base, protocol, host, port)
}
func defaultName(name, id string) string {
if strings.TrimSpace(name) != "" {
return strings.TrimSpace(name)
}
return "CRS " + id
}
func clampPriority(priority int) int {
if priority < 1 || priority > 100 {
return 50
}
return priority
}
func sanitizeCredentialsMap(input map[string]any) map[string]any {
if input == nil {
return map[string]any{}
}
out := make(map[string]any, len(input))
for k, v := range input {
// Avoid nil values to keep JSONB cleaner
if v != nil {
out[k] = v
}
}
return out
}
func mapCRSStatus(isActive bool, status string) string {
if !isActive {
return "inactive"
}
if strings.EqualFold(strings.TrimSpace(status), "error") {
return "error"
}
return "active"
}
func normalizeBaseURL(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", errors.New("base_url is required")
}
u, err := url.Parse(trimmed)
if err != nil || u.Scheme == "" || u.Host == "" {
return "", fmt.Errorf("invalid base_url: %s", trimmed)
}
u.Path = strings.TrimRight(u.Path, "/")
return strings.TrimRight(u.String(), "/"), nil
}
// cleanBaseURL removes trailing suffix from base_url in credentials
// Used for both Claude and OpenAI accounts to remove /v1
func cleanBaseURL(credentials map[string]any, suffixToRemove string) {
if baseURL, ok := credentials["base_url"].(string); ok && baseURL != "" {
trimmed := strings.TrimSpace(baseURL)
if strings.HasSuffix(trimmed, suffixToRemove) {
credentials["base_url"] = strings.TrimSuffix(trimmed, suffixToRemove)
}
}
}
func crsLogin(ctx context.Context, client *http.Client, baseURL, username, password string) (string, error) {
payload := map[string]any{
"username": username,
"password": password,
}
body, _ := json.Marshal(payload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/web/auth/login", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("crs login failed: status=%d body=%s", resp.StatusCode, string(raw))
}
var parsed crsLoginResponse
if err := json.Unmarshal(raw, &parsed); err != nil {
return "", fmt.Errorf("crs login parse failed: %w", err)
}
if !parsed.Success || strings.TrimSpace(parsed.Token) == "" {
msg := parsed.Message
if msg == "" {
msg = parsed.Error
}
if msg == "" {
msg = "unknown error"
}
return "", errors.New("crs login failed: " + msg)
}
return parsed.Token, nil
}
func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminToken string) (*crsExportResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/admin/sync/export-accounts?include_secrets=true", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+adminToken)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 5<<20))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("crs export failed: status=%d body=%s", resp.StatusCode, string(raw))
}
var parsed crsExportResponse
if err := json.Unmarshal(raw, &parsed); err != nil {
return nil, fmt.Errorf("crs export parse failed: %w", err)
}
if !parsed.Success {
msg := parsed.Message
if msg == "" {
msg = parsed.Error
}
if msg == "" {
msg = "unknown error"
}
return nil, errors.New("crs export failed: " + msg)
}
return &parsed, nil
}
// refreshOAuthToken attempts to refresh OAuth token for a synced account
// Returns updated credentials or nil if refresh failed/not applicable
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.Account) model.JSONB {
if account.Type != model.AccountTypeOAuth {
return nil
}
var newCredentials map[string]any
var err error
switch account.Platform {
case model.PlatformAnthropic:
if s.oauthService == nil {
return nil
}
tokenInfo, refreshErr := s.oauthService.RefreshAccountToken(ctx, account)
if refreshErr != nil {
err = refreshErr
} else {
// Preserve existing credentials
newCredentials = make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
// Update token fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = tokenInfo.ExpiresIn
newCredentials["expires_at"] = tokenInfo.ExpiresAt
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
}
case model.PlatformOpenAI:
if s.openaiOAuthService == nil {
return nil
}
tokenInfo, refreshErr := s.openaiOAuthService.RefreshAccountToken(ctx, account)
if refreshErr != nil {
err = refreshErr
} else {
newCredentials = s.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
}
default:
return nil
}
if err != nil {
// Log but don't fail the sync - token might still be valid or refreshable later
return nil
}
return model.JSONB(newCredentials)
}

View File

@@ -0,0 +1,77 @@
package service
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
// DashboardService provides aggregated statistics for admin dashboard.
type DashboardService struct {
usageRepo ports.UsageLogRepository
}
func NewDashboardService(usageRepo ports.UsageLogRepository) *DashboardService {
return &DashboardService{
usageRepo: usageRepo,
}
}
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
stats, err := s.usageRepo.GetDashboardStats(ctx)
if err != nil {
return nil, fmt.Errorf("get dashboard stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, 0)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
trend, err := s.usageRepo.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get user usage trend: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs)
if err != nil {
return nil, fmt.Errorf("get batch user usage stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}
return stats, nil
}

View File

@@ -6,11 +6,11 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"math/big" "math/big"
"net/smtp" "net/smtp"
"strconv" "strconv"
"sub2api/internal/model"
"sub2api/internal/service/ports"
"time" "time"
) )

View File

@@ -16,10 +16,12 @@ import (
"strings" "strings"
"time" "time"
"sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -358,6 +360,25 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou
return accessToken, "oauth", nil return accessToken, "oauth", nil
} }
// 重试相关常量
const (
maxRetries = 3 // 最大重试次数
retryDelay = 2 * time.Second // 重试等待时间
)
// shouldRetryUpstreamError 判断是否应该重试上游错误
// OAuth/Setup Token 账号:仅 403 重试
// API Key 账号:未配置的错误码重试
func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, statusCode int) bool {
// OAuth/Setup Token 账号:仅 403 重试
if account.IsOAuth() {
return statusCode == 403
}
// API Key 账号:未配置的错误码重试
return !account.ShouldHandleErrorCode(statusCode)
}
// Forward 转发请求到Claude API // Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) { func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
@@ -371,6 +392,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
return nil, fmt.Errorf("parse request: %w", err) return nil, fmt.Errorf("parse request: %w", err)
} }
if !gjson.GetBytes(body, "system").Exists() {
body, _ = sjson.SetBytes(body, "system", []any{
map[string]any{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
"cache_control": map[string]string{
"type": "ephemeral",
},
},
})
}
// 应用模型映射仅对apikey类型账号 // 应用模型映射仅对apikey类型账号
originalModel := req.Model originalModel := req.Model
if account.Type == model.AccountTypeApiKey { if account.Type == model.AccountTypeApiKey {
@@ -389,26 +422,51 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
return nil, err return nil, err
} }
// 构建上游请求
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
if err != nil {
return nil, err
}
// 获取代理URL // 获取代理URL
proxyURL := "" proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
// 发送请求 // 重试循环
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL) var resp *http.Response
if err != nil { for attempt := 1; attempt <= maxRetries; attempt++ {
return nil, fmt.Errorf("upstream request failed: %w", err) // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
if err != nil {
return nil, err
}
// 发送请求
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil {
return nil, fmt.Errorf("upstream request failed: %w", err)
}
// 检查是否需要重试
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
_ = resp.Body.Close()
time.Sleep(retryDelay)
continue
}
// 最后一次尝试也失败,跳出循环处理重试耗尽
break
}
// 不需要重试(成功或不可重试的错误),跳出循环
break
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
// 处理错误响应包括401由后台TokenRefreshService维护token有效性 // 处理重试耗尽的情况
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
} }
@@ -570,19 +628,6 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
// apikey 类型账号:检查自定义错误码配置
// 如果启用且错误码不在列表中,返回通用 500 错误(不做任何账号状态处理)
if !account.ShouldHandleErrorCode(resp.StatusCode) {
c.JSON(http.StatusInternalServerError, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream gateway error",
},
})
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
}
// 处理上游错误,标记账号状态 // 处理上游错误,标记账号状态
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
@@ -591,6 +636,9 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
var statusCode int var statusCode int
switch resp.StatusCode { switch resp.StatusCode {
case 400:
c.Data(http.StatusBadRequest, "application/json", body)
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
case 401: case 401:
statusCode = http.StatusBadGateway statusCode = http.StatusBadGateway
errType = "upstream_error" errType = "upstream_error"
@@ -629,6 +677,34 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
} }
// handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
statusCode := resp.StatusCode
// OAuth/Setup Token 账号的 403标记账号异常
if account.IsOAuth() && statusCode == 403 {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode)
} else {
// API Key 未配置错误码:不标记账号状态
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
}
// 返回统一的重试耗尽错误响应
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed after retries",
},
})
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", statusCode)
}
// streamingResult 流式响应结果 // streamingResult 流式响应结果
type streamingResult struct { type streamingResult struct {
usage *ClaudeUsage usage *ClaudeUsage
@@ -734,7 +810,7 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
} }
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
// 解析message_start获取input tokens // 解析message_start获取input tokens标准Claude API格式
var msgStart struct { var msgStart struct {
Type string `json:"type"` Type string `json:"type"`
Message struct { Message struct {
@@ -747,15 +823,30 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
} }
// 解析message_delta获取output tokens // 解析message_delta获取tokens兼容GLM等把所有usage放在delta中的API
var msgDelta struct { var msgDelta struct {
Type string `json:"type"` Type string `json:"type"`
Usage struct { Usage struct {
OutputTokens int `json:"output_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
} `json:"usage"` } `json:"usage"`
} }
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" { if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
// output_tokens 总是从 message_delta 获取
usage.OutputTokens = msgDelta.Usage.OutputTokens usage.OutputTokens = msgDelta.Usage.OutputTokens
// 如果 message_start 中没有值,则从 message_delta 获取兼容GLM等API
if usage.InputTokens == 0 {
usage.InputTokens = msgDelta.Usage.InputTokens
}
if usage.CacheCreationInputTokens == 0 {
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
}
if usage.CacheReadInputTokens == 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
}
} }
} }

View File

@@ -4,9 +4,9 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"gorm.io/gorm" "gorm.io/gorm"
) )

View File

@@ -7,11 +7,11 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"log" "log"
"net/http" "net/http"
"regexp" "regexp"
"strconv" "strconv"
"sub2api/internal/service/ports"
"time" "time"
) )

View File

@@ -6,9 +6,9 @@ import (
"log" "log"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows // ClaudeOAuthClient handles HTTP requests for Claude OAuth flows

View File

@@ -11,12 +11,13 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
"sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -38,6 +39,18 @@ var openaiAllowedHeaders = map[string]bool{
"session_id": true, "session_id": true,
} }
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
type OpenAICodexUsageSnapshot struct {
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
PrimaryResetAfterSeconds *int `json:"primary_reset_after_seconds,omitempty"`
PrimaryWindowMinutes *int `json:"primary_window_minutes,omitempty"`
SecondaryUsedPercent *float64 `json:"secondary_used_percent,omitempty"`
SecondaryResetAfterSeconds *int `json:"secondary_reset_after_seconds,omitempty"`
SecondaryWindowMinutes *int `json:"secondary_window_minutes,omitempty"`
PrimaryOverSecondaryPercent *float64 `json:"primary_over_secondary_percent,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
// OpenAIUsage represents OpenAI API response usage // OpenAIUsage represents OpenAI API response usage
type OpenAIUsage struct { type OpenAIUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
@@ -284,6 +297,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
} }
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == model.AccountTypeOAuth {
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
@@ -708,3 +728,109 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
return nil return nil
} }
// extractCodexUsageHeaders extracts Codex usage limits from response headers
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
snapshot := &OpenAICodexUsageSnapshot{}
hasData := false
// Helper to parse float64 from header
parseFloat := func(key string) *float64 {
if v := headers.Get(key); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
return &f
}
}
return nil
}
// Helper to parse int from header
parseInt := func(key string) *int {
if v := headers.Get(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return &i
}
}
return nil
}
// Primary (weekly) limits
if v := parseFloat("x-codex-primary-used-percent"); v != nil {
snapshot.PrimaryUsedPercent = v
hasData = true
}
if v := parseInt("x-codex-primary-reset-after-seconds"); v != nil {
snapshot.PrimaryResetAfterSeconds = v
hasData = true
}
if v := parseInt("x-codex-primary-window-minutes"); v != nil {
snapshot.PrimaryWindowMinutes = v
hasData = true
}
// Secondary (5h) limits
if v := parseFloat("x-codex-secondary-used-percent"); v != nil {
snapshot.SecondaryUsedPercent = v
hasData = true
}
if v := parseInt("x-codex-secondary-reset-after-seconds"); v != nil {
snapshot.SecondaryResetAfterSeconds = v
hasData = true
}
if v := parseInt("x-codex-secondary-window-minutes"); v != nil {
snapshot.SecondaryWindowMinutes = v
hasData = true
}
// Overflow ratio
if v := parseFloat("x-codex-primary-over-secondary-limit-percent"); v != nil {
snapshot.PrimaryOverSecondaryPercent = v
hasData = true
}
if !hasData {
return nil
}
snapshot.UpdatedAt = time.Now().Format(time.RFC3339)
return snapshot
}
// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field
func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) {
if snapshot == nil {
return
}
// Convert snapshot to map for merging into Extra
updates := make(map[string]any)
if snapshot.PrimaryUsedPercent != nil {
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
}
if snapshot.PrimaryResetAfterSeconds != nil {
updates["codex_primary_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
}
if snapshot.PrimaryWindowMinutes != nil {
updates["codex_primary_window_minutes"] = *snapshot.PrimaryWindowMinutes
}
if snapshot.SecondaryUsedPercent != nil {
updates["codex_secondary_used_percent"] = *snapshot.SecondaryUsedPercent
}
if snapshot.SecondaryResetAfterSeconds != nil {
updates["codex_secondary_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
}
if snapshot.SecondaryWindowMinutes != nil {
updates["codex_secondary_window_minutes"] = *snapshot.SecondaryWindowMinutes
}
if snapshot.PrimaryOverSecondaryPercent != nil {
updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent
}
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
// Update account's Extra field asynchronously
go func() {
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}()
}

View File

@@ -5,9 +5,9 @@ import (
"fmt" "fmt"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
// OpenAIOAuthService handles OpenAI OAuth authentication flows // OpenAIOAuthService handles OpenAI OAuth authentication flows

View File

@@ -4,13 +4,16 @@ import (
"context" "context"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
type AccountRepository interface { type AccountRepository interface {
Create(ctx context.Context, account *model.Account) error Create(ctx context.Context, account *model.Account) error
GetByID(ctx context.Context, id int64) (*model.Account, error) GetByID(ctx context.Context, id int64) (*model.Account, error)
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error)
Update(ctx context.Context, account *model.Account) error Update(ctx context.Context, account *model.Account) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
@@ -34,4 +37,18 @@ type AccountRepository interface {
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
ClearRateLimit(ctx context.Context, id int64) error ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
}
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
// Nil pointers mean "do not change".
type AccountBulkUpdate struct {
Name *string
ProxyID *int64
Concurrency *int
Priority *int
Status *string
Credentials map[string]any
Extra map[string]any
} }

View File

@@ -3,8 +3,8 @@ package ports
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
type ApiKeyRepository interface { type ApiKeyRepository interface {

View File

@@ -3,17 +3,21 @@ package ports
import "context" import "context"
// ConcurrencyCache defines cache operations for concurrency service // ConcurrencyCache defines cache operations for concurrency service
// Uses independent keys per request slot with native Redis TTL for automatic cleanup
type ConcurrencyCache interface { type ConcurrencyCache interface {
// Slot management // Account slot management - each slot is a separate key with independent TTL
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) // Key format: concurrency:account:{accountID}:{requestID}
ReleaseAccountSlot(ctx context.Context, accountID int64) error AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) // User slot management - each slot is a separate key with independent TTL
ReleaseUserSlot(ctx context.Context, userID int64) error // Key format: concurrency:user:{userID}:{requestID}
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error) GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// Wait queue // Wait queue - uses counter with TTL set only on creation
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error DecrementWaitCount(ctx context.Context, userID int64) error
} }

View File

@@ -3,8 +3,8 @@ package ports
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )

View File

@@ -3,7 +3,7 @@ package ports
import ( import (
"context" "context"
"sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
// OpenAIOAuthClient interface for OpenAI OAuth operations // OpenAIOAuthClient interface for OpenAI OAuth operations

View File

@@ -3,8 +3,8 @@ package ports
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
type ProxyRepository interface { type ProxyRepository interface {

View File

@@ -3,8 +3,8 @@ package ports
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
type RedeemCodeRepository interface { type RedeemCodeRepository interface {

View File

@@ -3,7 +3,7 @@ package ports
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
) )
type SettingRepository interface { type SettingRepository interface {

View File

@@ -4,9 +4,9 @@ import (
"context" "context"
"time" "time"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
) )
type UsageLogRepository interface { type UsageLogRepository interface {
@@ -25,4 +25,25 @@ type UsageLogRepository interface {
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
// Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error)
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)
// Admin usage listing/stats
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error)
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
// Account stats
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)
} }

View File

@@ -3,8 +3,8 @@ package ports
import ( import (
"context" "context"
"sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
type UserRepository interface { type UserRepository interface {

Some files were not shown because too many files have changed in this diff Show More