Compare commits

..

41 Commits

Author SHA1 Message Date
shaw
f1325e9ae6 chore: 调整Turnstile设置跳转地址 2025-12-20 15:14:36 +08:00
shaw
587012396b feat: 支持创建管理员APIKEY 2025-12-20 15:11:43 +08:00
shaw
adebd941e1 fix: 修复Oauth账号自动刷新token失败的bug 2025-12-20 13:01:58 +08:00
Wesley Liddick
bb500b7b2a Merge pull request #9 from NepetaLemon/refactor/add-http-service-ports
refactor(backend): service http ports
2025-12-19 23:35:13 -05:00
Forest
cceada7dae refactor(backend): service http ports 2025-12-20 11:57:02 +08:00
shaw
5c2e7ae265 fix: 调整订阅计费时间窗口为每日0点
- 窗口激活/重置时使用当天零点而非精确时间
- 使用服务器本地时区计算零点(支持 UTC+8 等时区)
- 窗口重置时失效 Redis 缓存,避免数据不一致
2025-12-20 11:33:06 +08:00
shaw
420bedd615 Merge PR #8: refactor(backend): 添加 service 缓存端口 2025-12-20 11:05:01 +08:00
shaw
a79f6c5e1e feat: 给所有表格页面增加刷新按钮 2025-12-20 10:48:42 +08:00
shaw
0484c59ead feat: /admin/usage页面增加模型分布情况显示 2025-12-20 10:06:55 +08:00
Forest
7bbf621490 refactor(backend): 添加 service 缓存端口 2025-12-19 23:44:18 +08:00
shaw
ef81aeb463 fix: 修复dashboard页面用户名的显示bug 2025-12-19 22:41:26 +08:00
shaw
22414326cc fix: 修复前端切换页面时logo跟标题闪烁的问题 2025-12-19 22:33:36 +08:00
Wesley Liddick
14b155c66b Merge pull request #7 from NepetaLemon/refactor/ports-pattern
refactor(backend): 引入端口接口模式
2025-12-19 08:29:04 -05:00
Forest
e99b344b2b refactor(backend): 引入端口接口模式 2025-12-19 21:26:19 +08:00
shaw
7fd94ab78b fix: 修复usage页面未显示缓存写入的问题 2025-12-19 16:57:31 +08:00
shaw
078529e51e chore: 更新docker的postgres版本为18 2025-12-19 16:42:03 +08:00
shaw
23a4cf11c8 fix: 设置默认logo作为favicon 2025-12-19 16:41:00 +08:00
shaw
d1f0902ec0 feat(account): 支持账号级别拦截预热请求
- 新增 intercept_warmup_requests 配置项,存储在 credentials 字段
- 启用后,标题生成、Warmup 等预热请求返回 mock 响应,不消耗上游 token
- 前端支持所有账号类型(OAuth、Setup Token、API Key)的开关配置
- 修复 OAuth 凭证刷新时丢失非 token 配置的问题
2025-12-19 16:39:25 +08:00
shaw
ee86dbca9d feat(account): 账号测试支持选择模型
- 新增 GET /api/v1/admin/accounts/:id/models 接口获取账号可用模型
- 账号测试弹窗新增模型选择下拉框
- 测试时支持传入 model_id 参数,不传则默认使用 Sonnet
- API Key 账号支持根据 model_mapping 映射测试模型
- 将模型常量提取到 claude 包统一管理
2025-12-19 16:00:09 +08:00
Wesley Liddick
733d4c2b85 Merge pull request #6 from dexcoder6/main
fix(frontend): 修复移动端菜单栏和使用记录页面 UI 问题
2025-12-19 02:59:05 -05:00
dexcoder6
406d3f3cab fix(frontend): 修复移动端菜单栏和使用记录页面 UI 问题
- 修复移动端无法打开菜单栏的问题
  - 在 app.ts 中添加 mobileOpen 状态管理
  - 修复 AppHeader.vue 中移动端菜单按钮调用错误的方法
  - 修复 AppSidebar.vue 使用本地 ref 而非全局状态的问题

- 添加移动端菜单自动关闭功能
  - 点击菜单项后自动关闭侧边栏
  - 添加 150ms 延迟以显示关闭动画

- 修复使用记录页面总消费卡片溢出问题
  - 调整总消费卡片布局,将删除线价格移至说明行
  - 添加 min-w-0 flex-1 防止内容溢出
  - 保持与其他卡片高度一致

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-19 15:55:42 +08:00
shaw
1ed93a5fd0 refactor: 提取 Claude 客户端常量到独立包
- 新增 internal/pkg/claude 包统一管理 Claude Code 相关常量
  - 统一账号测试逻辑,所有账号类型使用相同的 Claude Code 风格请求
  - 网关服务使用常量包替换硬编码的 beta header 字符串
2025-12-19 15:22:52 +08:00
shaw
463ddea36f fix(frontend): 修复代理快捷添加弹窗的 i18n 解析错误
batchInputHint 中的 @ 符号需要使用 {'@'} 转义
2025-12-19 11:24:22 +08:00
shaw
e769f67699 fix(setup): 支持从配置文件读取 Setup Wizard 监听地址
Setup Wizard 之前硬编码使用 8080 端口,现在支持从 config.yaml 或
环境变量 (SERVER_HOST, SERVER_PORT) 读取监听地址,方便用户在端口
被占用时使用其他地址启动初始化向导。
2025-12-19 11:21:58 +08:00
shaw
52d2ae9708 feat(gateway): 添加 /v1/messages/count_tokens 端点
实现 Claude API 的 token 计数功能,支持 OAuth、SetupToken 和 ApiKey 三种账号类型。

特点:
- 校验订阅/余额(不扣费)
- 不计算用户和账号并发
- 不记录使用量
- 支持模型映射(ApiKey 账号)
- 支持 OAuth 账号的指纹管理和 401 重试
2025-12-19 11:12:41 +08:00
shaw
2e59998c51 fix: 代理表单字段保存时自动去除前后空格
前后端同时处理,防止因意外空格导致代理连接失败
2025-12-19 10:39:30 +08:00
shaw
32e58115cc fix(frontend): 修复代理快捷添加弹窗的 i18n 解析错误
转义 batchInputPlaceholder 中的 @ 符号,防止 Vue I18n 将其误解析为链接消息语法
2025-12-19 10:32:22 +08:00
shaw
ba27026399 docs: 调整源码编译步骤的顺序 2025-12-19 09:47:17 +08:00
shaw
c15b419c4c feat(backend): 添加 event_logging 接口直接返回200
将原本在nginx处理的遥测日志请求移至后端,
忽略Claude Code客户端发送的日志数据。
2025-12-19 09:39:57 +08:00
shaw
5bd27a5d17 fix(frontend): 优化分组表单中订阅模式的字段显示逻辑
- 订阅模式下隐藏 Exclusive 字段并默认为开启状态
- 编辑分组时禁用计费类型字段,防止修改
- 移除编辑表单中无用的 subscription_type watch
2025-12-19 08:41:30 +08:00
Wesley Liddick
0e7b8aab8c Merge pull request #4 from NepetaLemon/refactor/backend-wire-provider-sets
refactor(backend): 拆分 Wire ProviderSet
2025-12-18 19:27:49 -05:00
Forest
236908c03d refactor(backend): 拆分 Wire ProviderSet 2025-12-19 00:03:29 +08:00
shaw
67d028cf50 fix: 修复用户修改密码接口404问题
将后端路由与前端API调用对齐:
- /user/profile -> /users/me
- PUT /user/password -> POST /users/me/password
2025-12-18 22:59:49 +08:00
shaw
66ba487697 fix: 修复前端github项目地址 2025-12-18 22:47:42 +08:00
Wesley Liddick
8c7875aa4d Merge pull request #3 from NepetaLemon/refactor/backend-wire-bootstrap
refactor(backend): 引入 Wire 重构服务启动与依赖组装
2025-12-18 09:12:15 -05:00
shaw
145171464f fix: 修复前端多个 bug
1. 版本号闪烁问题
   - 将版本信息缓存到 Pinia store,避免每次路由切换都重新请求
   - 添加加载占位符,版本为空时显示骨架屏

2. 管理员登录跳转问题
   - 管理员登录后现在正确跳转到 /admin/dashboard
   - 普通用户仍跳转到 /dashboard

3. Dashboard 页面空白报错
   - 修复 API 返回 null 时访问 .length 导致的 TypeError
   - 为 computed 属性添加可选链操作符保护
   - 为数据赋值添加空数组默认值
2025-12-18 22:11:29 +08:00
Forest
e5aa676853 refactor(backend): 引入 Wire 重构服务启动与依赖组装 2025-12-18 22:07:17 +08:00
shaw
9b4fc42457 feat: 实现后台在线更新功能
- 前端添加更新和重启按钮,支持一键更新 Release 构建
- 修复条件判断优先级问题,确保错误/成功状态正确显示
- 后端使用原子文件替换模式,确保更新过程安全可靠
- 在可执行文件同目录创建临时文件,保证 rename 原子性
- 删除未使用的 copyFile 函数,保持代码整洁
2025-12-18 21:15:10 +08:00
shaw
caae7e4603 feat: 改进安装脚本的交互体验和自动化流程
- 修复 curl | bash 管道模式下无法交互式输入的问题
  - 使用 /dev/tty 检测终端可用性替代 stdin 检测
  - 所有 read 命令从 /dev/tty 读取用户输入
- 安装完成后自动启动服务和启用开机自启
- 使用 ipinfo.io API 获取公网 IP 用于显示访问地址
- 简化安装完成后的输出信息
2025-12-18 20:53:29 +08:00
shaw
a26db8b3e2 fix: 修复前端页面刷新时偶发空白渲染的竞态条件问题
使用 router.isReady() 等待路由器完成初始导航后再挂载应用,
避免 RouterView 在路由未就绪时渲染空的 Comment 节点。
2025-12-18 20:45:56 +08:00
shaw
8e81e395b3 refactor: 使用行业标准方案重构服务重启逻辑
重构内容:
- 移除复杂的 sudo systemctl restart 方案
- 改用 os.Exit(0) + systemd Restart=always 的标准做法
- 删除 sudoers 配置及相关代码
- 删除 sub2api-sudoers 文件

优势:
- 代码从 85+ 行简化到 47 行
- 无需 sudo 权限配置
- 无需特殊用户 shell 配置
- 更简单、更可靠
- 符合行业最佳实践(Docker/K8s 等均采用此方案)

工作原理:
- 服务调用 os.Exit(0) 优雅退出
- systemd 检测到退出后自动重启(Restart=always)
2025-12-18 20:32:24 +08:00
142 changed files with 6405 additions and 2772 deletions

View File

@@ -16,6 +16,14 @@ English | [中文](README_CN.md)
---
## Demo
Try Sub2API online: **https://v2.pincc.ai/**
| Email | Password |
|-------|----------|
| admin@sub2api.com | admin123 |
## Overview
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
@@ -208,20 +216,19 @@ Build and run from source code for development or customization.
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. Build backend
cd backend
go build -o sub2api ./cmd/server
# 3. Build frontend
cd ../frontend
# 2. Build frontend
cd frontend
npm install
npm run build
# 4. Copy frontend build to backend (for embedding)
# 3. Copy frontend build to backend (for embedding)
cp -r dist ../backend/internal/web/
# 5. Create configuration file
# 4. Build backend (requires frontend dist to be present)
cd ../backend
go build -o sub2api ./cmd/server
# 5. Create configuration file
cp ../deploy/config.example.yaml ./config.yaml
# 6. Edit configuration

View File

@@ -16,6 +16,14 @@
---
## 在线体验
体验地址:**https://v2.pincc.ai/**
| 邮箱 | 密码 |
|------|------|
| admin@sub2api.com | admin123 |
## 项目概述
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
@@ -208,20 +216,19 @@ docker-compose logs -f
git clone https://github.com/Wei-Shaw/sub2api.git
cd sub2api
# 2. 编译
cd backend
go build -o sub2api ./cmd/server
# 3. 编译前端
cd ../frontend
# 2. 编译
cd frontend
npm install
npm run build
# 4. 复制前端构建产物到后端(用于嵌入)
# 3. 复制前端构建产物到后端(用于嵌入)
cp -r dist ../backend/internal/web/
# 5. 创建配置文件
# 4. 编译后端(需要前端 dist 目录存在)
cd ../backend
go build -o sub2api ./cmd/server
# 5. 创建配置文件
cp ../deploy/config.example.yaml ./config.yaml
# 6. 编辑配置

6
backend/Makefile Normal file
View File

@@ -0,0 +1,6 @@
.PHONY: wire
wire:
@echo "生成 Wire 代码..."
@cd cmd/server && go generate
@echo "Wire 代码生成完成"

View File

@@ -1,8 +1,11 @@
package main
//go:generate go run github.com/google/wire/cmd/wire
import (
"context"
_ "embed"
"errors"
"flag"
"log"
"net/http"
@@ -15,18 +18,10 @@ import (
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/middleware"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"sub2api/internal/setup"
"sub2api/internal/web"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
//go:embed VERSION
@@ -100,8 +95,10 @@ func runSetupServer() {
r.Use(web.ServeEmbeddedFrontend())
}
addr := ":8080"
log.Printf("Setup wizard available at http://localhost%s", addr)
// Get server address from config.yaml or environment variables (SERVER_HOST, SERVER_PORT)
// This allows users to run setup on a different address if needed
addr := config.GetServerAddress()
log.Printf("Setup wizard available at http://%s", addr)
log.Println("Complete the setup wizard to configure Sub2API")
if err := r.Run(addr); err != nil {
@@ -110,78 +107,25 @@ func runSetupServer() {
}
func runMainServer() {
// 加载配置
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 初始化时区(类似 PHP 的 date_default_timezone_set
if err := timezone.Init(cfg.Timezone); err != nil {
log.Fatalf("Failed to initialize timezone: %v", err)
}
// 初始化数据库
db, err := initDB(cfg)
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
// 初始化Redis
rdb := initRedis(cfg)
// 初始化Repository
repos := repository.NewRepositories(db)
// 初始化Service
services := service.NewServices(repos, rdb, cfg)
// 初始化Handler
buildInfo := handler.BuildInfo{
Version: Version,
BuildType: BuildType,
}
handlers := handler.NewHandlers(services, repos, rdb, buildInfo)
// 设置Gin模式
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
}
// 创建路由
r := gin.New()
r.Use(gin.Recovery())
r.Use(middleware.Logger())
r.Use(middleware.CORS())
// 注册路由
registerRoutes(r, handlers, services, repos)
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
app, err := initializeApplication(buildInfo)
if err != nil {
log.Fatalf("Failed to initialize application: %v", err)
}
defer app.Cleanup()
// 启动服务器
srv := &http.Server{
Addr: cfg.Server.Address(),
Handler: r,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
// 优雅关闭
go func() {
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
if err := app.Server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start server: %v", err)
}
}()
log.Printf("Server started on %s", cfg.Server.Address())
log.Printf("Server started on %s", app.Server.Addr)
// 等待中断信号
quit := make(chan os.Signal, 1)
@@ -193,289 +137,9 @@ func runMainServer() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
if err := app.Server.Shutdown(ctx); err != nil {
log.Fatalf("Server forced to shutdown: %v", err)
}
log.Println("Server exited")
}
func initDB(cfg *config.Config) (*gorm.DB, error) {
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil {
return nil, err
}
return db, nil
}
func initRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
}
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
// API v1
v1 := r.Group("/api/v1")
{
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的接口
authenticated := v1.Group("")
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
{
// 当前用户信息
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
// 管理员接口
admin := v1.Group("/admin")
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
{
// 仪表盘
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
// 用户管理
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
}
// 分组管理
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
// 账号管理
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
// OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
// 代理管理
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
// 卡密管理
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
// 系统设置
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
}
// 系统管理
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
// 订阅管理
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
// 使用记录管理
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
}
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
}
}

117
backend/cmd/server/wire.go Normal file
View File

@@ -0,0 +1,117 @@
//go:build wireinject
// +build wireinject
package main
import (
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/infrastructure"
"sub2api/internal/repository"
"sub2api/internal/server"
"sub2api/internal/service"
"context"
"log"
"net/http"
"time"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
type Application struct {
Server *http.Server
Cleanup func()
}
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
wire.Build(
// 基础设施层 ProviderSets
config.ProviderSet,
infrastructure.ProviderSet,
// 业务层 ProviderSets
repository.ProviderSet,
service.ProviderSet,
handler.ProviderSet,
// 服务器层 ProviderSet
server.ProviderSet,
// BuildInfo provider
provideServiceBuildInfo,
// 清理函数提供者
provideCleanup,
// 应用程序结构体
wire.Struct(new(Application), "Server", "Cleanup"),
)
return nil, nil
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,
BuildType: buildInfo.BuildType,
}
}
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Cleanup steps in reverse dependency order
cleanupSteps := []struct {
name string
fn func() error
}{
{"TokenRefreshService", func() error {
services.TokenRefresh.Stop()
return nil
}},
{"PricingService", func() error {
services.Pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},
{"Database", func() error {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}},
}
for _, step := range cleanupSteps {
if err := step.fn(); err != nil {
log.Printf("[Cleanup] %s failed: %v", step.name, err)
// Continue with remaining cleanup steps even if one fails
} else {
log.Printf("[Cleanup] %s succeeded", step.name)
}
}
// Check if context timed out
select {
case <-ctx.Done():
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
default:
log.Printf("[Cleanup] All cleanup steps completed")
}
}
}

View File

@@ -0,0 +1,232 @@
// Code generated by Wire. DO NOT EDIT.
//go:generate go run -mod=mod github.com/google/wire/cmd/wire
//go:build !wireinject
// +build !wireinject
package main
import (
"context"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
"log"
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/handler/admin"
"sub2api/internal/infrastructure"
"sub2api/internal/repository"
"sub2api/internal/server"
"sub2api/internal/service"
"time"
)
import (
_ "embed"
)
// Injectors from wire.go:
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
configConfig, err := config.ProvideConfig()
if err != nil {
return nil, err
}
db, err := infrastructure.ProvideDB(configConfig)
if err != nil {
return nil, err
}
userRepository := repository.NewUserRepository(db)
settingRepository := repository.NewSettingRepository(db)
settingService := service.NewSettingService(settingRepository, configConfig)
client := infrastructure.ProvideRedis(configConfig)
emailCache := repository.NewEmailCache(client)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier()
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
authHandler := handler.NewAuthHandler(authService)
userService := service.NewUserService(userRepository, configConfig)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
apiKeyCache := repository.NewApiKeyCache(client)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
billingCache := repository.NewBillingCache(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(client)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db)
proxyExitInfoProber := repository.NewProxyExitInfoProber()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService, proxyExitInfoProber)
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService, claudeUsageFetcher)
claudeUpstream := repository.NewClaudeUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService)
updateCache := repository.NewUpdateCache(client)
gitHubReleaseClient := repository.NewGitHubReleaseClient()
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
gatewayCache := repository.NewGatewayCache(client)
pricingRemoteClient := repository.NewPricingRemoteClient()
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(client)
identityService := service.NewIdentityService(identityCache)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
groupService := service.NewGroupService(groupRepository)
accountService := service.NewAccountService(accountRepository, groupRepository)
proxyService := service.NewProxyService(proxyRepository)
services := &service.Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OAuth: oAuthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
Update: updateService,
TokenRefresh: tokenRefreshService,
}
repositories := &repository.Repositories{
User: userRepository,
ApiKey: apiKeyRepository,
Group: groupRepository,
Account: accountRepository,
Proxy: proxyRepository,
RedeemCode: redeemCodeRepository,
UsageLog: usageLogRepository,
Setting: settingRepository,
UserSubscription: userSubscriptionRepository,
}
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
httpServer := server.ProvideHTTPServer(configConfig, engine)
v := provideCleanup(db, client, services)
application := &Application{
Server: httpServer,
Cleanup: v,
}
return application, nil
}
// wire.go:
type Application struct {
Server *http.Server
Cleanup func()
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,
BuildType: buildInfo.BuildType,
}
}
func provideCleanup(
db *gorm.DB,
rdb *redis.Client,
services *service.Services,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cleanupSteps := []struct {
name string
fn func() error
}{
{"TokenRefreshService", func() error {
services.TokenRefresh.Stop()
return nil
}},
{"PricingService", func() error {
services.Pricing.Stop()
return nil
}},
{"EmailQueueService", func() error {
services.EmailQueue.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},
{"Database", func() error {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}},
}
for _, step := range cleanupSteps {
if err := step.fn(); err != nil {
log.Printf("[Cleanup] %s failed: %v", step.name, err)
} else {
log.Printf("[Cleanup] %s succeeded", step.name)
}
}
select {
case <-ctx.Done():
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
default:
log.Printf("[Cleanup] All cleanup steps completed")
}
}
}

View File

@@ -13,6 +13,7 @@ require (
github.com/redis/go-redis/v9 v9.3.0
github.com/spf13/viper v1.18.2
golang.org/x/crypto v0.44.0
golang.org/x/net v0.47.0
golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.4
@@ -33,6 +34,8 @@ require (
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/google/wire v0.7.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -50,6 +53,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.56.0 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
@@ -66,9 +70,11 @@ require (
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/mod v0.29.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)

View File

@@ -48,8 +48,12 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
@@ -154,8 +158,12 @@ golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
@@ -166,6 +174,8 @@ golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=

View File

@@ -8,15 +8,30 @@ import (
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
Enabled bool `mapstructure:"enabled"`
// 检查间隔(分钟)
CheckIntervalMinutes int `mapstructure:"check_interval_minutes"`
// 提前刷新时间小时在token过期前多久开始刷新
RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"`
// 最大重试次数
MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
}
type PricingConfig struct {
@@ -192,6 +207,13 @@ func setDefaults() {
// Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
}
func (c *Config) Validate() error {
@@ -203,3 +225,29 @@ func (c *Config) Validate() error {
}
return nil
}
// GetServerAddress returns the server address (host:port) from config file or environment variable.
// This is a lightweight function that can be used before full config validation,
// such as during setup wizard startup.
// Priority: config.yaml > environment variables > defaults
func GetServerAddress() string {
v := viper.New()
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("./config")
v.AddConfigPath("/etc/sub2api")
// Support SERVER_HOST and SERVER_PORT environment variables
v.AutomaticEnv()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.SetDefault("server.host", "0.0.0.0")
v.SetDefault("server.port", 8080)
// Try to read config file (ignore errors if not found)
_ = v.ReadInConfig()
host := v.GetString("server.host")
port := v.GetInt("server.port")
return fmt.Sprintf("%s:%d", host, port)
}

View File

@@ -0,0 +1,13 @@
package config
import "github.com/google/wire"
// ProviderSet 提供配置层的依赖
var ProviderSet = wire.NewSet(
ProvideConfig,
)
// ProvideConfig 提供应用配置
func ProvideConfig() (*Config, error) {
return Load()
}

View File

@@ -3,6 +3,7 @@ package admin
import (
"strconv"
"sub2api/internal/pkg/claude"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
@@ -186,6 +187,11 @@ func (h *AccountHandler) Delete(c *gin.Context) {
response.Success(c, gin.H{"message": "Account deleted successfully"})
}
// TestAccountRequest represents the request body for testing an account
type TestAccountRequest struct {
ModelID string `json:"model_id"`
}
// Test handles testing account connectivity with SSE streaming
// POST /api/v1/admin/accounts/:id/test
func (h *AccountHandler) Test(c *gin.Context) {
@@ -195,8 +201,12 @@ func (h *AccountHandler) Test(c *gin.Context) {
return
}
var req TestAccountRequest
// Allow empty body, model_id is optional
_ = c.ShouldBindJSON(&req)
// Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID); err != nil {
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
// Error already sent via SSE, just log
return
}
@@ -231,16 +241,20 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
return
}
// Update account credentials
newCredentials := map[string]interface{}{
"access_token": tokenInfo.AccessToken,
"token_type": tokenInfo.TokenType,
"expires_in": tokenInfo.ExpiresIn,
"expires_at": tokenInfo.ExpiresAt,
"refresh_token": tokenInfo.RefreshToken,
"scope": tokenInfo.Scope,
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
newCredentials := make(map[string]interface{})
for k, v := range account.Credentials {
newCredentials[k] = v
}
// Update token-related fields
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = tokenInfo.ExpiresIn
newCredentials["expires_at"] = tokenInfo.ExpiresAt
newCredentials["refresh_token"] = tokenInfo.RefreshToken
newCredentials["scope"] = tokenInfo.Scope
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
@@ -535,3 +549,58 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
response.Success(c, account)
}
// GetAvailableModels handles getting available models for an account
// GET /api/v1/admin/accounts/:id/models
func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
response.Success(c, claude.DefaultModels)
return
}
// For API Key accounts: return models based on model_mapping
mapping := account.GetModelMapping()
if mapping == nil || len(mapping) == 0 {
// No mapping configured, return default models
response.Success(c, claude.DefaultModels)
return
}
// Return mapped models (keys of the mapping are the available model IDs)
var models []claude.Model
for requestedModel := range mapping {
// Try to find display info from default models
var found bool
for _, dm := range claude.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
// If not found in defaults, create a basic entry
if !found {
models = append(models, claude.Model{
ID: requestedModel,
Type: "model",
DisplayName: requestedModel,
CreatedAt: "",
})
}
}
response.Success(c, models)
}

View File

@@ -127,12 +127,25 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour)
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUsageTrend(c.Request.Context(), startTime, endTime, granularity)
// Parse optional filter params
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
trend, err := h.usageRepo.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -148,11 +161,24 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD)
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
stats, err := h.usageRepo.GetModelStats(c.Request.Context(), startTime, endTime)
// Parse optional filter params
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
stats, err := h.usageRepo.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return

View File

@@ -2,6 +2,7 @@ package admin
import (
"strconv"
"strings"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
@@ -112,12 +113,12 @@ func (h *ProxyHandler) Create(c *gin.Context) {
}
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: req.Username,
Password: req.Password,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
})
if err != nil {
response.BadRequest(c, "Failed to create proxy: "+err.Error())
@@ -143,13 +144,13 @@ func (h *ProxyHandler) Update(c *gin.Context) {
}
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: req.Username,
Password: req.Password,
Status: req.Status,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
Status: strings.TrimSpace(req.Status),
})
if err != nil {
response.InternalError(c, "Failed to update proxy: "+err.Error())
@@ -263,8 +264,14 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
skipped := 0
for _, item := range req.Proxies {
// Trim all string fields
host := strings.TrimSpace(item.Host)
protocol := strings.TrimSpace(item.Protocol)
username := strings.TrimSpace(item.Username)
password := strings.TrimSpace(item.Password)
// Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), item.Host, item.Port, item.Username, item.Password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
if err != nil {
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
return
@@ -278,11 +285,11 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
// Create proxy with default name
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: "default",
Protocol: item.Protocol,
Host: item.Host,
Protocol: protocol,
Host: host,
Port: item.Port,
Username: item.Username,
Password: item.Password,
Username: username,
Password: password,
})
if err != nil {
// If creation fails due to duplicate, count as skipped

View File

@@ -256,3 +256,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
response.Success(c, gin.H{"message": "Test email sent successfully"})
}
// GetAdminApiKey 获取管理员 API Key 状态
// GET /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get admin API key status: "+err.Error())
return
}
response.Success(c, gin.H{
"exists": exists,
"masked_key": maskedKey,
})
}
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
// POST /api/v1/admin/settings/admin-api-key/regenerate
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to generate admin API key: "+err.Error())
return
}
response.Success(c, gin.H{
"key": key, // 完整 key 只在生成时返回一次
})
}
// DeleteAdminApiKey 删除管理员 API Key
// DELETE /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
response.InternalError(c, "Failed to delete admin API key: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Admin API key deleted"})
}

View File

@@ -4,15 +4,15 @@ import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// toResponsePagination converts repository.PaginationResult to response.PaginationResult
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult {
// toResponsePagination converts pagination.PaginationResult to response.PaginationResult
func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult {
if p == nil {
return nil
}

View File

@@ -9,7 +9,6 @@ import (
"sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// SystemHandler handles system-related operations
@@ -18,9 +17,9 @@ type SystemHandler struct {
}
// NewSystemHandler creates a new SystemHandler
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
return &SystemHandler{
updateSvc: service.NewUpdateService(rdb, version, buildType),
updateSvc: updateSvc,
}
}

View File

@@ -4,6 +4,7 @@ import (
"strconv"
"time"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
@@ -14,10 +15,10 @@ import (
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageRepo *repository.UsageLogRepository
apiKeyRepo *repository.ApiKeyRepository
usageService *service.UsageService
adminService service.AdminService
usageRepo *repository.UsageLogRepository
apiKeyRepo *repository.ApiKeyRepository
usageService *service.UsageService
adminService service.AdminService
}
// NewUsageHandler creates a new admin usage handler
@@ -82,7 +83,7 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := repository.UsageLogFilters{
UserID: userID,
ApiKeyID: apiKeyID,

View File

@@ -4,8 +4,8 @@ import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -53,7 +53,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
}
page, pageSize := response.ParsePagination(c)
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
if err != nil {

View File

@@ -7,10 +7,12 @@ import (
"io"
"log"
"net/http"
"strings"
"time"
"sub2api/internal/middleware"
"sub2api/internal/model"
"sub2api/internal/pkg/claude"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -126,6 +128,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
} else {
sendMockWarmupResponse(c, req.Model)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
if err != nil {
@@ -285,29 +297,8 @@ func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id
// Models handles listing available models
// GET /v1/models
func (h *GatewayHandler) Models(c *gin.Context) {
models := []gin.H{
{
"id": "claude-opus-4-5-20251101",
"type": "model",
"display_name": "Claude Opus 4.5",
"created_at": "2025-11-01T00:00:00Z",
},
{
"id": "claude-sonnet-4-5-20250929",
"type": "model",
"display_name": "Claude Sonnet 4.5",
"created_at": "2025-09-29T00:00:00Z",
},
{
"id": "claude-haiku-4-5-20251001",
"type": "model",
"display_name": "Claude Haiku 4.5",
"created_at": "2025-10-01T00:00:00Z",
},
}
c.JSON(http.StatusOK, gin.H{
"data": models,
"data": claude.DefaultModels,
"object": "list",
})
}
@@ -443,3 +434,155 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
},
})
}
// CountTokens handles token counting endpoint
// POST /v1/messages/count_tokens
// 特点:校验订阅/余额,但不计算并发、不记录使用量
func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// 解析请求获取模型名
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 获取订阅信息可能为nil
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
return
}
// 计算粘性会话 hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
// 错误响应已在 ForwardCountTokens 中处理
return
}
}
// isWarmupRequest 检测是否为预热请求标题生成、Warmup等
func isWarmupRequest(body []byte) bool {
// 快速检查如果body不包含关键字直接返回false
bodyStr := string(body)
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
return false
}
// 解析完整请求
var req struct {
Messages []struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
} `json:"messages"`
System []struct {
Text string `json:"text"`
} `json:"system"`
}
if err := json.Unmarshal(body, &req); err != nil {
return false
}
// 检查 messages 中的标题提示模式
for _, msg := range req.Messages {
for _, content := range msg.Content {
if content.Type == "text" {
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
content.Text == "Warmup" {
return true
}
}
}
}
// 检查 system 中的标题提取模式
for _, system := range req.System {
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
return true
}
}
return false
}
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
func sendMockWarmupStream(c *gin.Context, model string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
events := []string{
`event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`,
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
}
for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush()
time.Sleep(20 * time.Millisecond)
}
}
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
func sendMockWarmupResponse(c *gin.Context, model string) {
c.JSON(http.StatusOK, gin.H{
"id": "msg_mock_warmup",
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
"stop_reason": "end_turn",
"usage": gin.H{
"input_tokens": 10,
"output_tokens": 2,
},
})
}

View File

@@ -2,10 +2,6 @@ package handler
import (
"sub2api/internal/handler/admin"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// AdminHandlers contains all admin-related HTTP handlers
@@ -41,30 +37,3 @@ type BuildInfo struct {
Version string
BuildType string // "source" for manual builds, "release" for CI builds
}
// NewHandlers creates a new Handlers instance with all handlers initialized
func NewHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo BuildInfo) *Handlers {
return &Handlers{
Auth: NewAuthHandler(services.Auth),
User: NewUserHandler(services.User),
APIKey: NewAPIKeyHandler(services.ApiKey),
Usage: NewUsageHandler(services.Usage, repos.UsageLog, services.ApiKey),
Redeem: NewRedeemHandler(services.Redeem),
Subscription: NewSubscriptionHandler(services.Subscription),
Admin: &AdminHandlers{
Dashboard: admin.NewDashboardHandler(services.Admin, repos.UsageLog),
User: admin.NewUserHandler(services.Admin),
Group: admin.NewGroupHandler(services.Admin),
Account: admin.NewAccountHandler(services.Admin, services.OAuth, services.RateLimit, services.AccountUsage, services.AccountTest),
OAuth: admin.NewOAuthHandler(services.OAuth, services.Admin),
Proxy: admin.NewProxyHandler(services.Admin),
Redeem: admin.NewRedeemHandler(services.Admin),
Setting: admin.NewSettingHandler(services.Setting, services.Email),
System: admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType),
Subscription: admin.NewSubscriptionHandler(services.Subscription),
Usage: admin.NewUsageHandler(repos.UsageLog, repos.ApiKey, services.Usage, services.Admin),
},
Gateway: NewGatewayHandler(services.Gateway, services.User, services.Concurrency, services.BillingCache),
Setting: NewSettingHandler(services.Setting, buildInfo.Version),
}
}

View File

@@ -5,6 +5,7 @@ import (
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
@@ -68,9 +69,9 @@ func (h *UsageHandler) List(c *gin.Context) {
apiKeyID = id
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog
var result *repository.PaginationResult
var result *pagination.PaginationResult
var err error
if apiKeyID > 0 {
@@ -362,7 +363,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
}
// Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000})
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil {
response.InternalError(c, "Failed to verify API key ownership")
return

View File

@@ -0,0 +1,102 @@
package handler
import (
"sub2api/internal/handler/admin"
"sub2api/internal/service"
"github.com/google/wire"
)
// ProvideAdminHandlers creates the AdminHandlers struct
func ProvideAdminHandlers(
dashboardHandler *admin.DashboardHandler,
userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
oauthHandler *admin.OAuthHandler,
proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler,
settingHandler *admin.SettingHandler,
systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
User: userHandler,
Group: groupHandler,
Account: accountHandler,
OAuth: oauthHandler,
Proxy: proxyHandler,
Redeem: redeemHandler,
Setting: settingHandler,
System: systemHandler,
Subscription: subscriptionHandler,
Usage: usageHandler,
}
}
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
return admin.NewSystemHandler(updateService)
}
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
return NewSettingHandler(settingService, buildInfo.Version)
}
// ProvideHandlers creates the Handlers struct
func ProvideHandlers(
authHandler *AuthHandler,
userHandler *UserHandler,
apiKeyHandler *APIKeyHandler,
usageHandler *UsageHandler,
redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
settingHandler *SettingHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
User: userHandler,
APIKey: apiKeyHandler,
Usage: usageHandler,
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
Setting: settingHandler,
}
}
// ProviderSet is the Wire provider set for all handlers
var ProviderSet = wire.NewSet(
// Top-level handlers
NewAuthHandler,
NewUserHandler,
NewAPIKeyHandler,
NewUsageHandler,
NewRedeemHandler,
NewSubscriptionHandler,
NewGatewayHandler,
ProvideSettingHandler,
// Admin handlers
admin.NewDashboardHandler,
admin.NewUserHandler,
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewOAuthHandler,
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewSettingHandler,
ProvideSystemHandler,
admin.NewSubscriptionHandler,
admin.NewUsageHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
ProvideHandlers,
)

View File

@@ -0,0 +1,38 @@
package infrastructure
import (
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// InitDB 初始化数据库连接
func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 初始化时区(在数据库连接之前,确保时区设置正确)
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, err
}
gormConfig := &gorm.Config{}
if cfg.Server.Mode == "debug" {
gormConfig.Logger = logger.Default.LogMode(logger.Info)
}
// 使用带时区的 DSN 连接数据库
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
if err != nil {
return nil, err
}
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil {
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,16 @@
package infrastructure
import (
"sub2api/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
}

View File

@@ -0,0 +1,25 @@
package infrastructure
import (
"sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
// ProviderSet 提供基础设施层的依赖
var ProviderSet = wire.NewSet(
ProvideDB,
ProvideRedis,
)
// ProvideDB 提供数据库连接
func ProvideDB(cfg *config.Config) (*gorm.DB, error) {
return InitDB(cfg)
}
// ProvideRedis 提供 Redis 客户端
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}

View File

@@ -0,0 +1,130 @@
package middleware
import (
"context"
"crypto/subtle"
"strings"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AdminAuth 管理员认证中间件
// 支持两种认证方式(通过不同的 header 区分):
// 1. Admin API Key: x-api-key: <admin-api-key>
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
func AdminAuth(
authService *service.AuthService,
userRepo interface {
GetByID(ctx context.Context, id int64) (*model.User, error)
GetFirstAdmin(ctx context.Context) (*model.User, error)
},
settingService *service.SettingService,
) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查 x-api-key headerAdmin API Key 认证)
apiKey := c.GetHeader("x-api-key")
if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userRepo) {
return
}
c.Next()
return
}
// 检查 Authorization headerJWT 认证)
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
if !validateJWTForAdmin(c, parts[1], authService, userRepo) {
return
}
c.Next()
return
}
}
// 无有效认证信息
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
}
}
// validateAdminApiKey 验证管理员 API Key
func validateAdminApiKey(
c *gin.Context,
key string,
settingService *service.SettingService,
userRepo interface {
GetFirstAdmin(ctx context.Context) (*model.User, error)
},
) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false
}
// 未配置或不匹配,统一返回相同错误(避免信息泄露)
if storedKey == "" || subtle.ConstantTimeCompare([]byte(key), []byte(storedKey)) != 1 {
AbortWithError(c, 401, "INVALID_ADMIN_KEY", "Invalid admin API key")
return false
}
// 获取真实的管理员用户
admin, err := userRepo.GetFirstAdmin(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
return false
}
c.Set(string(ContextKeyUser), admin)
c.Set("auth_method", "admin_api_key")
return true
}
// validateJWTForAdmin 验证 JWT 并检查管理员权限
func validateJWTForAdmin(
c *gin.Context,
token string,
authService *service.AuthService,
userRepo interface {
GetByID(ctx context.Context, id int64) (*model.User, error)
},
) bool {
// 验证 JWT token
claims, err := authService.ValidateToken(token)
if err != nil {
if err == service.ErrTokenExpired {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return false
}
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
return false
}
// 从数据库获取用户
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID)
if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return false
}
// 检查用户状态
if !user.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return false
}
// 检查管理员权限
if user.Role != model.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return false
}
c.Set(string(ContextKeyUser), user)
c.Set("auth_method", "jwt")
return true
}

View File

@@ -263,3 +263,17 @@ func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
}
return false
}
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
// 启用后标题生成、Warmup等预热请求将返回mock响应不消耗上游token
func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil {
return false
}
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}

View File

@@ -46,8 +46,14 @@ const (
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
)
// 管理员 API Key 前缀(与用户 sk- 前缀区分)
const AdminApiKeyPrefix = "admin-"
// SystemSettings 系统设置结构体用于API响应
type SystemSettings struct {
// 注册设置

View File

@@ -0,0 +1,74 @@
package claude
// Claude Code 客户端相关常量
// Beta header 常量
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0",
"X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64",
"X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0",
"X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
// Model 表示一个 Claude 模型
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels Claude Code 客户端支持的默认模型列表
var DefaultModels = []Model{
{
ID: "claude-opus-4-5-20251101",
Type: "model",
DisplayName: "Claude Opus 4.5",
CreatedAt: "2025-11-01T00:00:00Z",
},
{
ID: "claude-sonnet-4-5-20250929",
Type: "model",
DisplayName: "Claude Sonnet 4.5",
CreatedAt: "2025-09-29T00:00:00Z",
},
{
ID: "claude-haiku-4-5-20251001",
Type: "model",
DisplayName: "Claude Haiku 4.5",
CreatedAt: "2025-10-01T00:00:00Z",
},
}
// DefaultModelIDs 返回默认模型的 ID 列表
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"

View File

@@ -0,0 +1,42 @@
package pagination
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}

View File

@@ -90,7 +90,7 @@ func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize in
})
}
// PaginationResult 分页结果(与repository.PaginationResult兼容
// PaginationResult 分页结果(与pagination.PaginationResult兼容
type PaginationResult struct {
Total int64
Page int

View File

@@ -1,88 +1,39 @@
package sysutil
import (
"fmt"
"log"
"os"
"os/exec"
"runtime"
"time"
)
const serviceName = "sub2api"
// findExecutable finds the full path of an executable
// by checking common system paths
func findExecutable(name string) string {
// First try exec.LookPath (uses current PATH)
if path, err := exec.LookPath(name); err == nil {
return path
}
// Fallback: check common paths
commonPaths := []string{
"/usr/bin/" + name,
"/bin/" + name,
"/usr/sbin/" + name,
"/sbin/" + name,
}
for _, path := range commonPaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
// Return the name as-is and let exec fail with a clear error
return name
}
// RestartService triggers a service restart via systemd.
// RestartService triggers a service restart by gracefully exiting.
//
// IMPORTANT: This function initiates the restart and returns immediately.
// The actual restart happens asynchronously - the current process will be killed
// by systemd and a new process will be started.
//
// We use Start() instead of Run() because:
// - systemctl restart will kill the current process first
// - Run() waits for completion, but the process dies before completion
// - Start() spawns the command independently, allowing systemd to handle the full cycle
// This relies on systemd's Restart=always configuration to automatically
// restart the service after it exits. This is the industry-standard approach:
// - Simple and reliable
// - No sudo permissions needed
// - No complex process management
// - Leverages systemd's native restart capability
//
// Prerequisites:
// - Linux OS with systemd
// - NOPASSWD sudo access configured (install.sh creates /etc/sudoers.d/sub2api)
// - Service configured with Restart=always in systemd unit file
func RestartService() error {
if runtime.GOOS != "linux" {
return fmt.Errorf("systemd restart only available on Linux")
log.Println("Service restart via exit only works on Linux with systemd")
return nil
}
log.Println("Initiating service restart...")
log.Println("Initiating service restart by graceful exit...")
log.Println("systemd will automatically restart the service (Restart=always)")
// Find full paths for sudo and systemctl
// This ensures the commands work even if PATH is limited in systemd service
sudoPath := findExecutable("sudo")
systemctlPath := findExecutable("systemctl")
// Give a moment for logs to flush and response to be sent
go func() {
time.Sleep(100 * time.Millisecond)
os.Exit(0)
}()
log.Printf("Using sudo: %s, systemctl: %s", sudoPath, systemctlPath)
// The sub2api user has NOPASSWD sudo access for systemctl commands
// (configured by install.sh in /etc/sudoers.d/sub2api).
// Use -n (non-interactive) to prevent sudo from waiting for password input
//
// Use setsid to create a new session, ensuring the child process
// survives even if the parent process is killed by systemctl restart
setsidPath := findExecutable("setsid")
cmd := exec.Command(setsidPath, sudoPath, "-n", systemctlPath, "restart", serviceName)
// Detach from parent's stdio to ensure clean separation
cmd.Stdin = nil
cmd.Stdout = nil
cmd.Stderr = nil
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to initiate service restart: %w", err)
}
log.Println("Service restart initiated successfully")
return nil
}

View File

@@ -0,0 +1,8 @@
package usagestats
// AccountStats 账号使用统计
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"time"
"gorm.io/gorm"
@@ -47,12 +48,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
}
func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) {
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) {
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account
var total int64
@@ -94,7 +95,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati
pages++
}
return accounts, &PaginationResult{
return accounts, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -226,7 +227,7 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
"rate_limited_at": now,
"rate_limited_at": now,
"rate_limit_reset_at": resetAt,
}).Error
}

View File

@@ -0,0 +1,51 @@
package repository
import (
"context"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
)
type apiKeyCache struct {
rdb *redis.Client
}
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
return &apiKeyCache{rdb: rdb}
}
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
}
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, err := pipe.Exec(ctx)
return err
}
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return c.rdb.Incr(ctx, apiKey).Err()
}
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return c.rdb.Expire(ctx, apiKey, ttl).Err()
}

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -45,7 +46,7 @@ func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
}
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
var total int64
@@ -64,7 +65,7 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
pages++
}
return keys, &PaginationResult{
return keys, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -84,7 +85,7 @@ func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
return count > 0, err
}
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
var total int64
@@ -103,7 +104,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
pages++
}
return keys, &PaginationResult{
return keys, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),

View File

@@ -0,0 +1,174 @@
package repository
import (
"context"
"errors"
"fmt"
"log"
"strconv"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
)
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
)
var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
type billingCache struct {
rdb *redis.Client
}
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
return &billingCache{rdb: rdb}
}
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return 0, err
}
return strconv.ParseFloat(val, 64)
}
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
}
return nil
}
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
return c.parseSubscriptionCache(result)
}
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
result := &ports.SubscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
if data == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]interface{}{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
}
return nil
}
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,235 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"time"
"sub2api/internal/pkg/oauth"
"sub2api/internal/service"
"github.com/imroc/req/v3"
)
type claudeOAuthService struct{}
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
return &claudeOAuthService{}
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := "https://claude.ai/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetSuccessResult(&orgs).
Get(targetURL)
if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
}
if len(orgs) == 0 {
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
return orgs[0].UUID, nil
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
reqBody := map[string]interface{}{
"response_type": "code",
"client_id": oauth.ClientID,
"organization_uuid": orgUUID,
"redirect_uri": oauth.RedirectURI,
"scope": scope,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
RedirectURI string `json:"redirect_uri"`
}
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetHeader("Accept", "application/json").
SetHeader("Accept-Language", "en-US,en;q=0.9").
SetHeader("Cache-Control", "no-cache").
SetHeader("Origin", "https://claude.ai").
SetHeader("Referer", "https://claude.ai/new").
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&result).
Post(authURL)
if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
}
if result.RedirectURI == "" {
return "", fmt.Errorf("no redirect_uri in response")
}
parsedURL, err := url.Parse(result.RedirectURI)
if err != nil {
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
}
queryParams := parsedURL.Query()
authCode := queryParams.Get("code")
responseState := queryParams.Get("state")
if authCode == "" {
return "", fmt.Errorf("no authorization code in redirect_uri")
}
fullCode := authCode
if responseState != "" {
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
return fullCode, nil
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
authCode := code
codeState := ""
if len(code) > 0 {
parts := make([]string, 0, 2)
for i, part := range []rune(code) {
if part == '#' {
authCode = code[:i]
codeState = code[i+1:]
break
}
}
if len(parts) == 0 {
authCode = code
}
}
reqBody := map[string]interface{}{
"code": authCode,
"grant_type": "authorization_code",
"client_id": oauth.ClientID,
"redirect_uri": oauth.RedirectURI,
"code_verifier": codeVerifier,
}
if codeState != "" {
reqBody["state"] = codeState
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
return &tokenResp, nil
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauth.ClientID)
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome().
SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}

View File

@@ -0,0 +1,64 @@
package repository
import (
"net/http"
"net/url"
"time"
"sub2api/internal/config"
"sub2api/internal/service"
)
type claudeUpstreamService struct {
defaultClient *http.Client
cfg *config.Config
}
func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
return &claudeUpstreamService{
defaultClient: &http.Client{Transport: transport},
cfg: cfg,
}
}
func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
if proxyURL == "" {
return s.defaultClient.Do(req)
}
client := s.createProxyClient(proxyURL)
return client.Do(req)
}
func (s *claudeUpstreamService) createProxyClient(proxyURL string) *http.Client {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return s.defaultClient
}
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
return &http.Client{Transport: transport}
}

View File

@@ -0,0 +1,59 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
"sub2api/internal/service"
)
type claudeUsageService struct{}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
return &claudeUsageService{}
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
var usageResp service.ClaudeUsageResponse
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
return nil, fmt.Errorf("decode response failed: %w", err)
}
return &usageResp, nil
}

View File

@@ -0,0 +1,132 @@
package repository
import (
"context"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
accountConcurrencyKeyPrefix = "concurrency:account:"
userConcurrencyKeyPrefix = "concurrency:user:"
waitQueueKeyPrefix = "concurrency:wait:"
concurrencyTTL = 5 * time.Minute
)
var (
acquireScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current < tonumber(ARGV[1]) then
redis.call('INCR', KEYS[1])
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
end
return 0
`)
releaseScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
incrementWaitScript = redis.NewScript(`
local waitKey = KEYS[1]
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= maxWait then
return 0
end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl)
return 1
`)
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
)
type concurrencyCache struct {
rdb *redis.Client
}
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
return &concurrencyCache{rdb: rdb}
}
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
return c.rdb.Get(ctx, key).Int()
}
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
}
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}

View File

@@ -0,0 +1,48 @@
package repository
import (
"context"
"encoding/json"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const verifyCodeKeyPrefix = "verify_code:"
type emailCache struct {
rdb *redis.Client
}
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
return &emailCache{rdb: rdb}
}
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
key := verifyCodeKeyPrefix + email
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data ports.VerificationCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
key := verifyCodeKeyPrefix + email
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
key := verifyCodeKeyPrefix + email
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,35 @@
package repository
import (
"context"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const stickySessionPrefix = "sticky_session:"
type gatewayCache struct {
rdb *redis.Client
}
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
return &gatewayCache{rdb: rdb}
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
key := stickySessionPrefix + sessionHash
return c.rdb.Get(ctx, key).Int64()
}
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
return c.rdb.Set(ctx, key, accountID, ttl).Err()
}
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
return c.rdb.Expire(ctx, key, ttl).Err()
}

View File

@@ -0,0 +1,116 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"
"sub2api/internal/service"
)
type githubReleaseClient struct {
httpClient *http.Client
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "Sub2API-Updater")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release service.GitHubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
return &release, nil
}
func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode)
}
// SECURITY: Check Content-Length if available
if resp.ContentLength > maxSize {
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxSize)
}
out, err := os.Create(dest)
if err != nil {
return err
}
defer out.Close()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxSize+1)
written, err := io.Copy(out, limited)
if err != nil {
return err
}
// Check if we hit the limit (downloaded more than maxSize)
if written > maxSize {
os.Remove(dest) // Clean up partial file
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
}
return nil
}
func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -36,12 +37,12 @@ func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
}
func (r *GroupRepository) List(ctx context.Context, params PaginationParams) ([]model.Group, *PaginationResult, error) {
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *GroupRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *PaginationResult, error) {
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
var groups []model.Group
var total int64
@@ -77,7 +78,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params Pagination
pages++
}
return groups, &PaginationResult{
return groups, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),

View File

@@ -0,0 +1,47 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
fingerprintKeyPrefix = "fingerprint:"
fingerprintTTL = 24 * time.Hour
)
type identityCache struct {
rdb *redis.Client
}
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache {
return &identityCache{rdb: rdb}
}
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var fp ports.Fingerprint
if err := json.Unmarshal([]byte(val), &fp); err != nil {
return nil, err
}
return &fp, nil
}
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
val, err := json.Marshal(fp)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
}

View File

@@ -0,0 +1,73 @@
package repository
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"sub2api/internal/service"
)
type pricingRemoteClient struct {
httpClient *http.Client
}
func NewPricingRemoteClient() service.PricingRemoteClient {
return &pricingRemoteClient{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// 哈希文件格式hash filename 或者纯 hash
hash := strings.TrimSpace(string(body))
parts := strings.Fields(hash)
if len(parts) > 0 {
return parts[0], nil
}
return hash, nil
}

View File

@@ -0,0 +1,104 @@
package repository
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"sub2api/internal/service"
"golang.org/x/net/proxy"
)
type proxyProbeService struct{}
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
return &proxyProbeService{}
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
transport, err := createProxyTransport(proxyURL)
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
}
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
}
defer resp.Body.Close()
latencyMs := time.Since(startTime).Milliseconds()
if resp.StatusCode != http.StatusOK {
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
}
var ipInfo struct {
IP string `json:"ip"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
}
return &service.ProxyExitInfo{
IP: ipInfo.IP,
City: ipInfo.City,
Region: ipInfo.Region,
Country: ipInfo.Country,
}, latencyMs, nil
}
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -36,12 +37,12 @@ func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
}
func (r *ProxyRepository) List(ctx context.Context, params PaginationParams) ([]model.Proxy, *PaginationResult, error) {
func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params PaginationParams, protocol, status, search string) ([]model.Proxy, *PaginationResult, error) {
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy
var total int64
@@ -72,7 +73,7 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params Pagination
pages++
}
return proxies, &PaginationResult{
return proxies, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),

View File

@@ -0,0 +1,49 @@
package repository
import (
"context"
"fmt"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const (
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
redeemLockKeyPrefix = "redeem:lock:"
redeemRateLimitDuration = 24 * time.Hour
)
type redeemCache struct {
rdb *redis.Client
}
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache {
return &redeemCache{rdb: rdb}
}
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
return c.rdb.Get(ctx, key).Int()
}
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
_, err := pipe.Exec(ctx)
return err
}
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
key := redeemLockKeyPrefix + code
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
key := redeemLockKeyPrefix + code
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"time"
"gorm.io/gorm"
@@ -46,12 +47,12 @@ func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
}
func (r *RedeemCodeRepository) List(ctx context.Context, params PaginationParams) ([]model.RedeemCode, *PaginationResult, error) {
func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params PaginationParams, codeType, status, search string) ([]model.RedeemCode, *PaginationResult, error) {
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
var codes []model.RedeemCode
var total int64
@@ -82,7 +83,7 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params Pagin
pages++
}
return codes, &PaginationResult{
return codes, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),

View File

@@ -1,9 +1,5 @@
package repository
import (
"gorm.io/gorm"
)
// Repositories 所有仓库的集合
type Repositories struct {
User *UserRepository
@@ -16,59 +12,3 @@ type Repositories struct {
Setting *SettingRepository
UserSubscription *UserSubscriptionRepository
}
// NewRepositories 创建所有仓库
func NewRepositories(db *gorm.DB) *Repositories {
return &Repositories{
User: NewUserRepository(db),
ApiKey: NewApiKeyRepository(db),
Group: NewGroupRepository(db),
Account: NewAccountRepository(db),
Proxy: NewProxyRepository(db),
RedeemCode: NewRedeemCodeRepository(db),
UsageLog: NewUsageLogRepository(db),
Setting: NewSettingRepository(db),
UserSubscription: NewUserSubscriptionRepository(db),
}
}
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}

View File

@@ -0,0 +1,55 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"sub2api/internal/service"
)
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
type turnstileVerifier struct {
httpClient *http.Client
}
func NewTurnstileVerifier() service.TurnstileVerifier {
return &turnstileVerifier{
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*service.TurnstileVerifyResponse, error) {
formData := url.Values{}
formData.Set("secret", secretKey)
formData.Set("response", token)
if remoteIP != "" {
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := v.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
var result service.TurnstileVerifyResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decode response: %w", err)
}
return &result, nil
}

View File

@@ -0,0 +1,28 @@
package repository
import (
"context"
"time"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
const updateCacheKey = "update:latest"
type updateCache struct {
rdb *redis.Client
}
func NewUpdateCache(rdb *redis.Client) ports.UpdateCache {
return &updateCache{rdb: rdb}
}
func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) {
return c.rdb.Get(ctx, updateCacheKey).Result()
}
func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error {
return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err()
}

View File

@@ -3,7 +3,9 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone"
"sub2api/internal/pkg/usagestats"
"time"
"gorm.io/gorm"
@@ -30,7 +32,7 @@ func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.Usag
return &log, nil
}
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -49,7 +51,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
pages++
}
return logs, &PaginationResult{
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -57,7 +59,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
}, nil
}
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -76,7 +78,7 @@ func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
pages++
}
return logs, &PaginationResult{
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -270,7 +272,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil
}
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -289,7 +291,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
pages++
}
return logs, &PaginationResult{
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -297,7 +299,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}, nil
}
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
@@ -306,7 +308,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
return logs, nil, err
}
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
@@ -315,7 +317,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
return logs, nil, err
}
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
@@ -324,7 +326,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
return logs, nil, err
}
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
@@ -337,15 +339,8 @@ func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
}
// AccountStats 账号使用统计
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}
// GetAccountTodayStats 获取账号今日统计
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*AccountStats, error) {
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
today := timezone.Today()
var stats struct {
@@ -367,7 +362,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
return nil, err
}
return &AccountStats{
return &usagestats.AccountStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
@@ -375,7 +370,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
}
// GetAccountWindowStats 获取账号时间窗口内的统计
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*AccountStats, error) {
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
var stats struct {
Requests int64 `gorm:"column:requests"`
Tokens int64 `gorm:"column:tokens"`
@@ -395,7 +390,7 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
return nil, err
}
return &AccountStats{
return &usagestats.AccountStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
@@ -436,75 +431,13 @@ type UserUsageTrendPoint struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// GetUsageTrend returns usage trend data grouped by date
// granularity: "day" or "hour"
func (r *UsageLogRepository) GetUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
var results []TrendDataPoint
// Choose date format based on granularity
var dateFormat string
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
} else {
dateFormat = "YYYY-MM-DD"
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
TO_CHAR(created_at, ?) as date,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`, dateFormat).
Where("created_at >= ? AND created_at < ?", startTime, endTime).
Group("date").
Order("date ASC").
Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetModelStats returns usage statistics grouped by model
func (r *UsageLogRepository) GetModelStats(ctx context.Context, startTime, endTime time.Time) ([]ModelStat, error) {
var results []ModelStat
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`).
Where("created_at >= ? AND created_at < ?", startTime, endTime).
Group("model").
Order("total_tokens DESC").
Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
@@ -780,7 +713,7 @@ type UsageLogFilters struct {
}
// ListWithFilters lists usage logs with optional filters (for admin)
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -816,7 +749,7 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params Paginat
pages++
}
return logs, &PaginationResult{
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -838,7 +771,7 @@ type UsageStats struct {
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
@@ -964,6 +897,76 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
return result, nil
}
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
var results []TrendDataPoint
var dateFormat string
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
} else {
dateFormat = "YYYY-MM-DD"
}
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
TO_CHAR(created_at, ?) as date,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`, dateFormat).
Where("created_at >= ? AND created_at < ?", startTime, endTime)
if userID > 0 {
db = db.Where("user_id = ?", userID)
}
if apiKeyID > 0 {
db = db.Where("api_key_id = ?", apiKeyID)
}
err := db.Group("date").Order("date ASC").Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]ModelStat, error) {
var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
`).
Where("created_at >= ? AND created_at < ?", startTime, endTime)
if userID > 0 {
db = db.Where("user_id = ?", userID)
}
if apiKeyID > 0 {
db = db.Where("api_key_id = ?", apiKeyID)
}
err := db.Group("model").Order("total_tokens DESC").Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetGlobalStats gets usage statistics for all users within a time range
func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
var stats struct {

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -45,12 +46,12 @@ func (r *UserRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
}
func (r *UserRepository) List(ctx context.Context, params PaginationParams) ([]model.User, *PaginationResult, error) {
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists users with optional filtering by status, role, and search query
func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationParams, status, role, search string) ([]model.User, *PaginationResult, error) {
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User
var total int64
@@ -81,7 +82,7 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationP
pages++
}
return users, &PaginationResult{
return users, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -128,3 +129,15 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
return result.RowsAffected, result.Error
}
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
Order("id ASC").
First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}

View File

@@ -5,6 +5,7 @@ import (
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -100,7 +101,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
}
// ListByGroupID 获取分组的所有订阅(分页)
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.UserSubscription, *PaginationResult, error) {
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
var total int64
@@ -126,7 +127,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
pages++
}
return subs, &PaginationResult{
return subs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -135,7 +136,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
}
// List 获取所有订阅(分页,支持筛选)
func (r *UserSubscriptionRepository) List(ctx context.Context, params PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *PaginationResult, error) {
func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
var total int64
@@ -172,7 +173,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params Pagination
pages++
}
return subs, &PaginationResult{
return subs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),

View File

@@ -0,0 +1,51 @@
package repository
import (
"sub2api/internal/service/ports"
"github.com/google/wire"
)
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
NewApiKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewProxyRepository,
NewRedeemCodeRepository,
NewUsageLogRepository,
NewSettingRepository,
NewUserSubscriptionRepository,
wire.Struct(new(Repositories), "*"),
// Cache implementations
NewGatewayCache,
NewBillingCache,
NewApiKeyCache,
NewConcurrencyCache,
NewEmailCache,
NewIdentityCache,
NewRedeemCache,
NewUpdateCache,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
NewPricingRemoteClient,
NewGitHubReleaseClient,
NewProxyExitInfoProber,
NewClaudeUsageFetcher,
NewClaudeOAuthClient,
NewClaudeUpstream,
// Bind concrete repositories to service port interfaces
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
wire.Bind(new(ports.GroupRepository), new(*GroupRepository)),
wire.Bind(new(ports.AccountRepository), new(*AccountRepository)),
wire.Bind(new(ports.ProxyRepository), new(*ProxyRepository)),
wire.Bind(new(ports.RedeemCodeRepository), new(*RedeemCodeRepository)),
wire.Bind(new(ports.UsageLogRepository), new(*UsageLogRepository)),
wire.Bind(new(ports.SettingRepository), new(*SettingRepository)),
wire.Bind(new(ports.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
)

View File

@@ -0,0 +1,45 @@
package server
import (
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/repository"
"sub2api/internal/service"
"time"
"github.com/gin-gonic/gin"
"github.com/google/wire"
)
// ProviderSet 提供服务器层的依赖
var ProviderSet = wire.NewSet(
ProvideRouter,
ProvideHTTPServer,
)
// ProvideRouter 提供路由器
func ProvideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
}
r := gin.New()
r.Use(gin.Recovery())
return SetupRouter(r, cfg, handlers, services, repos)
}
// ProvideHTTPServer 提供 HTTP 服务器
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
return &http.Server{
Addr: cfg.Server.Address(),
Handler: router,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
}

View File

@@ -0,0 +1,293 @@
package server
import (
"net/http"
"sub2api/internal/config"
"sub2api/internal/handler"
"sub2api/internal/middleware"
"sub2api/internal/repository"
"sub2api/internal/service"
"sub2api/internal/web"
"github.com/gin-gonic/gin"
)
// SetupRouter 配置路由器中间件和路由
func SetupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
// 应用中间件
r.Use(middleware.Logger())
r.Use(middleware.CORS())
// 注册路由
registerRoutes(r, handlers, services, repos)
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
}
return r
}
// registerRoutes 注册所有 HTTP 路由
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Claude Code 遥测日志忽略直接返回200
r.POST("/api/event_logging/batch", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
// API v1
v1 := r.Group("/api/v1")
{
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的接口
authenticated := v1.Group("")
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
{
// 当前用户信息
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
// 管理员接口
admin := v1.Group("/admin")
admin.Use(middleware.AdminAuth(s.Auth, repos.User, s.Setting))
{
// 仪表盘
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
// 用户管理
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
}
// 分组管理
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
// 账号管理
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
// OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
// 代理管理
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
// 卡密管理
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
// 系统设置
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
}
// 系统管理
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
// 订阅管理
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
// 使用记录管理
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
}
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
}
}

View File

@@ -5,7 +5,8 @@ import (
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"gorm.io/gorm"
)
@@ -41,12 +42,12 @@ type UpdateAccountRequest struct {
// AccountService 账号管理服务
type AccountService struct {
accountRepo *repository.AccountRepository
groupRepo *repository.GroupRepository
accountRepo ports.AccountRepository
groupRepo ports.GroupRepository
}
// NewAccountService 创建账号服务实例
func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService {
func NewAccountService(accountRepo ports.AccountRepository, groupRepo ports.GroupRepository) *AccountService {
return &AccountService{
accountRepo: accountRepo,
groupRepo: groupRepo,
@@ -108,7 +109,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
}
// List 获取账号列表
func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) {
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err)

View File

@@ -10,12 +10,12 @@ import (
"io"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"sub2api/internal/repository"
"sub2api/internal/pkg/claude"
"sub2api/internal/service/ports"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -23,7 +23,6 @@ import (
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testModel = "claude-sonnet-4-5-20250929"
)
// TestEvent represents a SSE event for account testing
@@ -37,19 +36,17 @@ type TestEvent struct {
// AccountTestService handles account testing operations
type AccountTestService struct {
repos *repository.Repositories
oauthService *OAuthService
httpClient *http.Client
accountRepo ports.AccountRepository
oauthService *OAuthService
claudeUpstream ClaudeUpstream
}
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService {
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, claudeUpstream ClaudeUpstream) *AccountTestService {
return &AccountTestService{
repos: repos,
oauthService: oauthService,
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
accountRepo: accountRepo,
oauthService: oauthService,
claudeUpstream: claudeUpstream,
}
}
@@ -62,10 +59,10 @@ func generateSessionString() string {
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
}
// createTestPayload creates a minimal test request payload for OAuth/Setup Token accounts
func createTestPayload() map[string]interface{} {
// createTestPayload creates a Claude Code style test request payload
func createTestPayload(modelID string) map[string]interface{} {
return map[string]interface{}{
"model": testModel,
"model": modelID,
"messages": []map[string]interface{}{
{
"role": "user",
@@ -92,45 +89,48 @@ func createTestPayload() map[string]interface{} {
"metadata": map[string]string{
"user_id": generateSessionString(),
},
"max_tokens": 1024,
"max_tokens": 1024,
"temperature": 1,
"stream": true,
}
}
// createApiKeyTestPayload creates a simpler test request payload for API Key accounts
func createApiKeyTestPayload(model string) map[string]interface{} {
return map[string]interface{}{
"model": model,
"messages": []map[string]interface{}{
{
"role": "user",
"content": "hi",
},
},
"max_tokens": 1024,
"stream": true,
"stream": true,
}
}
// TestAccountConnection tests an account's connection by sending a test request
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64) error {
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
ctx := c.Request.Context()
// Get account
account, err := s.repos.Account.GetByID(ctx, accountID)
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
return s.sendErrorAndEnd(c, "Account not found")
}
// Determine authentication method based on account type
// Determine the model to use
testModelID := modelID
if testModelID == "" {
testModelID = claude.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if mapping != nil && len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
}
// Determine authentication method and API URL
var authToken string
var authType string // "bearer" for OAuth, "apikey" for API Key
var useBearer bool
var apiURL string
if account.IsOAuth() {
// OAuth or Setup Token account
authType = "bearer"
// OAuth or Setup Token - use Bearer token
useBearer = true
apiURL = testClaudeAPIURL
authToken = account.GetCredential("access_token")
if authToken == "" {
@@ -141,7 +141,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
needRefresh := false
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil && time.Now().Unix()+300 > expiresAt { // 5 minute buffer
if err == nil && time.Now().Unix()+300 > expiresAt {
needRefresh = true
}
}
@@ -154,19 +154,17 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
authToken = tokenInfo.AccessToken
}
} else if account.Type == "apikey" {
// API Key account
authType = "apikey"
// API Key - use x-api-key header
useBearer = false
authToken = account.GetCredential("api_key")
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
// Get base URL (use default if not set)
apiURL = account.GetBaseURL()
if apiURL == "" {
apiURL = "https://api.anthropic.com"
}
// Append /v1/messages endpoint
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
@@ -179,57 +177,42 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
// Create test request payload
var payload map[string]interface{}
var actualModel string
if authType == "apikey" {
// Use simpler payload for API Key (without Claude Code specific fields)
// Apply model mapping if configured
actualModel = account.GetMappedModel(testModel)
payload = createApiKeyTestPayload(actualModel)
} else {
actualModel = testModel
payload = createTestPayload()
}
// Create Claude Code style payload (same for all account types)
payload := createTestPayload(testModelID)
payloadBytes, _ := json.Marshal(payload)
// Send test_start event with model info
s.sendEvent(c, TestEvent{Type: "test_start", Model: actualModel})
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
// Set headers based on auth type
// Set common headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
if authType == "bearer" {
// Apply Claude Code client headers
for key, value := range claude.DefaultHeaders {
req.Header.Set(key, value)
}
// Set authentication header
if useBearer {
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31,interleaved-thinking-2025-05-14,output-128k-2025-02-19")
} else {
// API Key uses x-api-key header
req.Header.Set("x-api-key", authToken)
}
// Configure proxy if account has one
transport := http.DefaultTransport.(*http.Transport).Clone()
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
proxyURL = account.Proxy.URL()
}
client := &http.Client{
Transport: transport,
Timeout: 60 * time.Second,
}
resp, err := client.Do(req)
resp, err := s.claudeUpstream.Do(req, proxyURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
@@ -252,7 +235,6 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
// Stream ended, send complete event
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
@@ -310,5 +292,5 @@ func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error {
log.Printf("Account test error: %s", errorMsg)
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
return fmt.Errorf(errorMsg)
return fmt.Errorf("%s", errorMsg)
}

View File

@@ -2,17 +2,13 @@ package service
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"sync"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
)
// usageCache 用于缓存usage数据
@@ -35,10 +31,10 @@ type WindowStats struct {
// UsageProgress 使用量进度
type UsageProgress struct {
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+100表示100%)
ResetsAt *time.Time `json:"resets_at"` // 重置时间
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+100表示100%)
ResetsAt *time.Time `json:"resets_at"` // 重置时间
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
}
// UsageInfo 账号使用量信息
@@ -65,21 +61,26 @@ type ClaudeUsageResponse struct {
} `json:"seven_day_sonnet"`
}
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
type ClaudeUsageFetcher interface {
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
}
// AccountUsageService 账号使用量查询服务
type AccountUsageService struct {
repos *repository.Repositories
accountRepo ports.AccountRepository
usageLogRepo ports.UsageLogRepository
oauthService *OAuthService
httpClient *http.Client
usageFetcher ClaudeUsageFetcher
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService {
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
return &AccountUsageService{
repos: repos,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
oauthService: oauthService,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
usageFetcher: usageFetcher,
}
}
@@ -88,7 +89,7 @@ func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthS
// Setup Token账号: 根据session_window推算5h窗口7d数据不可用没有profile scope
// API Key账号: 不支持usage查询
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
account, err := s.repos.Account.GetByID(ctx, accountID)
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account failed: %w", err)
}
@@ -148,7 +149,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime)
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
return
@@ -163,7 +164,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
// GetTodayStats 获取账号今日统计
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID)
stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get today stats failed: %w", err)
}
@@ -177,58 +178,23 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
// 获取access token从credentials中获取
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return nil, fmt.Errorf("no access token available")
}
// 获取代理配置
transport := http.DefaultTransport.(*http.Transport).Clone()
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
proxyURL = account.Proxy.URL()
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
// 构建请求
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
usageResp, err := s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
// 发送请求
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
// 解析响应
var usageResp ClaudeUsageResponse
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
return nil, fmt.Errorf("decode response failed: %w", err)
}
// 转换为UsageInfo
now := time.Now()
return s.buildUsageInfo(&usageResp, &now), nil
return s.buildUsageInfo(usageResp, &now), nil
}
// parseTime 尝试多种格式解析时间

View File

@@ -2,20 +2,14 @@ package service
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"golang.org/x/net/proxy"
"gorm.io/gorm"
)
@@ -177,44 +171,63 @@ type ProxyTestResult struct {
Country string `json:"country,omitempty"`
}
// ProxyExitInfo represents proxy exit information from ipinfo.io
type ProxyExitInfo struct {
IP string
City string
Region string
Country string
}
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo *repository.UserRepository
groupRepo *repository.GroupRepository
accountRepo *repository.AccountRepository
proxyRepo *repository.ProxyRepository
apiKeyRepo *repository.ApiKeyRepository
redeemCodeRepo *repository.RedeemCodeRepository
usageLogRepo *repository.UsageLogRepository
userSubRepo *repository.UserSubscriptionRepository
userRepo ports.UserRepository
groupRepo ports.GroupRepository
accountRepo ports.AccountRepository
proxyRepo ports.ProxyRepository
apiKeyRepo ports.ApiKeyRepository
redeemCodeRepo ports.RedeemCodeRepository
usageLogRepo ports.UsageLogRepository
userSubRepo ports.UserSubscriptionRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
}
// NewAdminService creates a new AdminService
func NewAdminService(repos *repository.Repositories) AdminService {
func NewAdminService(
userRepo ports.UserRepository,
groupRepo ports.GroupRepository,
accountRepo ports.AccountRepository,
proxyRepo ports.ProxyRepository,
apiKeyRepo ports.ApiKeyRepository,
redeemCodeRepo ports.RedeemCodeRepository,
usageLogRepo ports.UsageLogRepository,
userSubRepo ports.UserSubscriptionRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
) AdminService {
return &adminServiceImpl{
userRepo: repos.User,
groupRepo: repos.Group,
accountRepo: repos.Account,
proxyRepo: repos.Proxy,
apiKeyRepo: repos.ApiKey,
redeemCodeRepo: repos.RedeemCode,
usageLogRepo: repos.UsageLog,
userSubRepo: repos.UserSubscription,
}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
// 注意AdminService是接口需要类型断言
func SetAdminServiceBillingCache(adminService AdminService, billingCacheService *BillingCacheService) {
if impl, ok := adminService.(*adminServiceImpl); ok {
impl.billingCacheService = billingCacheService
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
usageLogRepo: usageLogRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
}
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
if err != nil {
return nil, 0, err
@@ -383,7 +396,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, 0, err
@@ -404,7 +417,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil {
return nil, 0, err
@@ -575,7 +588,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
return nil, 0, err
@@ -585,7 +598,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil {
return nil, 0, err
@@ -703,7 +716,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
@@ -788,7 +801,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
// Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil {
return nil, 0, err
@@ -872,79 +885,12 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
return nil, err
}
return testProxyConnection(ctx, proxy)
}
// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json
func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestResult, error) {
proxyURL := proxy.URL()
// Create HTTP client with proxy
transport, err := createProxyTransport(proxyURL)
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Failed to create proxy transport: %v", err),
}, nil
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
}
// Measure latency
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Failed to create request: %v", err),
}, nil
}
resp, err := client.Do(req)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Proxy connection failed: %v", err),
}, nil
}
defer resp.Body.Close()
latencyMs := time.Since(startTime).Milliseconds()
if resp.StatusCode != http.StatusOK {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Request failed with status: %d", resp.StatusCode),
LatencyMs: latencyMs,
}, nil
}
// Parse ipinfo.io response
var ipInfo struct {
IP string `json:"ip"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible but failed to read response",
LatencyMs: latencyMs,
}, nil
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible but failed to parse response",
LatencyMs: latencyMs,
Message: err.Error(),
}, nil
}
@@ -952,38 +898,9 @@ func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestRes
Success: true,
Message: "Proxy is accessible",
LatencyMs: latencyMs,
IPAddress: ipInfo.IP,
City: ipInfo.City,
Region: ipInfo.Region,
Country: ipInfo.Country,
IPAddress: exitInfo.IP,
City: exitInfo.City,
Region: exitInfo.Region,
Country: exitInfo.Country,
}, nil
}
// createProxyTransport creates an HTTP transport with the given proxy URL
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}

View File

@@ -8,8 +8,9 @@ import (
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"time"
"github.com/redis/go-redis/v9"
@@ -17,18 +18,16 @@ import (
)
var (
ErrApiKeyNotFound = errors.New("api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
ErrApiKeyNotFound = errors.New("api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
)
const (
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
apiKeyMaxErrorsPerHour = 20
apiKeyRateLimitDuration = time.Hour
apiKeyMaxErrorsPerHour = 20
)
// CreateApiKeyRequest 创建API Key请求
@@ -47,21 +46,21 @@ type UpdateApiKeyRequest struct {
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo *repository.ApiKeyRepository
userRepo *repository.UserRepository
groupRepo *repository.GroupRepository
userSubRepo *repository.UserSubscriptionRepository
rdb *redis.Client
apiKeyRepo ports.ApiKeyRepository
userRepo ports.UserRepository
groupRepo ports.GroupRepository
userSubRepo ports.UserSubscriptionRepository
cache ports.ApiKeyCache
cfg *config.Config
}
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo *repository.ApiKeyRepository,
userRepo *repository.UserRepository,
groupRepo *repository.GroupRepository,
userSubRepo *repository.UserSubscriptionRepository,
rdb *redis.Client,
apiKeyRepo ports.ApiKeyRepository,
userRepo ports.UserRepository,
groupRepo ports.GroupRepository,
userSubRepo ports.UserSubscriptionRepository,
cache ports.ApiKeyCache,
cfg *config.Config,
) *ApiKeyService {
return &ApiKeyService{
@@ -69,7 +68,7 @@ func NewApiKeyService(
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
rdb: rdb,
cache: cache,
cfg: cfg,
}
}
@@ -112,13 +111,11 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
count, err := s.rdb.Get(ctx, key).Int()
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作
return nil
@@ -133,16 +130,11 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil {
if s.cache == nil {
return
}
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, _ = pipe.Exec(ctx)
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
}
// canUserBindGroup 检查用户是否可以绑定指定分组
@@ -237,7 +229,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) {
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
@@ -272,7 +264,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
}
// 缓存到Redis可选TTL设置为5分钟
if s.rdb != nil {
if s.cache != nil {
// 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误
}
@@ -325,9 +317,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
if req.Status != nil {
apiKey.Status = *req.Status
// 如果状态改变清除Redis缓存
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
_ = s.rdb.Del(ctx, cacheKey)
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
}
}
@@ -354,9 +345,8 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
}
// 清除Redis缓存
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
_ = s.rdb.Del(ctx, cacheKey)
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
}
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
@@ -399,13 +389,13 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.rdb != nil {
if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
return fmt.Errorf("increment usage: %w", err)
}
// 设置24小时过期
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
}
return nil
}

View File

@@ -7,7 +7,7 @@ import (
"log"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"time"
"github.com/golang-jwt/jwt/v5"
@@ -16,13 +16,13 @@ import (
)
var (
ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotActive = errors.New("user is not active")
ErrEmailExists = errors.New("email already exists")
ErrInvalidToken = errors.New("invalid token")
ErrTokenExpired = errors.New("token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled")
ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotActive = errors.New("user is not active")
ErrEmailExists = errors.New("email already exists")
ErrInvalidToken = errors.New("invalid token")
ErrTokenExpired = errors.New("token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled")
)
// JWTClaims JWT载荷数据
@@ -35,7 +35,7 @@ type JWTClaims struct {
// AuthService 认证服务
type AuthService struct {
userRepo *repository.UserRepository
userRepo ports.UserRepository
cfg *config.Config
settingService *SettingService
emailService *EmailService
@@ -44,33 +44,24 @@ type AuthService struct {
}
// NewAuthService 创建认证服务实例
func NewAuthService(userRepo *repository.UserRepository, cfg *config.Config) *AuthService {
func NewAuthService(
userRepo ports.UserRepository,
cfg *config.Config,
settingService *SettingService,
emailService *EmailService,
turnstileService *TurnstileService,
emailQueueService *EmailQueueService,
) *AuthService {
return &AuthService{
userRepo: userRepo,
cfg: cfg,
userRepo: userRepo,
cfg: cfg,
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
emailQueueService: emailQueueService,
}
}
// SetSettingService 设置系统设置服务(用于检查注册开关和邮件验证)
func (s *AuthService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
}
// SetEmailService 设置邮件服务(用于邮件验证)
func (s *AuthService) SetEmailService(emailService *EmailService) {
s.emailService = emailService
}
// SetTurnstileService 设置Turnstile服务用于验证码校验
func (s *AuthService) SetTurnstileService(turnstileService *TurnstileService) {
s.turnstileService = turnstileService
}
// SetEmailQueueService 设置邮件队列服务(用于异步发送邮件)
func (s *AuthService) SetEmailQueueService(emailQueueService *EmailQueueService) {
s.emailQueueService = emailQueueService
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) {
return s.RegisterWithVerification(ctx, email, password, "")

View File

@@ -5,30 +5,10 @@ import (
"errors"
"fmt"
"log"
"strconv"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"github.com/redis/go-redis/v9"
)
// 缓存Key前缀和TTL
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
)
// 订阅缓存Hash字段
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
"sub2api/internal/service/ports"
)
// 错误定义
@@ -38,35 +18,6 @@ var (
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
)
// 预编译的Lua脚本
var (
// deductBalanceScript: 扣减余额缓存key不存在则忽略
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// updateSubUsageScript: 更新订阅用量缓存key不存在则忽略
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
type subscriptionCacheData struct {
Status string
@@ -80,15 +31,15 @@ type subscriptionCacheData struct {
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
rdb *redis.Client
userRepo *repository.UserRepository
subRepo *repository.UserSubscriptionRepository
cache ports.BillingCache
userRepo ports.UserRepository
subRepo ports.UserSubscriptionRepository
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserRepository, subRepo *repository.UserSubscriptionRepository) *BillingCacheService {
func NewBillingCacheService(cache ports.BillingCache, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
return &BillingCacheService{
rdb: rdb,
cache: cache,
userRepo: userRepo,
subRepo: subRepo,
}
@@ -100,24 +51,19 @@ func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserReposito
// GetUserBalance 获取用户余额(优先从缓存读取)
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
if s.rdb == nil {
if s.cache == nil {
// Redis不可用直接查询数据库
return s.getUserBalanceFromDB(ctx, userID)
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
// 尝试从缓存读取
val, err := s.rdb.Get(ctx, key).Result()
balance, err := s.cache.GetUserBalance(ctx, userID)
if err == nil {
balance, parseErr := strconv.ParseFloat(val, 64)
if parseErr == nil {
return balance, nil
}
return balance, nil
}
// 缓存未命中或解析错误,从数据库读取
balance, err := s.getUserBalanceFromDB(ctx, userID)
// 缓存未命中,从数据库读取
balance, err = s.getUserBalanceFromDB(ctx, userID)
if err != nil {
return 0, err
}
@@ -143,39 +89,28 @@ func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID i
// setBalanceCache 设置余额缓存
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
if s.rdb == nil {
if s.cache == nil {
return
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
if err := s.rdb.Set(ctx, key, balance, billingCacheTTL).Err(); err != nil {
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
}
}
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
// 使用预编译的Lua脚本原子性扣减如果key不存在则忽略
_, err := deductBalanceScript.Run(ctx, s.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
}
return nil
return s.cache.DeductUserBalance(ctx, userID, amount)
}
// InvalidateUserBalance 失效用户余额缓存
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
if err := s.rdb.Del(ctx, key).Err(); err != nil {
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
return err
}
@@ -188,19 +123,14 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
if s.rdb == nil {
if s.cache == nil {
return s.getSubscriptionFromDB(ctx, userID, groupID)
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
// 尝试从缓存读取
result, err := s.rdb.HGetAll(ctx, key).Result()
if err == nil && len(result) > 0 {
data, parseErr := s.parseSubscriptionCache(result)
if parseErr == nil {
return data, nil
}
cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID)
if err == nil && cacheData != nil {
return s.convertFromPortsData(cacheData), nil
}
// 缓存未命中,从数据库读取
@@ -219,6 +149,28 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
return data, nil
}
func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCacheData) *subscriptionCacheData {
return &subscriptionCacheData{
Status: data.Status,
ExpiresAt: data.ExpiresAt,
DailyUsage: data.DailyUsage,
WeeklyUsage: data.WeeklyUsage,
MonthlyUsage: data.MonthlyUsage,
Version: data.Version,
}
}
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *ports.SubscriptionCacheData {
return &ports.SubscriptionCacheData{
Status: data.Status,
ExpiresAt: data.ExpiresAt,
DailyUsage: data.DailyUsage,
WeeklyUsage: data.WeeklyUsage,
MonthlyUsage: data.MonthlyUsage,
Version: data.Version,
}
}
// getSubscriptionFromDB 从数据库获取订阅数据
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
@@ -236,90 +188,30 @@ func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID,
}, nil
}
// parseSubscriptionCache 解析订阅缓存数据
func (s *BillingCacheService) parseSubscriptionCache(data map[string]string) (*subscriptionCacheData, error) {
result := &subscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
// setSubscriptionCache 设置订阅缓存
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
if s.rdb == nil || data == nil {
if s.cache == nil || data == nil {
return
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]interface{}{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := s.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
if _, err := pipe.Exec(ctx); err != nil {
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
}
}
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
// 使用预编译的Lua脚本原子性增加用量如果key不存在则忽略
_, err := updateSubUsageScript.Run(ctx, s.rdb, []string{key}, costUSD, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
}
return nil
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
}
// InvalidateSubscription 失效指定订阅缓存
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
if err := s.rdb.Del(ctx, key).Err(); err != nil {
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
return err
}

View File

@@ -2,22 +2,13 @@ package service
import (
"context"
"fmt"
"log"
"time"
"github.com/redis/go-redis/v9"
"sub2api/internal/service/ports"
)
const (
// Redis key prefixes
accountConcurrencyKey = "concurrency:account:"
userConcurrencyKey = "concurrency:user:"
userWaitCountKey = "concurrency:wait:"
// TTL for concurrency keys (auto-release safety net)
concurrencyKeyTTL = 10 * time.Minute
// Wait polling interval
waitPollInterval = 100 * time.Millisecond
@@ -28,70 +19,14 @@ const (
defaultExtraWaitSlots = 20
)
// Pre-compiled Lua scripts for better performance
var (
// acquireScript: increment counter if below max, return 1 if successful
acquireScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current < tonumber(ARGV[1]) then
redis.call('INCR', KEYS[1])
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
end
return 0
`)
// releaseScript: decrement counter, but don't go below 0
releaseScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
// incrementWaitScript: increment wait counter if below max, return 1 if successful
incrementWaitScript = redis.NewScript(`
local waitKey = KEYS[1]
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= maxWait then
return 0
end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl)
return 1
`)
// decrementWaitScript: decrement wait counter, but don't go below 0
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
)
// ConcurrencyService manages concurrent request limiting for accounts and users
type ConcurrencyService struct {
rdb *redis.Client
cache ports.ConcurrencyCache
}
// NewConcurrencyService creates a new ConcurrencyService
func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService {
return &ConcurrencyService{rdb: rdb}
func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
return &ConcurrencyService{cache: cache}
}
// AcquireResult represents the result of acquiring a concurrency slot
@@ -104,20 +39,6 @@ type AcquireResult struct {
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
return s.acquireSlot(ctx, key, maxConcurrency)
}
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
// If the user is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
return s.acquireSlot(ctx, key, maxConcurrency)
}
// acquireSlot is the core implementation for acquiring a concurrency slot
func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) {
// If maxConcurrency is 0 or negative, no limit
if maxConcurrency <= 0 {
return &AcquireResult{
@@ -126,8 +47,7 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
}, nil
}
// Try to acquire immediately
acquired, err := s.tryAcquire(ctx, key, maxConcurrency)
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, err
}
@@ -135,64 +55,56 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
if acquired {
return &AcquireResult{
Acquired: true,
ReleaseFunc: s.makeReleaseFunc(key),
ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil {
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err)
}
},
}, nil
}
// Not acquired, return with Acquired=false
// The caller (gateway handler) will handle waiting with ping support
return &AcquireResult{
Acquired: false,
ReleaseFunc: nil,
}, nil
}
// tryAcquire attempts to increment the counter if below max
// Uses pre-compiled Lua script for atomicity and performance
func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) {
result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int()
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
// If the user is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
// If maxConcurrency is 0 or negative, no limit
if maxConcurrency <= 0 {
return &AcquireResult{
Acquired: true,
ReleaseFunc: func() {}, // no-op
}, nil
}
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return false, fmt.Errorf("acquire slot failed: %w", err)
return nil, err
}
return result == 1, nil
}
// makeReleaseFunc creates a function to release a concurrency slot
func (s *ConcurrencyService) makeReleaseFunc(key string) func() {
return func() {
// Use background context to ensure release even if original context is cancelled
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil {
// Log error but don't panic - TTL will eventually clean up
log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err)
}
if acquired {
return &AcquireResult{
Acquired: true,
ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil {
log.Printf("Warning: failed to release user slot for %d: %v", userID, err)
}
},
}, nil
}
}
// GetCurrentCount returns the current concurrency count for debugging/monitoring
func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) {
val, err := s.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, err
}
return val, nil
}
// GetAccountCurrentCount returns current concurrency count for an account
func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
return s.GetCurrentCount(ctx, key)
}
// GetUserCurrentCount returns current concurrency count for a user
func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
return s.GetCurrentCount(ctx, key)
return &AcquireResult{
Acquired: false,
ReleaseFunc: nil,
}, nil
}
// ============================================
@@ -203,44 +115,36 @@ func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int
// Returns true if successful, false if the wait queue is full.
// maxWait should be user.Concurrency + defaultExtraWaitSlots
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
if s.rdb == nil {
if s.cache == nil {
// Redis not available, allow request
return true, nil
}
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int()
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
if err != nil {
// On error, allow the request to proceed (fail open)
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
return true, nil
}
return result == 1, nil
return result, nil
}
// DecrementWaitCount decrements the wait queue counter for a user.
// Should be called when a request completes or exits the wait queue.
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
if s.rdb == nil {
if s.cache == nil {
return
}
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
// Use background context to ensure decrement even if original context is cancelled
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil {
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
}
}
// GetUserWaitCount returns current wait queue count for a user
func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
return s.GetCurrentCount(ctx, key)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int {

View File

@@ -4,17 +4,14 @@ import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/smtp"
"strconv"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"time"
"github.com/redis/go-redis/v9"
)
var (
@@ -25,19 +22,11 @@ var (
)
const (
verifyCodeKeyPrefix = "email_verify:"
verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute
verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5
)
// verifyCodeData Redis 中存储的验证码数据
type verifyCodeData struct {
Code string `json:"code"`
Attempts int `json:"attempts"`
CreatedAt time.Time `json:"created_at"`
}
// SmtpConfig SMTP配置
type SmtpConfig struct {
Host string
@@ -51,15 +40,15 @@ type SmtpConfig struct {
// EmailService 邮件服务
type EmailService struct {
settingRepo *repository.SettingRepository
rdb *redis.Client
settingRepo ports.SettingRepository
cache ports.EmailCache
}
// NewEmailService 创建邮件服务实例
func NewEmailService(settingRepo *repository.SettingRepository, rdb *redis.Client) *EmailService {
func NewEmailService(settingRepo ports.SettingRepository, cache ports.EmailCache) *EmailService {
return &EmailService{
settingRepo: settingRepo,
rdb: rdb,
cache: cache,
}
}
@@ -201,10 +190,8 @@ func (s *EmailService) GenerateVerifyCode() (string, error) {
// SendVerifyCode 发送验证码邮件
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
key := verifyCodeKeyPrefix + email
// 检查是否在冷却期内
existing, err := s.getVerifyCodeData(ctx, key)
existing, err := s.cache.GetVerificationCode(ctx, email)
if err == nil && existing != nil {
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
return ErrVerifyCodeTooFrequent
@@ -218,12 +205,12 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
}
// 保存验证码到 Redis
data := &verifyCodeData{
data := &ports.VerificationCodeData{
Code: code,
Attempts: 0,
CreatedAt: time.Now(),
}
if err := s.setVerifyCodeData(ctx, key, data); err != nil {
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err)
}
@@ -241,9 +228,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
// VerifyCode 验证验证码
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
key := verifyCodeKeyPrefix + email
data, err := s.getVerifyCodeData(ctx, key)
data, err := s.cache.GetVerificationCode(ctx, email)
if err != nil || data == nil {
return ErrInvalidVerifyCode
}
@@ -256,7 +241,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配
if data.Code != code {
data.Attempts++
_ = s.setVerifyCodeData(ctx, key, data)
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL)
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
@@ -264,32 +249,10 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
}
// 验证成功,删除验证码
s.rdb.Del(ctx, key)
_ = s.cache.DeleteVerificationCode(ctx, email)
return nil
}
// getVerifyCodeData 从 Redis 获取验证码数据
func (s *EmailService) getVerifyCodeData(ctx context.Context, key string) (*verifyCodeData, error) {
val, err := s.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data verifyCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
// setVerifyCodeData 保存验证码数据到 Redis
func (s *EmailService) setVerifyCodeData(ctx context.Context, key string, data *verifyCodeData) error {
val, err := json.Marshal(data)
if err != nil {
return err
}
return s.rdb.Set(ctx, key, val, verifyCodeTTL).Err()
}
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
return fmt.Sprintf(`

View File

@@ -12,25 +12,27 @@ import (
"io"
"log"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/claude"
"sub2api/internal/service/ports"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// ClaudeUpstream handles HTTP requests to Claude API
type ClaudeUpstream interface {
Do(req *http.Request, proxyURL string) (*http.Response, error)
}
const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
stickySessionPrefix = "sticky_session:"
stickySessionTTL = time.Hour // 粘性会话TTL
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
)
// allowedHeaders 白名单headers参考CRS项目
@@ -76,46 +78,48 @@ type ForwardResult struct {
// GatewayService handles API gateway operations
type GatewayService struct {
repos *repository.Repositories
rdb *redis.Client
accountRepo ports.AccountRepository
usageLogRepo ports.UsageLogRepository
userRepo ports.UserRepository
userSubRepo ports.UserSubscriptionRepository
cache ports.GatewayCache
cfg *config.Config
oauthService *OAuthService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpClient *http.Client
claudeUpstream ClaudeUpstream
}
// NewGatewayService creates a new GatewayService
func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config, oauthService *OAuthService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService) *GatewayService {
// 计算响应头超时时间
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second // 默认5分钟LLM高负载时可能排队较久
}
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout, // 等待上游响应头的超时
// 注意:不设置整体 Timeout让流式响应可以无限时间传输
}
func NewGatewayService(
accountRepo ports.AccountRepository,
usageLogRepo ports.UsageLogRepository,
userRepo ports.UserRepository,
userSubRepo ports.UserSubscriptionRepository,
cache ports.GatewayCache,
cfg *config.Config,
oauthService *OAuthService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
identityService *IdentityService,
claudeUpstream ClaudeUpstream,
) *GatewayService {
return &GatewayService{
repos: repos,
rdb: rdb,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
oauthService: oauthService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
identityService: identityService,
httpClient: &http.Client{
Transport: transport,
// 不设置 Timeout流式请求可能持续十几分钟
// 超时控制由 Transport.ResponseHeaderTimeout 负责(只控制等待响应头)
},
claudeUpstream: claudeUpstream,
}
}
@@ -270,14 +274,14 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64()
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
account, err := s.repos.Account.GetByID(ctx, accountID)
account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中
// 同时检查模型支持
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// 续期粘性会话
s.rdb.Expire(ctx, stickySessionPrefix+sessionHash, stickySessionTTL)
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
return account, nil
}
}
@@ -287,9 +291,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
var accounts []model.Account
var err error
if groupID != nil {
accounts, err = s.repos.Account.ListSchedulableByGroupID(ctx, *groupID)
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
} else {
accounts, err = s.repos.Account.ListSchedulable(ctx)
accounts, err = s.accountRepo.ListSchedulable(ctx)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -327,7 +331,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// 4. 建立粘性绑定
if sessionHash != "" {
s.rdb.Set(ctx, stickySessionPrefix+sessionHash, selected.ID, stickySessionTTL)
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
}
return selected, nil
@@ -352,37 +356,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
accessToken := account.GetCredential("access_token")
expiresAtStr := account.GetCredential("expires_at")
// 检查是否需要刷新
needRefresh := false
if expiresAtStr != "" {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil && time.Now().Unix()+tokenRefreshBuffer > expiresAt {
needRefresh = true
}
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
}
if needRefresh || accessToken == "" {
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", "", fmt.Errorf("refresh token failed: %w", err)
}
// 更新账号凭证
account.Credentials["access_token"] = tokenInfo.AccessToken
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
}
if err := s.repos.Account.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err)
}
return tokenInfo.AccessToken, "oauth", nil
}
// Token刷新由后台 TokenRefreshService 处理此处只返回当前token
return accessToken, "oauth", nil
}
@@ -418,48 +395,25 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
}
// 构建上游请求
upstreamResult, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
if err != nil {
return nil, err
}
// 选择使用的client如果有代理则使用独立的client否则使用共享的httpClient
httpClient := s.httpClient
if upstreamResult.Client != nil {
httpClient = upstreamResult.Client
// 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := httpClient.Do(upstreamResult.Request)
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
if err != nil {
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer resp.Body.Close()
// 处理401错误刷新token重试
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
resp.Body.Close()
token, tokenType, err = s.forceRefreshToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("token refresh failed: %w", err)
}
upstreamResult, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
if err != nil {
return nil, err
}
// 重试时也需要使用正确的client
httpClient = s.httpClient
if upstreamResult.Client != nil {
httpClient = upstreamResult.Client
}
resp, err = httpClient.Do(upstreamResult.Request)
if err != nil {
return nil, fmt.Errorf("retry request failed: %w", err)
}
defer resp.Body.Close()
}
// 处理错误响应
// 处理错误响应包括401由后台TokenRefreshService维护token有效性)
if resp.StatusCode >= 400 {
return s.handleErrorResponse(ctx, resp, c, account)
}
@@ -491,13 +445,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
}, nil
}
// buildUpstreamRequestResult contains the request and optional custom client for proxy
type buildUpstreamRequestResult struct {
Request *http.Request
Client *http.Client // nil means use default s.httpClient
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) {
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == model.AccountTypeApiKey {
@@ -506,7 +454,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
// OAuth账号应用统一指纹
var fingerprint *Fingerprint
var fingerprint *ports.Fingerprint
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹包含随机生成的ClientID
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
@@ -566,48 +514,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
}
// 配置代理 - 创建独立的client避免并发修改共享httpClient
var customClient *http.Client
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
// 计算响应头超时时间(与默认 Transport 保持一致)
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
// 创建独立的client避免并发时修改共享的s.httpClient.Transport
customClient = &http.Client{
Transport: transport,
}
}
}
}
return &buildUpstreamRequestResult{
Request: req,
Client: customClient,
}, nil
return req, nil
}
// getBetaHeader 处理anthropic-beta header
// 对于OAuth账号需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string {
const oauthBeta = "oauth-2025-04-20"
const claudeCodeBeta = "claude-code-20250219"
// 如果客户端传了anthropic-beta
if clientBetaHeader != "" {
// 已包含oauth beta则直接返回
if strings.Contains(clientBetaHeader, oauthBeta) {
if strings.Contains(clientBetaHeader, claude.BetaOAuth) {
return clientBetaHeader
}
@@ -620,7 +536,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
// 在claude-code-20250219后面插入oauth beta
claudeCodeIdx := -1
for i, p := range parts {
if p == claudeCodeBeta {
if p == claude.BetaClaudeCode {
claudeCodeIdx = i
break
}
@@ -630,13 +546,13 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
// 在claude-code后面插入
newParts := make([]string, 0, len(parts)+1)
newParts = append(newParts, parts[:claudeCodeIdx+1]...)
newParts = append(newParts, oauthBeta)
newParts = append(newParts, claude.BetaOAuth)
newParts = append(newParts, parts[claudeCodeIdx+1:]...)
return strings.Join(newParts, ",")
}
// 没有claude-code放在第一位
return oauthBeta + "," + clientBetaHeader
return claude.BetaOAuth + "," + clientBetaHeader
}
// 客户端没传,根据模型生成
@@ -650,29 +566,10 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
// haiku模型不需要claude-code beta
if strings.Contains(strings.ToLower(modelID), "haiku") {
return "oauth-2025-04-20,interleaved-thinking-2025-05-14"
return claude.HaikuBetaHeader
}
return "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14"
}
func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) {
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", "", err
}
account.Credentials["access_token"] = tokenInfo.AccessToken
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
}
if err := s.repos.Account.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err)
}
return tokenInfo.AccessToken, "oauth", nil
return claude.DefaultBetaHeader
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
@@ -1000,7 +897,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.SubscriptionID = &subscription.ID
}
if err := s.repos.UsageLog.Create(ctx, usageLog); err != nil {
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
log.Printf("Create usage log failed: %v", err)
}
@@ -1008,7 +905,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if cost.TotalCost > 0 {
if err := s.repos.UserSubscription.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
@@ -1023,7 +920,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if cost.ActualCost > 0 {
if err := s.repos.User.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err)
}
// 异步更新余额缓存
@@ -1038,9 +935,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
// 更新账号最后使用时间
if err := s.repos.Account.UpdateLastUsed(ctx, account.ID); err != nil {
if err := s.accountRepo.UpdateLastUsed(ctx, account.ID); err != nil {
log.Printf("Update last used failed: %v", err)
}
return nil
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error {
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == model.AccountTypeApiKey {
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model {
body = s.replaceModelInBody(body, mappedModel)
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
}
}
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token")
return err
}
// 构建上游请求
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
}
// 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err)
}
defer resp.Body.Close()
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
// 处理错误响应
if resp.StatusCode >= 400 {
// 标记账号状态429/529等
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 返回简化的错误响应
errMsg := "Upstream request failed"
switch resp.StatusCode {
case 429:
errMsg = "Rate limit exceeded"
case 529:
errMsg = "Service overloaded"
}
s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg)
return fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// 透传成功响应
c.Data(resp.StatusCode, "application/json", respBody)
return nil
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == model.AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
// OAuth 账号:应用统一指纹和重写 userID
if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
}
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置认证头
if tokenType == "oauth" {
req.Header.Set("Authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
}
// 白名单透传 headers
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// OAuth 账号:应用指纹到请求头
if account.IsOAuth() && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
}
}
// 确保必要的 headers 存在
if req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
}
return req, nil
}
// countTokensError 返回 count_tokens 错误响应
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -5,7 +5,8 @@ import (
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"gorm.io/gorm"
)
@@ -34,11 +35,11 @@ type UpdateGroupRequest struct {
// GroupService 分组管理服务
type GroupService struct {
groupRepo *repository.GroupRepository
groupRepo ports.GroupRepository
}
// NewGroupService 创建分组服务实例
func NewGroupService(groupRepo *repository.GroupRepository) *GroupService {
func NewGroupService(groupRepo ports.GroupRepository) *GroupService {
return &GroupService{
groupRepo: groupRepo,
}
@@ -84,7 +85,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err
}
// List 获取分组列表
func (s *GroupService) List(ctx context.Context, params repository.PaginationParams) ([]model.Group, *repository.PaginationResult, error) {
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
groups, pagination, err := s.groupRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list groups: %w", err)

View File

@@ -11,15 +11,10 @@ import (
"net/http"
"regexp"
"strconv"
"sub2api/internal/service/ports"
"time"
"github.com/redis/go-redis/v9"
)
const (
// Redis key prefix
identityFingerprintKey = "identity:fingerprint:"
)
// 预编译正则表达式(避免每次调用重新编译)
var (
@@ -29,20 +24,8 @@ var (
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
)
// Fingerprint 存储的指纹数据结构
type Fingerprint struct {
ClientID string `json:"client_id"` // 64位hex客户端ID首次随机生成
UserAgent string `json:"user_agent"` // User-Agent
StainlessLang string `json:"x_stainless_lang"` // x-stainless-lang
StainlessPackageVersion string `json:"x_stainless_package_version"` // x-stainless-package-version
StainlessOS string `json:"x_stainless_os"` // x-stainless-os
StainlessArch string `json:"x_stainless_arch"` // x-stainless-arch
StainlessRuntime string `json:"x_stainless_runtime"` // x-stainless-runtime
StainlessRuntimeVersion string `json:"x_stainless_runtime_version"` // x-stainless-runtime-version
}
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
var defaultFingerprint = ports.Fingerprint{
UserAgent: "claude-cli/2.0.62 (external, cli)",
StainlessLang: "js",
StainlessPackageVersion: "0.52.0",
@@ -54,39 +37,31 @@ var defaultFingerprint = Fingerprint{
// IdentityService 管理OAuth账号的请求身份指纹
type IdentityService struct {
rdb *redis.Client
cache ports.IdentityCache
}
// NewIdentityService 创建新的IdentityService
func NewIdentityService(rdb *redis.Client) *IdentityService {
return &IdentityService{rdb: rdb}
func NewIdentityService(cache ports.IdentityCache) *IdentityService {
return &IdentityService{cache: cache}
}
// GetOrCreateFingerprint 获取或创建账号的指纹
// 如果缓存存在检测user-agent版本新版本则更新
// 如果缓存不存在生成随机ClientID并从请求头创建指纹然后缓存
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
key := identityFingerprintKey + strconv.FormatInt(accountID, 10)
// 尝试从Redis获取缓存的指纹
data, err := s.rdb.Get(ctx, key).Bytes()
if err == nil && len(data) > 0 {
// 缓存存在,解析指纹
var cached Fingerprint
if err := json.Unmarshal(data, &cached); err == nil {
// 检查客户端的user-agent是否是更新版本
clientUA := headers.Get("User-Agent")
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
// 更新user-agent
cached.UserAgent = clientUA
// 保存更新后的指纹
if newData, err := json.Marshal(cached); err == nil {
s.rdb.Set(ctx, key, newData, 0) // 永不过期
}
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
}
return &cached, nil
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*ports.Fingerprint, error) {
// 尝试从缓存获取指纹
cached, err := s.cache.GetFingerprint(ctx, accountID)
if err == nil && cached != nil {
// 检查客户端的user-agent是否是更新版本
clientUA := headers.Get("User-Agent")
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
// 更新user-agent
cached.UserAgent = clientUA
// 保存更新后的指纹
_ = s.cache.SetFingerprint(ctx, accountID, cached)
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
}
return cached, nil
}
// 缓存不存在或解析失败,创建新指纹
@@ -95,11 +70,9 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
// 生成随机ClientID
fp.ClientID = generateClientID()
// 保存到Redis(永不过期)
if data, err := json.Marshal(fp); err == nil {
if err := s.rdb.Set(ctx, key, data, 0).Err(); err != nil {
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
}
// 保存到缓存(永不过期)
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
}
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
@@ -107,8 +80,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
}
// createFingerprintFromHeaders 从请求头创建指纹
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
fp := &Fingerprint{}
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *ports.Fingerprint {
fp := &ports.Fingerprint{}
// 获取User-Agent
if ua := headers.Get("User-Agent"); ua != "" {
@@ -137,7 +110,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
}
// ApplyFingerprint 将指纹应用到请求头覆盖原有的x-stainless-*头)
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerprint) {
if fp == nil {
return
}

View File

@@ -2,32 +2,36 @@ package service
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/oauth"
"sub2api/internal/repository"
"github.com/imroc/req/v3"
"sub2api/internal/service/ports"
)
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
type ClaudeOAuthClient interface {
GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error)
GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
}
// OAuthService handles OAuth authentication flows
type OAuthService struct {
sessionStore *oauth.SessionStore
proxyRepo *repository.ProxyRepository
proxyRepo ports.ProxyRepository
oauthClient ClaudeOAuthClient
}
// NewOAuthService creates a new OAuth service
func NewOAuthService(proxyRepo *repository.ProxyRepository) *OAuthService {
func NewOAuthService(proxyRepo ports.ProxyRepository, oauthClient ClaudeOAuthClient) *OAuthService {
return &OAuthService{
sessionStore: oauth.NewSessionStore(),
proxyRepo: proxyRepo,
oauthClient: oauthClient,
}
}
@@ -210,177 +214,21 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
// getOrganizationUUID gets the organization UUID from claude.ai using sessionKey
func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := s.createReqClient(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := "https://claude.ai/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetSuccessResult(&orgs).
Get(targetURL)
if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
}
if len(orgs) == 0 {
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
return orgs[0].UUID, nil
return s.oauthClient.GetOrganizationUUID(ctx, sessionKey, proxyURL)
}
// getAuthorizationCode gets the authorization code using sessionKey
func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := s.createReqClient(proxyURL)
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
// Build request body - must include organization_uuid as per CRS
reqBody := map[string]interface{}{
"response_type": "code",
"client_id": oauth.ClientID,
"organization_uuid": orgUUID, // Required field!
"redirect_uri": oauth.RedirectURI,
"scope": scope,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
// Response contains redirect_uri with code, not direct code field
var result struct {
RedirectURI string `json:"redirect_uri"`
}
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetHeader("Accept", "application/json").
SetHeader("Accept-Language", "en-US,en;q=0.9").
SetHeader("Cache-Control", "no-cache").
SetHeader("Origin", "https://claude.ai").
SetHeader("Referer", "https://claude.ai/new").
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&result).
Post(authURL)
if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
}
if result.RedirectURI == "" {
return "", fmt.Errorf("no redirect_uri in response")
}
// Parse redirect_uri to extract code and state
parsedURL, err := url.Parse(result.RedirectURI)
if err != nil {
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
}
queryParams := parsedURL.Query()
authCode := queryParams.Get("code")
responseState := queryParams.Get("state")
if authCode == "" {
return "", fmt.Errorf("no authorization code in redirect_uri")
}
// Combine code with state if present (as CRS does)
fullCode := authCode
if responseState != "" {
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
return fullCode, nil
return s.oauthClient.GetAuthorizationCode(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
}
// exchangeCodeForToken exchanges authorization code for tokens
func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) {
client := s.createReqClient(proxyURL)
// Parse code#state format if present
authCode := code
codeState := ""
if parts := strings.Split(code, "#"); len(parts) > 1 {
authCode = parts[0]
codeState = parts[1]
}
// Build JSON body as CRS does (not form data!)
reqBody := map[string]interface{}{
"code": authCode,
"grant_type": "authorization_code",
"client_id": oauth.ClientID,
"redirect_uri": oauth.RedirectURI,
"code_verifier": codeVerifier,
}
// Add state if present
if codeState != "" {
reqBody["state"] = codeState
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err)
return nil, err
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
tokenInfo := &TokenInfo{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
@@ -390,7 +238,6 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
Scope: tokenResp.Scope,
}
// Extract org_uuid and account_uuid from response
if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" {
tokenInfo.OrgUUID = tokenResp.Organization.UUID
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
@@ -405,27 +252,9 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
// RefreshToken refreshes an OAuth token
func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) {
client := s.createReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauth.ClientID)
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
return nil, err
}
return &TokenInfo{
@@ -455,17 +284,3 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A
return s.RefreshToken(ctx, refreshToken, proxyURL)
}
// createReqClient creates a req client with Chrome impersonation and optional proxy
func (s *OAuthService) createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
SetTimeout(60 * time.Second)
// Set proxy if specified
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}

View File

@@ -0,0 +1,35 @@
package ports
import (
"context"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type AccountRepository interface {
Create(ctx context.Context, account *model.Account) error
GetByID(ctx context.Context, id int64) (*model.Account, error)
Update(ctx context.Context, account *model.Account) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error)
ListActive(ctx context.Context) ([]model.Account, error)
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error)
UpdateLastUsed(ctx context.Context, id int64) error
SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]model.Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
}

View File

@@ -0,0 +1,24 @@
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type ApiKeyRepository interface {
Create(ctx context.Context, key *model.ApiKey) error
GetByID(ctx context.Context, id int64) (*model.ApiKey, error)
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
Update(ctx context.Context, key *model.ApiKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}

View File

@@ -0,0 +1,16 @@
package ports
import (
"context"
"time"
)
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
IncrementDailyUsage(ctx context.Context, apiKey string) error
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}

View File

@@ -0,0 +1,31 @@
package ports
import (
"context"
"time"
)
// SubscriptionCacheData represents cached subscription data
type SubscriptionCacheData struct {
Status string
ExpiresAt time.Time
DailyUsage float64
WeeklyUsage float64
MonthlyUsage float64
Version int64
}
// BillingCache defines cache operations for billing service
type BillingCache interface {
// Balance operations
GetUserBalance(ctx context.Context, userID int64) (float64, error)
SetUserBalance(ctx context.Context, userID int64, balance float64) error
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
InvalidateUserBalance(ctx context.Context, userID int64) error
// Subscription operations
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
}

View File

@@ -0,0 +1,19 @@
package ports
import "context"
// ConcurrencyCache defines cache operations for concurrency service
type ConcurrencyCache interface {
// Slot management
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// Wait queue
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
}

View File

@@ -0,0 +1,20 @@
package ports
import (
"context"
"time"
)
// VerificationCodeData represents verification code data
type VerificationCodeData struct {
Code string
Attempts int
CreatedAt time.Time
}
// EmailCache defines cache operations for email service
type EmailCache interface {
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error
}

View File

@@ -0,0 +1,13 @@
package ports
import (
"context"
"time"
)
// GatewayCache defines cache operations for gateway service
type GatewayCache interface {
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
}

View File

@@ -0,0 +1,28 @@
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type GroupRepository interface {
Create(ctx context.Context, group *model.Group) error
GetByID(ctx context.Context, id int64) (*model.Group, error)
Update(ctx context.Context, group *model.Group) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error)
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
DB() *gorm.DB
}

View File

@@ -0,0 +1,21 @@
package ports
import "context"
// Fingerprint represents account fingerprint data
type Fingerprint struct {
ClientID string
UserAgent string
StainlessLang string
StainlessPackageVersion string
StainlessOS string
StainlessArch string
StainlessRuntime string
StainlessRuntimeVersion string
}
// IdentityCache defines cache operations for identity service
type IdentityCache interface {
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
}

View File

@@ -0,0 +1,23 @@
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type ProxyRepository interface {
Create(ctx context.Context, proxy *model.Proxy) error
GetByID(ctx context.Context, id int64) (*model.Proxy, error)
Update(ctx context.Context, proxy *model.Proxy) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
}

View File

@@ -0,0 +1,15 @@
package ports
import (
"context"
"time"
)
// RedeemCache defines cache operations for redeem service
type RedeemCache interface {
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementRedeemAttemptCount(ctx context.Context, userID int64) error
AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error)
ReleaseRedeemLock(ctx context.Context, code string) error
}

View File

@@ -0,0 +1,22 @@
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type RedeemCodeRepository interface {
Create(ctx context.Context, code *model.RedeemCode) error
CreateBatch(ctx context.Context, codes []model.RedeemCode) error
GetByID(ctx context.Context, id int64) (*model.RedeemCode, error)
GetByCode(ctx context.Context, code string) (*model.RedeemCode, error)
Update(ctx context.Context, code *model.RedeemCode) error
Delete(ctx context.Context, id int64) error
Use(ctx context.Context, id, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error)
}

View File

@@ -0,0 +1,17 @@
package ports
import (
"context"
"sub2api/internal/model"
)
type SettingRepository interface {
Get(ctx context.Context, key string) (*model.Setting, error)
GetValue(ctx context.Context, key string) (string, error)
Set(ctx context.Context, key, value string) error
GetMultiple(ctx context.Context, keys []string) (map[string]string, error)
SetMultiple(ctx context.Context, settings map[string]string) error
GetAll(ctx context.Context) (map[string]string, error)
Delete(ctx context.Context, key string) error
}

View File

@@ -0,0 +1,12 @@
package ports
import (
"context"
"time"
)
// UpdateCache defines cache operations for update service
type UpdateCache interface {
GetUpdateInfo(ctx context.Context) (string, error)
SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error
}

View File

@@ -0,0 +1,28 @@
package ports
import (
"context"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/usagestats"
)
type UsageLogRepository interface {
Create(ctx context.Context, log *model.UsageLog) error
GetByID(ctx context.Context, id int64) (*model.UsageLog, error)
Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
}

View File

@@ -0,0 +1,26 @@
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type UserRepository interface {
Create(ctx context.Context, user *model.User) error
GetByID(ctx context.Context, id int64) (*model.User, error)
GetByEmail(ctx context.Context, email string) (*model.User, error)
GetFirstAdmin(ctx context.Context) (*model.User, error)
Update(ctx context.Context, user *model.User) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error)
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
UpdateConcurrency(ctx context.Context, id int64, amount int) error
ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
}

View File

@@ -0,0 +1,36 @@
package ports
import (
"context"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type UserSubscriptionRepository interface {
Create(ctx context.Context, sub *model.UserSubscription) error
GetByID(ctx context.Context, id int64) (*model.UserSubscription, error)
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
Update(ctx context.Context, sub *model.UserSubscription) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
UpdateStatus(ctx context.Context, subscriptionID int64, status string) error
UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error
ActivateWindows(ctx context.Context, id int64, start time.Time) error
ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
IncrementUsage(ctx context.Context, id int64, costUSD float64) error
BatchUpdateExpiredStatus(ctx context.Context) (int64, error)
}

View File

@@ -1,13 +1,12 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
@@ -20,13 +19,19 @@ import (
// LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值
type LiteLLMModelPricing struct {
InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
}
// PricingRemoteClient 远程价格数据获取接口
type PricingRemoteClient interface {
FetchPricingJSON(ctx context.Context, url string) ([]byte, error)
FetchHashText(ctx context.Context, url string) (string, error)
}
// LiteLLMRawEntry 用于解析原始JSON数据
@@ -42,11 +47,12 @@ type LiteLLMRawEntry struct {
// PricingService 动态价格服务
type PricingService struct {
cfg *config.Config
mu sync.RWMutex
pricingData map[string]*LiteLLMModelPricing
lastUpdated time.Time
localHash string
cfg *config.Config
remoteClient PricingRemoteClient
mu sync.RWMutex
pricingData map[string]*LiteLLMModelPricing
lastUpdated time.Time
localHash string
// 停止信号
stopCh chan struct{}
@@ -54,11 +60,12 @@ type PricingService struct {
}
// NewPricingService 创建价格服务
func NewPricingService(cfg *config.Config) *PricingService {
func NewPricingService(cfg *config.Config, remoteClient PricingRemoteClient) *PricingService {
s := &PricingService{
cfg: cfg,
pricingData: make(map[string]*LiteLLMModelPricing),
stopCh: make(chan struct{}),
cfg: cfg,
remoteClient: remoteClient,
pricingData: make(map[string]*LiteLLMModelPricing),
stopCh: make(chan struct{}),
}
return s
}
@@ -199,21 +206,13 @@ func (s *PricingService) syncWithRemote() error {
func (s *PricingService) downloadPricingData() error {
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Get(s.cfg.Pricing.RemoteURL)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed: HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response failed: %w", err)
}
// 解析JSON数据使用灵活的解析方式
data, err := s.parsePricingData(body)
@@ -367,29 +366,10 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(s.cfg.Pricing.HashURL)
if err != nil {
return "", err
}
defer resp.Body.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// 哈希文件格式hash filename 或者纯 hash
hash := strings.TrimSpace(string(body))
parts := strings.Fields(hash)
if len(parts) > 0 {
return parts[0], nil
}
return hash, nil
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
}
// computeFileHash 计算文件哈希
@@ -466,14 +446,14 @@ func (s *PricingService) extractBaseName(model string) string {
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
// Claude模型系列匹配规则
familyPatterns := map[string][]string{
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
"opus-4": {"claude-opus-4", "claude-3-opus"},
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
"sonnet-3": {"claude-3-sonnet"},
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
"haiku-3": {"claude-3-haiku"},
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
"opus-4": {"claude-opus-4", "claude-3-opus"},
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
"sonnet-3": {"claude-3-sonnet"},
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
"haiku-3": {"claude-3-haiku"},
}
// 确定模型属于哪个系列

View File

@@ -5,7 +5,8 @@ import (
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"gorm.io/gorm"
)
@@ -37,11 +38,11 @@ type UpdateProxyRequest struct {
// ProxyService 代理管理服务
type ProxyService struct {
proxyRepo *repository.ProxyRepository
proxyRepo ports.ProxyRepository
}
// NewProxyService 创建代理服务实例
func NewProxyService(proxyRepo *repository.ProxyRepository) *ProxyService {
func NewProxyService(proxyRepo ports.ProxyRepository) *ProxyService {
return &ProxyService{
proxyRepo: proxyRepo,
}
@@ -80,7 +81,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err
}
// List 获取代理列表
func (s *ProxyService) List(ctx context.Context, params repository.PaginationParams) ([]model.Proxy, *repository.PaginationResult, error) {
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
proxies, pagination, err := s.proxyRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list proxies: %w", err)

View File

@@ -9,20 +9,20 @@ import (
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
)
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
repos *repository.Repositories
cfg *config.Config
accountRepo ports.AccountRepository
cfg *config.Config
}
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(repos *repository.Repositories, cfg *config.Config) *RateLimitService {
func NewRateLimitService(accountRepo ports.AccountRepository, cfg *config.Config) *RateLimitService {
return &RateLimitService{
repos: repos,
cfg: cfg,
accountRepo: accountRepo,
cfg: cfg,
}
}
@@ -62,7 +62,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
if err := s.repos.Account.SetError(ctx, account.ID, errorMsg); err != nil {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
return
}
@@ -77,7 +77,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
if resetTimestamp == "" {
// 没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
return
@@ -88,7 +88,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
if err != nil {
log.Printf("Parse reset timestamp failed: %v", err)
resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
return
@@ -97,7 +97,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
resetAt := time.Unix(ts, 0)
// 标记限流状态
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
return
}
@@ -105,7 +105,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
// 根据重置时间反推5h窗口
windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour)
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
}
@@ -121,7 +121,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.repos.Account.SetOverloaded(ctx, account.ID, until); err != nil {
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
return
}
@@ -152,13 +152,13 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
}
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
}
// 如果状态为allowed且之前有限流说明窗口已重置清除限流状态
if status == "allowed" && account.IsRateLimited() {
if err := s.repos.Account.ClearRateLimit(ctx, account.ID); err != nil {
if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
}
}
@@ -166,5 +166,5 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
// ClearRateLimit 清除账号的限流状态
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.repos.Account.ClearRateLimit(ctx, accountID)
return s.accountRepo.ClearRateLimit(ctx, accountID)
}

View File

@@ -8,7 +8,8 @@ import (
"fmt"
"strings"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"time"
"github.com/redis/go-redis/v9"
@@ -25,11 +26,9 @@ var (
)
const (
redeemRateLimitKeyPrefix = "redeem:rate_limit:"
redeemLockKeyPrefix = "redeem:lock:"
redeemMaxErrorsPerHour = 20
redeemRateLimitDuration = time.Hour
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
redeemMaxErrorsPerHour = 20
redeemRateLimitDuration = time.Hour
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
)
// GenerateCodesRequest 生成兑换码请求
@@ -49,28 +48,30 @@ type RedeemCodeResponse struct {
// RedeemService 兑换码服务
type RedeemService struct {
redeemRepo *repository.RedeemCodeRepository
userRepo *repository.UserRepository
redeemRepo ports.RedeemCodeRepository
userRepo ports.UserRepository
subscriptionService *SubscriptionService
rdb *redis.Client
cache ports.RedeemCache
billingCacheService *BillingCacheService
}
// NewRedeemService 创建兑换码服务实例
func NewRedeemService(redeemRepo *repository.RedeemCodeRepository, userRepo *repository.UserRepository, subscriptionService *SubscriptionService, rdb *redis.Client) *RedeemService {
func NewRedeemService(
redeemRepo ports.RedeemCodeRepository,
userRepo ports.UserRepository,
subscriptionService *SubscriptionService,
cache ports.RedeemCache,
billingCacheService *BillingCacheService,
) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
rdb: rdb,
cache: cache,
billingCacheService: billingCacheService,
}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
func (s *RedeemService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
}
// GenerateRandomCode 生成随机兑换码
func (s *RedeemService) GenerateRandomCode() (string, error) {
// 生成16字节随机数据
@@ -137,13 +138,11 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil {
if s.cache == nil {
return nil
}
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
count, err := s.rdb.Get(ctx, key).Int()
count, err := s.cache.GetRedeemAttemptCount(ctx, userID)
if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作
return nil
@@ -158,27 +157,21 @@ func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64)
// incrementRedeemErrorCount 增加用户兑换错误计数
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil {
if s.cache == nil {
return
}
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
_, _ = pipe.Exec(ctx)
_ = s.cache.IncrementRedeemAttemptCount(ctx, userID)
}
// acquireRedeemLock 尝试获取兑换码的分布式锁
// 返回 true 表示获取成功false 表示锁已被占用
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
if s.rdb == nil {
if s.cache == nil {
return true // 无 Redis 时降级为不加锁
}
key := redeemLockKeyPrefix + code
ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result()
ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration)
if err != nil {
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
return true
@@ -188,12 +181,11 @@ func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool
// releaseRedeemLock 释放兑换码的分布式锁
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
if s.rdb == nil {
if s.cache == nil {
return
}
key := redeemLockKeyPrefix + code
s.rdb.Del(ctx, key)
_ = s.cache.ReleaseRedeemLock(ctx, code)
}
// Redeem 使用兑换码
@@ -335,7 +327,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede
}
// List 获取兑换码列表(管理员功能)
func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) {
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
codes, pagination, err := s.redeemRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list redeem codes: %w", err)

View File

@@ -1,12 +1,5 @@
package service
import (
"sub2api/internal/config"
"sub2api/internal/repository"
"github.com/redis/go-redis/v9"
)
// Services 服务集合容器
type Services struct {
Auth *AuthService
@@ -33,107 +26,6 @@ type Services struct {
Subscription *SubscriptionService
Concurrency *ConcurrencyService
Identity *IdentityService
}
// NewServices 创建所有服务实例
func NewServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *Services {
// 初始化价格服务
pricingService := NewPricingService(cfg)
if err := pricingService.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格
println("[Service] Warning: Pricing service initialization failed:", err.Error())
}
// 初始化计费服务(依赖价格服务)
billingService := NewBillingService(cfg, pricingService)
// 初始化其他服务
authService := NewAuthService(repos.User, cfg)
userService := NewUserService(repos.User, cfg)
apiKeyService := NewApiKeyService(repos.ApiKey, repos.User, repos.Group, repos.UserSubscription, rdb, cfg)
groupService := NewGroupService(repos.Group)
accountService := NewAccountService(repos.Account, repos.Group)
proxyService := NewProxyService(repos.Proxy)
usageService := NewUsageService(repos.UsageLog, repos.User)
// 初始化订阅服务 (RedeemService 依赖)
subscriptionService := NewSubscriptionService(repos)
// 初始化兑换服务 (依赖订阅服务)
redeemService := NewRedeemService(repos.RedeemCode, repos.User, subscriptionService, rdb)
// 初始化Admin服务
adminService := NewAdminService(repos)
// 初始化OAuth服务GatewayService依赖
oauthService := NewOAuthService(repos.Proxy)
// 初始化限流服务
rateLimitService := NewRateLimitService(repos, cfg)
// 初始化计费缓存服务
billingCacheService := NewBillingCacheService(rdb, repos.User, repos.UserSubscription)
// 初始化账号使用量服务
accountUsageService := NewAccountUsageService(repos, oauthService)
// 初始化账号测试服务
accountTestService := NewAccountTestService(repos, oauthService)
// 初始化身份指纹服务
identityService := NewIdentityService(rdb)
// 初始化Gateway服务
gatewayService := NewGatewayService(repos, rdb, cfg, oauthService, billingService, rateLimitService, billingCacheService, identityService)
// 初始化设置服务
settingService := NewSettingService(repos.Setting, cfg)
emailService := NewEmailService(repos.Setting, rdb)
// 初始化邮件队列服务
emailQueueService := NewEmailQueueService(emailService, 3)
// 初始化Turnstile服务
turnstileService := NewTurnstileService(settingService)
// 设置Auth服务的依赖用于注册开关和邮件验证
authService.SetSettingService(settingService)
authService.SetEmailService(emailService)
authService.SetTurnstileService(turnstileService)
authService.SetEmailQueueService(emailQueueService)
// 初始化并发控制服务
concurrencyService := NewConcurrencyService(rdb)
// 注入计费缓存服务到需要失效缓存的服务
redeemService.SetBillingCacheService(billingCacheService)
subscriptionService.SetBillingCacheService(billingCacheService)
SetAdminServiceBillingCache(adminService, billingCacheService)
return &Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OAuth: oauthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
}
Update *UpdateService
TokenRefresh *TokenRefreshService
}

View File

@@ -2,12 +2,14 @@ package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strconv"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"gorm.io/gorm"
)
@@ -18,12 +20,12 @@ var (
// SettingService 系统设置服务
type SettingService struct {
settingRepo *repository.SettingRepository
settingRepo ports.SettingRepository
cfg *config.Config
}
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo *repository.SettingRepository, cfg *config.Config) *SettingService {
func NewSettingService(settingRepo ports.SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
settingRepo: settingRepo,
cfg: cfg,
@@ -262,3 +264,63 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
}
return value
}
// GenerateAdminApiKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
key := model.AdminApiKeyPrefix + hex.EncodeToString(bytes)
// 存储到 settings 表
if err := s.settingRepo.Set(ctx, model.SettingKeyAdminApiKey, key); err != nil {
return "", fmt.Errorf("save admin api key: %w", err)
}
return key, nil
}
// GetAdminApiKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", false, nil
}
return "", false, err
}
if key == "" {
return "", false, nil
}
// 脱敏:显示前 10 位和后 4 位
if len(key) > 14 {
maskedKey = key[:10] + "..." + key[len(key)-4:]
} else {
maskedKey = key
}
return maskedKey, true, nil
}
// GetAdminApiKey 获取完整的管理员 API Key仅供内部验证使用
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", nil // 未配置,返回空字符串
}
return "", err // 数据库错误
}
return key, nil
}
// DeleteAdminApiKey 删除管理员 API Key
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, model.SettingKeyAdminApiKey)
}

View File

@@ -7,7 +7,8 @@ import (
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
)
var (
@@ -23,18 +24,18 @@ var (
// SubscriptionService 订阅服务
type SubscriptionService struct {
repos *repository.Repositories
groupRepo ports.GroupRepository
userSubRepo ports.UserSubscriptionRepository
billingCacheService *BillingCacheService
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(repos *repository.Repositories) *SubscriptionService {
return &SubscriptionService{repos: repos}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
func (s *SubscriptionService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
func NewSubscriptionService(groupRepo ports.GroupRepository, userSubRepo ports.UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
return &SubscriptionService{
groupRepo: groupRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
}
}
// AssignSubscriptionInput 分配订阅输入
@@ -49,7 +50,7 @@ type AssignSubscriptionInput struct {
// AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
// 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
@@ -58,7 +59,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
}
// 检查是否已存在订阅
exists, err := s.repos.UserSubscription.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
return nil, err
}
@@ -88,10 +89,11 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
// 如果用户已有同分组的订阅:
// - 未过期:从当前过期时间累加天数
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
//
// 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
return nil, false, fmt.Errorf("group not found: %w", err)
}
@@ -100,7 +102,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 查询是否已有订阅
existingSub, err := s.repos.UserSubscription.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
existingSub, err := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
// 不存在记录是正常情况,其他错误需要返回
existingSub = nil
@@ -125,13 +127,13 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 更新过期时间
if err := s.repos.UserSubscription.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
return nil, false, fmt.Errorf("extend subscription: %w", err)
}
// 如果订阅已过期或被暂停恢复为active状态
if existingSub.Status != model.SubscriptionStatusActive {
if err := s.repos.UserSubscription.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
return nil, false, fmt.Errorf("update subscription status: %w", err)
}
}
@@ -143,7 +145,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
newNotes += "\n"
}
newNotes += input.Notes
if err := s.repos.UserSubscription.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
// 备注更新失败不影响主流程
}
}
@@ -159,7 +161,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 返回更新后的订阅
sub, err := s.repos.UserSubscription.GetByID(ctx, existingSub.ID)
sub, err := s.userSubRepo.GetByID(ctx, existingSub.ID)
return sub, true, err // true 表示是续期
}
@@ -191,27 +193,27 @@ func (s *SubscriptionService) createSubscription(ctx context.Context, input *Ass
now := time.Now()
sub := &model.UserSubscription{
UserID: input.UserID,
GroupID: input.GroupID,
StartsAt: now,
ExpiresAt: now.AddDate(0, 0, validityDays),
Status: model.SubscriptionStatusActive,
UserID: input.UserID,
GroupID: input.GroupID,
StartsAt: now,
ExpiresAt: now.AddDate(0, 0, validityDays),
Status: model.SubscriptionStatusActive,
AssignedAt: now,
Notes: input.Notes,
CreatedAt: now,
UpdatedAt: now,
Notes: input.Notes,
CreatedAt: now,
UpdatedAt: now,
}
// 只有当 AssignedBy > 0 时才设置0 表示系统分配,如兑换码)
if input.AssignedBy > 0 {
sub.AssignedBy = &input.AssignedBy
}
if err := s.repos.UserSubscription.Create(ctx, sub); err != nil {
if err := s.userSubRepo.Create(ctx, sub); err != nil {
return nil, err
}
// 重新获取完整订阅信息(包含关联)
return s.repos.UserSubscription.GetByID(ctx, sub.ID)
return s.userSubRepo.GetByID(ctx, sub.ID)
}
// BulkAssignSubscriptionInput 批量分配订阅输入
@@ -225,17 +227,17 @@ type BulkAssignSubscriptionInput struct {
// BulkAssignResult 批量分配结果
type BulkAssignResult struct {
SuccessCount int
FailedCount int
SuccessCount int
FailedCount int
Subscriptions []model.UserSubscription
Errors []string
Errors []string
}
// BulkAssignSubscription 批量分配订阅
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
result := &BulkAssignResult{
Subscriptions: make([]model.UserSubscription, 0),
Errors: make([]string, 0),
Errors: make([]string, 0),
}
for _, userID := range input.UserIDs {
@@ -261,12 +263,12 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
// RevokeSubscription 撤销订阅
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
// 先获取订阅信息用于失效缓存
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return err
}
if err := s.repos.UserSubscription.Delete(ctx, subscriptionID); err != nil {
if err := s.userSubRepo.Delete(ctx, subscriptionID); err != nil {
return err
}
@@ -285,20 +287,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
// ExtendSubscription 延长订阅
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
// 计算新的过期时间
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
if err := s.repos.UserSubscription.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
return nil, err
}
// 如果订阅已过期恢复为active状态
if sub.Status == model.SubscriptionStatusExpired {
if err := s.repos.UserSubscription.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
return nil, err
}
}
@@ -313,17 +315,17 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}()
}
return s.repos.UserSubscription.GetByID(ctx, subscriptionID)
return s.userSubRepo.GetByID(ctx, subscriptionID)
}
// GetByID 根据ID获取订阅
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
return s.repos.UserSubscription.GetByID(ctx, id)
return s.userSubRepo.GetByID(ctx, id)
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
@@ -332,24 +334,29 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
// ListUserSubscriptions 获取用户的所有订阅
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListByUserID(ctx, userID)
return s.userSubRepo.ListByUserID(ctx, userID)
}
// ListActiveUserSubscriptions 获取用户的所有有效订阅
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
return s.userSubRepo.ListActiveByUserID(ctx, userID)
}
// ListGroupSubscriptions 获取分组的所有订阅
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *repository.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.ListByGroupID(ctx, groupID, params)
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
return s.userSubRepo.ListByGroupID(ctx, groupID, params)
}
// List 获取所有订阅(分页,支持筛选)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *repository.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.List(ctx, params, userID, groupID, status)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
return s.userSubRepo.List(ctx, params, userID, groupID, status)
}
// startOfDay 返回给定时间所在日期的零点(保持原时区)
func startOfDay(t time.Time) time.Time {
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
}
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
@@ -358,39 +365,50 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m
return nil
}
now := time.Now()
return s.repos.UserSubscription.ActivateWindows(ctx, sub.ID, now)
// 使用当天零点作为窗口起始时间
windowStart := startOfDay(time.Now())
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
}
// CheckAndResetWindows 检查并重置过期的窗口
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error {
now := time.Now()
// 使用当天零点作为新窗口起始时间
windowStart := startOfDay(time.Now())
needsInvalidateCache := false
// 日窗口重置24小时
if sub.NeedsDailyReset() {
if err := s.repos.UserSubscription.ResetDailyUsage(ctx, sub.ID, now); err != nil {
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil {
return err
}
sub.DailyWindowStart = &now
sub.DailyWindowStart = &windowStart
sub.DailyUsageUSD = 0
needsInvalidateCache = true
}
// 周窗口重置7天
if sub.NeedsWeeklyReset() {
if err := s.repos.UserSubscription.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil {
return err
}
sub.WeeklyWindowStart = &now
sub.WeeklyWindowStart = &windowStart
sub.WeeklyUsageUSD = 0
needsInvalidateCache = true
}
// 月窗口重置30天
if sub.NeedsMonthlyReset() {
if err := s.repos.UserSubscription.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, windowStart); err != nil {
return err
}
sub.MonthlyWindowStart = &now
sub.MonthlyWindowStart = &windowStart
sub.MonthlyUsageUSD = 0
needsInvalidateCache = true
}
// 如果有窗口被重置,失效 Redis 缓存以保持一致性
if needsInvalidateCache && s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
}
return nil
@@ -412,15 +430,15 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.U
// RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.repos.UserSubscription.IncrementUsage(ctx, subscriptionID, costUSD)
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
}
// SubscriptionProgress 订阅进度
type SubscriptionProgress struct {
ID int64 `json:"id"`
GroupName string `json:"group_name"`
ExpiresAt time.Time `json:"expires_at"`
ExpiresInDays int `json:"expires_in_days"`
ID int64 `json:"id"`
GroupName string `json:"group_name"`
ExpiresAt time.Time `json:"expires_at"`
ExpiresInDays int `json:"expires_in_days"`
Daily *UsageWindowProgress `json:"daily,omitempty"`
Weekly *UsageWindowProgress `json:"weekly,omitempty"`
Monthly *UsageWindowProgress `json:"monthly,omitempty"`
@@ -428,25 +446,25 @@ type SubscriptionProgress struct {
// UsageWindowProgress 使用窗口进度
type UsageWindowProgress struct {
LimitUSD float64 `json:"limit_usd"`
UsedUSD float64 `json:"used_usd"`
RemainingUSD float64 `json:"remaining_usd"`
Percentage float64 `json:"percentage"`
WindowStart time.Time `json:"window_start"`
ResetsAt time.Time `json:"resets_at"`
ResetsInSeconds int64 `json:"resets_in_seconds"`
LimitUSD float64 `json:"limit_usd"`
UsedUSD float64 `json:"used_usd"`
RemainingUSD float64 `json:"remaining_usd"`
Percentage float64 `json:"percentage"`
WindowStart time.Time `json:"window_start"`
ResetsAt time.Time `json:"resets_at"`
ResetsInSeconds int64 `json:"resets_in_seconds"`
}
// GetSubscriptionProgress 获取订阅使用进度
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
group := sub.Group
if group == nil {
group, err = s.repos.Group.GetByID(ctx, sub.GroupID)
group, err = s.groupRepo.GetByID(ctx, sub.GroupID)
if err != nil {
return nil, err
}
@@ -464,12 +482,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
limit := *group.DailyLimitUSD
resetsAt := sub.DailyWindowStart.Add(24 * time.Hour)
progress.Daily = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.DailyUsageUSD,
RemainingUSD: limit - sub.DailyUsageUSD,
Percentage: (sub.DailyUsageUSD / limit) * 100,
WindowStart: *sub.DailyWindowStart,
ResetsAt: resetsAt,
LimitUSD: limit,
UsedUSD: sub.DailyUsageUSD,
RemainingUSD: limit - sub.DailyUsageUSD,
Percentage: (sub.DailyUsageUSD / limit) * 100,
WindowStart: *sub.DailyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Daily.RemainingUSD < 0 {
@@ -488,12 +506,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
limit := *group.WeeklyLimitUSD
resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour)
progress.Weekly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.WeeklyUsageUSD,
RemainingUSD: limit - sub.WeeklyUsageUSD,
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
WindowStart: *sub.WeeklyWindowStart,
ResetsAt: resetsAt,
LimitUSD: limit,
UsedUSD: sub.WeeklyUsageUSD,
RemainingUSD: limit - sub.WeeklyUsageUSD,
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
WindowStart: *sub.WeeklyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Weekly.RemainingUSD < 0 {
@@ -512,12 +530,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
limit := *group.MonthlyLimitUSD
resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour)
progress.Monthly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.MonthlyUsageUSD,
RemainingUSD: limit - sub.MonthlyUsageUSD,
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
WindowStart: *sub.MonthlyWindowStart,
ResetsAt: resetsAt,
LimitUSD: limit,
UsedUSD: sub.MonthlyUsageUSD,
RemainingUSD: limit - sub.MonthlyUsageUSD,
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
WindowStart: *sub.MonthlyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Monthly.RemainingUSD < 0 {
@@ -536,7 +554,7 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
subs, err := s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
@@ -555,7 +573,7 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
return s.repos.UserSubscription.BatchUpdateExpiredStatus(ctx)
return s.userSubRepo.BatchUpdateExpiredStatus(ctx)
}
// ValidateSubscription 验证订阅是否有效
@@ -568,7 +586,7 @@ func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *mod
}
if sub.IsExpired() {
// 更新状态
_ = s.repos.UserSubscription.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
return ErrSubscriptionExpired
}
return nil

View File

@@ -0,0 +1,185 @@
package service
import (
"context"
"fmt"
"log"
"sync"
"time"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/service/ports"
)
// TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token
type TokenRefreshService struct {
accountRepo ports.AccountRepository
refreshers []TokenRefresher
cfg *config.TokenRefreshConfig
stopCh chan struct{}
wg sync.WaitGroup
}
// NewTokenRefreshService 创建token刷新服务
func NewTokenRefreshService(
accountRepo ports.AccountRepository,
oauthService *OAuthService,
cfg *config.Config,
) *TokenRefreshService {
s := &TokenRefreshService{
accountRepo: accountRepo,
cfg: &cfg.TokenRefresh,
stopCh: make(chan struct{}),
}
// 注册平台特定的刷新器
s.refreshers = []TokenRefresher{
NewClaudeTokenRefresher(oauthService),
// 未来可以添加其他平台的刷新器:
// NewOpenAITokenRefresher(...),
// NewGeminiTokenRefresher(...),
}
return s
}
// Start 启动后台刷新服务
func (s *TokenRefreshService) Start() {
if !s.cfg.Enabled {
log.Println("[TokenRefresh] Service disabled by configuration")
return
}
s.wg.Add(1)
go s.refreshLoop()
log.Printf("[TokenRefresh] Service started (check every %d minutes, refresh %v hours before expiry)",
s.cfg.CheckIntervalMinutes, s.cfg.RefreshBeforeExpiryHours)
}
// Stop 停止刷新服务
func (s *TokenRefreshService) Stop() {
close(s.stopCh)
s.wg.Wait()
log.Println("[TokenRefresh] Service stopped")
}
// refreshLoop 刷新循环
func (s *TokenRefreshService) refreshLoop() {
defer s.wg.Done()
// 计算检查间隔
checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute
if checkInterval < time.Minute {
checkInterval = 5 * time.Minute
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
// 启动时立即执行一次检查
s.processRefresh()
for {
select {
case <-ticker.C:
s.processRefresh()
case <-s.stopCh:
return
}
}
}
// processRefresh 执行一次刷新检查
func (s *TokenRefreshService) processRefresh() {
ctx := context.Background()
// 计算刷新窗口
refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour))
// 获取所有active状态的账号
accounts, err := s.listActiveAccounts(ctx)
if err != nil {
log.Printf("[TokenRefresh] Failed to list accounts: %v", err)
return
}
refreshed, failed := 0, 0
for i := range accounts {
account := &accounts[i]
// 遍历所有刷新器,找到能处理此账号的
for _, refresher := range s.refreshers {
if !refresher.CanRefresh(account) {
continue
}
// 检查是否需要刷新
if !refresher.NeedsRefresh(account, refreshWindow) {
continue
}
// 执行刷新
if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
log.Printf("[TokenRefresh] Account %d (%s) failed: %v", account.ID, account.Name, err)
failed++
} else {
log.Printf("[TokenRefresh] Account %d (%s) refreshed successfully", account.ID, account.Name)
refreshed++
}
// 每个账号只由一个refresher处理
break
}
}
if refreshed > 0 || failed > 0 {
log.Printf("[TokenRefresh] Cycle complete: %d refreshed, %d failed", refreshed, failed)
}
}
// listActiveAccounts 获取所有active状态的账号
// 使用ListActive确保刷新所有活跃账号的token包括临时禁用的
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]model.Account, error) {
return s.accountRepo.ListActive(ctx)
}
// refreshWithRetry 带重试的刷新
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *model.Account, refresher TokenRefresher) error {
var lastErr error
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
newCredentials, err := refresher.Refresh(ctx, account)
if err == nil {
// 刷新成功更新账号credentials
account.Credentials = model.JSONB(newCredentials)
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
return nil
}
lastErr = err
log.Printf("[TokenRefresh] Account %d attempt %d/%d failed: %v",
account.ID, attempt, s.cfg.MaxRetries, err)
// 如果还有重试机会,等待后重试
if attempt < s.cfg.MaxRetries {
// 指数退避2^(attempt-1) * baseSeconds
backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1))
time.Sleep(backoff)
}
}
// 所有重试都失败标记账号为error状态
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err)
}
return lastErr
}

View File

@@ -0,0 +1,90 @@
package service
import (
"context"
"strconv"
"time"
"sub2api/internal/model"
)
// TokenRefresher 定义平台特定的token刷新策略接口
// 通过此接口可以扩展支持不同平台Anthropic/OpenAI/Gemini
type TokenRefresher interface {
// CanRefresh 检查此刷新器是否能处理指定账号
CanRefresh(account *model.Account) bool
// NeedsRefresh 检查账号的token是否需要刷新
NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool
// Refresh 执行token刷新返回更新后的credentials
// 注意返回的map应该保留原有credentials中的所有字段只更新token相关字段
Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error)
}
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
type ClaudeTokenRefresher struct {
oauthService *OAuthService
}
// NewClaudeTokenRefresher 创建Claude token刷新器
func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
return &ClaudeTokenRefresher{
oauthService: oauthService,
}
}
// CanRefresh 检查是否能处理此账号
// 只处理 anthropic 平台的 oauth 类型账号
// setup-token 虽然也是OAuth但有效期1年不需要频繁刷新
func (r *ClaudeTokenRefresher) CanRefresh(account *model.Account) bool {
return account.Platform == model.PlatformAnthropic &&
account.Type == model.AccountTypeOAuth
}
// NeedsRefresh 检查token是否需要刷新
// 基于 expires_at 字段判断是否在刷新窗口内
func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
expiresAtStr := account.GetCredential("expires_at")
if expiresAtStr == "" {
return false
}
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err != nil {
return false
}
expiryTime := time.Unix(expiresAt, 0)
return time.Until(expiryTime) < refreshWindow
}
// Refresh 执行token刷新
// 保留原有credentials中的所有字段只更新token相关字段
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error) {
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err
}
// 保留现有credentials中的所有字段
newCredentials := make(map[string]interface{})
for k, v := range account.Credentials {
newCredentials[k] = v
}
// 只更新token相关字段
// 注意expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
return newCredentials, nil
}

View File

@@ -2,14 +2,9 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
)
var (
@@ -19,10 +14,15 @@ var (
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
// TurnstileVerifier 验证 Turnstile token 的接口
type TurnstileVerifier interface {
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
}
// TurnstileService Turnstile 验证服务
type TurnstileService struct {
settingService *SettingService
httpClient *http.Client
verifier TurnstileVerifier
}
// TurnstileVerifyResponse Cloudflare Turnstile 验证响应
@@ -36,12 +36,10 @@ type TurnstileVerifyResponse struct {
}
// NewTurnstileService 创建 Turnstile 服务实例
func NewTurnstileService(settingService *SettingService) *TurnstileService {
func NewTurnstileService(settingService *SettingService, verifier TurnstileVerifier) *TurnstileService {
return &TurnstileService{
settingService: settingService,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
verifier: verifier,
}
}
@@ -66,35 +64,12 @@ func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remote
return ErrTurnstileVerificationFailed
}
// 构建请求
formData := url.Values{}
formData.Set("secret", secretKey)
formData.Set("response", token)
if remoteIP != "" {
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// 发送请求
log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP)
resp, err := s.httpClient.Do(req)
result, err := s.verifier.VerifyToken(ctx, secretKey, token, remoteIP)
if err != nil {
log.Printf("[Turnstile] Request failed: %v", err)
return fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
// 解析响应
var result TurnstileVerifyResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
log.Printf("[Turnstile] Failed to decode response: %v", err)
return fmt.Errorf("decode response: %w", err)
}
if !result.Success {
log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)

Some files were not shown because too many files have changed in this diff Show More