mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 15:32:13 +08:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
adebd941e1 | ||
|
|
bb500b7b2a | ||
|
|
cceada7dae | ||
|
|
5c2e7ae265 | ||
|
|
420bedd615 | ||
|
|
a79f6c5e1e | ||
|
|
0484c59ead | ||
|
|
7bbf621490 | ||
|
|
ef81aeb463 | ||
|
|
22414326cc | ||
|
|
14b155c66b | ||
|
|
e99b344b2b | ||
|
|
7fd94ab78b | ||
|
|
078529e51e | ||
|
|
23a4cf11c8 | ||
|
|
d1f0902ec0 | ||
|
|
ee86dbca9d | ||
|
|
733d4c2b85 | ||
|
|
406d3f3cab | ||
|
|
1ed93a5fd0 | ||
|
|
463ddea36f | ||
|
|
e769f67699 | ||
|
|
52d2ae9708 | ||
|
|
2e59998c51 | ||
|
|
32e58115cc | ||
|
|
ba27026399 | ||
|
|
c15b419c4c | ||
|
|
5bd27a5d17 | ||
|
|
0e7b8aab8c | ||
|
|
236908c03d | ||
|
|
67d028cf50 | ||
|
|
66ba487697 |
23
README.md
23
README.md
@@ -16,6 +16,14 @@ English | [中文](README_CN.md)
|
||||
|
||||
---
|
||||
|
||||
## Demo
|
||||
|
||||
Try Sub2API online: **https://v2.pincc.ai/**
|
||||
|
||||
| Email | Password |
|
||||
|-------|----------|
|
||||
| admin@sub2api.com | admin123 |
|
||||
|
||||
## Overview
|
||||
|
||||
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
|
||||
@@ -208,20 +216,19 @@ Build and run from source code for development or customization.
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||
cd sub2api
|
||||
|
||||
# 2. Build backend
|
||||
cd backend
|
||||
go build -o sub2api ./cmd/server
|
||||
|
||||
# 3. Build frontend
|
||||
cd ../frontend
|
||||
# 2. Build frontend
|
||||
cd frontend
|
||||
npm install
|
||||
npm run build
|
||||
|
||||
# 4. Copy frontend build to backend (for embedding)
|
||||
# 3. Copy frontend build to backend (for embedding)
|
||||
cp -r dist ../backend/internal/web/
|
||||
|
||||
# 5. Create configuration file
|
||||
# 4. Build backend (requires frontend dist to be present)
|
||||
cd ../backend
|
||||
go build -o sub2api ./cmd/server
|
||||
|
||||
# 5. Create configuration file
|
||||
cp ../deploy/config.example.yaml ./config.yaml
|
||||
|
||||
# 6. Edit configuration
|
||||
|
||||
23
README_CN.md
23
README_CN.md
@@ -16,6 +16,14 @@
|
||||
|
||||
---
|
||||
|
||||
## 在线体验
|
||||
|
||||
体验地址:**https://v2.pincc.ai/**
|
||||
|
||||
| 邮箱 | 密码 |
|
||||
|------|------|
|
||||
| admin@sub2api.com | admin123 |
|
||||
|
||||
## 项目概述
|
||||
|
||||
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
|
||||
@@ -208,20 +216,19 @@ docker-compose logs -f
|
||||
git clone https://github.com/Wei-Shaw/sub2api.git
|
||||
cd sub2api
|
||||
|
||||
# 2. 编译后端
|
||||
cd backend
|
||||
go build -o sub2api ./cmd/server
|
||||
|
||||
# 3. 编译前端
|
||||
cd ../frontend
|
||||
# 2. 编译前端
|
||||
cd frontend
|
||||
npm install
|
||||
npm run build
|
||||
|
||||
# 4. 复制前端构建产物到后端(用于嵌入)
|
||||
# 3. 复制前端构建产物到后端(用于嵌入)
|
||||
cp -r dist ../backend/internal/web/
|
||||
|
||||
# 5. 创建配置文件
|
||||
# 4. 编译后端(需要前端 dist 目录存在)
|
||||
cd ../backend
|
||||
go build -o sub2api ./cmd/server
|
||||
|
||||
# 5. 创建配置文件
|
||||
cp ../deploy/config.example.yaml ./config.yaml
|
||||
|
||||
# 6. 编辑配置
|
||||
|
||||
@@ -18,18 +18,10 @@ import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"sub2api/internal/setup"
|
||||
"sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
//go:embed VERSION
|
||||
@@ -103,8 +95,10 @@ func runSetupServer() {
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
}
|
||||
|
||||
addr := ":8080"
|
||||
log.Printf("Setup wizard available at http://localhost%s", addr)
|
||||
// Get server address from config.yaml or environment variables (SERVER_HOST, SERVER_PORT)
|
||||
// This allows users to run setup on a different address if needed
|
||||
addr := config.GetServerAddress()
|
||||
log.Printf("Setup wizard available at http://%s", addr)
|
||||
log.Println("Complete the setup wizard to configure Sub2API")
|
||||
|
||||
if err := r.Run(addr); err != nil {
|
||||
@@ -149,319 +143,3 @@ func runMainServer() {
|
||||
|
||||
log.Println("Server exited")
|
||||
}
|
||||
|
||||
func initDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
// 初始化时区(在数据库连接之前,确保时区设置正确)
|
||||
if err := timezone.Init(cfg.Timezone); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gormConfig := &gorm.Config{}
|
||||
if cfg.Server.Mode == "debug" {
|
||||
gormConfig.Logger = logger.Default.LogMode(logger.Info)
|
||||
}
|
||||
|
||||
// 使用带时区的 DSN 连接数据库
|
||||
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 自动迁移(始终执行,确保数据库结构与代码同步)
|
||||
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
|
||||
if err := model.AutoMigrate(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initRedis(cfg *config.Config) *redis.Client {
|
||||
return redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
})
|
||||
}
|
||||
|
||||
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
|
||||
// 健康检查
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Setup status endpoint (always returns needs_setup: false in normal mode)
|
||||
// This is used by the frontend to detect when the service has restarted after setup
|
||||
r.GET("/setup/status", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": gin.H{
|
||||
"needs_setup": false,
|
||||
"step": "completed",
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
// API v1
|
||||
v1 := r.Group("/api/v1")
|
||||
{
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
settings := v1.Group("/settings")
|
||||
{
|
||||
settings.GET("/public", h.Setting.GetPublicSettings)
|
||||
}
|
||||
|
||||
// 需要认证的接口
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
|
||||
{
|
||||
// 当前用户信息
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
|
||||
// 用户接口
|
||||
user := authenticated.Group("/user")
|
||||
{
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
}
|
||||
|
||||
// API Key管理
|
||||
keys := authenticated.Group("/keys")
|
||||
{
|
||||
keys.GET("", h.APIKey.List)
|
||||
keys.GET("/:id", h.APIKey.GetByID)
|
||||
keys.POST("", h.APIKey.Create)
|
||||
keys.PUT("/:id", h.APIKey.Update)
|
||||
keys.DELETE("/:id", h.APIKey.Delete)
|
||||
}
|
||||
|
||||
// 用户可用分组(非管理员接口)
|
||||
groups := authenticated.Group("/groups")
|
||||
{
|
||||
groups.GET("/available", h.APIKey.GetAvailableGroups)
|
||||
}
|
||||
|
||||
// 使用记录
|
||||
usage := authenticated.Group("/usage")
|
||||
{
|
||||
usage.GET("", h.Usage.List)
|
||||
usage.GET("/:id", h.Usage.GetByID)
|
||||
usage.GET("/stats", h.Usage.Stats)
|
||||
// User dashboard endpoints
|
||||
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
|
||||
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
|
||||
usage.GET("/dashboard/models", h.Usage.DashboardModels)
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
|
||||
}
|
||||
|
||||
// 卡密兑换
|
||||
redeem := authenticated.Group("/redeem")
|
||||
{
|
||||
redeem.POST("", h.Redeem.Redeem)
|
||||
redeem.GET("/history", h.Redeem.GetHistory)
|
||||
}
|
||||
|
||||
// 用户订阅
|
||||
subscriptions := authenticated.Group("/subscriptions")
|
||||
{
|
||||
subscriptions.GET("", h.Subscription.List)
|
||||
subscriptions.GET("/active", h.Subscription.GetActive)
|
||||
subscriptions.GET("/progress", h.Subscription.GetProgress)
|
||||
subscriptions.GET("/summary", h.Subscription.GetSummary)
|
||||
}
|
||||
}
|
||||
|
||||
// 管理员接口
|
||||
admin := v1.Group("/admin")
|
||||
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
|
||||
{
|
||||
// 仪表盘
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
|
||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
|
||||
}
|
||||
|
||||
// 用户管理
|
||||
users := admin.Group("/users")
|
||||
{
|
||||
users.GET("", h.Admin.User.List)
|
||||
users.GET("/:id", h.Admin.User.GetByID)
|
||||
users.POST("", h.Admin.User.Create)
|
||||
users.PUT("/:id", h.Admin.User.Update)
|
||||
users.DELETE("/:id", h.Admin.User.Delete)
|
||||
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
||||
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||
}
|
||||
|
||||
// 分组管理
|
||||
groups := admin.Group("/groups")
|
||||
{
|
||||
groups.GET("", h.Admin.Group.List)
|
||||
groups.GET("/all", h.Admin.Group.GetAll)
|
||||
groups.GET("/:id", h.Admin.Group.GetByID)
|
||||
groups.POST("", h.Admin.Group.Create)
|
||||
groups.PUT("/:id", h.Admin.Group.Update)
|
||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||
}
|
||||
|
||||
// 账号管理
|
||||
accounts := admin.Group("/accounts")
|
||||
{
|
||||
accounts.GET("", h.Admin.Account.List)
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
|
||||
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
|
||||
// OAuth routes
|
||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
|
||||
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
|
||||
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
|
||||
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
|
||||
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
|
||||
}
|
||||
|
||||
// 代理管理
|
||||
proxies := admin.Group("/proxies")
|
||||
{
|
||||
proxies.GET("", h.Admin.Proxy.List)
|
||||
proxies.GET("/all", h.Admin.Proxy.GetAll)
|
||||
proxies.GET("/:id", h.Admin.Proxy.GetByID)
|
||||
proxies.POST("", h.Admin.Proxy.Create)
|
||||
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
||||
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
|
||||
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
||||
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
||||
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
||||
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
|
||||
}
|
||||
|
||||
// 卡密管理
|
||||
codes := admin.Group("/redeem-codes")
|
||||
{
|
||||
codes.GET("", h.Admin.Redeem.List)
|
||||
codes.GET("/stats", h.Admin.Redeem.GetStats)
|
||||
codes.GET("/export", h.Admin.Redeem.Export)
|
||||
codes.GET("/:id", h.Admin.Redeem.GetByID)
|
||||
codes.POST("/generate", h.Admin.Redeem.Generate)
|
||||
codes.DELETE("/:id", h.Admin.Redeem.Delete)
|
||||
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
|
||||
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
|
||||
}
|
||||
|
||||
// 系统设置
|
||||
adminSettings := admin.Group("/settings")
|
||||
{
|
||||
adminSettings.GET("", h.Admin.Setting.GetSettings)
|
||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
|
||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||
}
|
||||
|
||||
// 系统管理
|
||||
system := admin.Group("/system")
|
||||
{
|
||||
system.GET("/version", h.Admin.System.GetVersion)
|
||||
system.GET("/check-updates", h.Admin.System.CheckUpdates)
|
||||
system.POST("/update", h.Admin.System.PerformUpdate)
|
||||
system.POST("/rollback", h.Admin.System.Rollback)
|
||||
system.POST("/restart", h.Admin.System.RestartService)
|
||||
}
|
||||
|
||||
// 订阅管理
|
||||
subscriptions := admin.Group("/subscriptions")
|
||||
{
|
||||
subscriptions.GET("", h.Admin.Subscription.List)
|
||||
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
|
||||
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
|
||||
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
||||
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
||||
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
||||
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
||||
}
|
||||
|
||||
// 分组下的订阅列表
|
||||
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
|
||||
|
||||
// 用户下的订阅列表
|
||||
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
|
||||
|
||||
// 使用记录管理
|
||||
usage := admin.Group("/usage")
|
||||
{
|
||||
usage.GET("", h.Admin.Usage.List)
|
||||
usage.GET("/stats", h.Admin.Usage.Stats)
|
||||
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
|
||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// API网关(Claude API兼容)
|
||||
gateway := r.Group("/v1")
|
||||
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
// setupRouter 配置路由器中间件和路由
|
||||
func setupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
|
||||
// 应用中间件
|
||||
r.Use(middleware.Logger())
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, services, repos)
|
||||
|
||||
// Serve embedded frontend if available
|
||||
if web.HasEmbeddedFrontend() {
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// createHTTPServer 创建HTTP服务器
|
||||
func createHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
|
||||
return &http.Server{
|
||||
Addr: cfg.Server.Address(),
|
||||
Handler: router,
|
||||
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
|
||||
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
|
||||
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
|
||||
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
|
||||
// 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟
|
||||
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,12 +6,16 @@ package main
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/infrastructure"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/server"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
@@ -24,80 +28,90 @@ type Application struct {
|
||||
|
||||
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
wire.Build(
|
||||
// Config provider
|
||||
provideConfig,
|
||||
// 基础设施层 ProviderSets
|
||||
config.ProviderSet,
|
||||
infrastructure.ProviderSet,
|
||||
|
||||
// Database provider
|
||||
provideDB,
|
||||
// 业务层 ProviderSets
|
||||
repository.ProviderSet,
|
||||
service.ProviderSet,
|
||||
handler.ProviderSet,
|
||||
|
||||
// Redis provider
|
||||
provideRedis,
|
||||
// 服务器层 ProviderSet
|
||||
server.ProviderSet,
|
||||
|
||||
// Repository provider
|
||||
provideRepositories,
|
||||
// BuildInfo provider
|
||||
provideServiceBuildInfo,
|
||||
|
||||
// Service provider
|
||||
provideServices,
|
||||
|
||||
// Handler provider
|
||||
provideHandlers,
|
||||
|
||||
// Router provider
|
||||
provideRouter,
|
||||
|
||||
// HTTP Server provider
|
||||
provideHTTPServer,
|
||||
|
||||
// Cleanup provider
|
||||
// 清理函数提供者
|
||||
provideCleanup,
|
||||
|
||||
// Application provider
|
||||
// 应用程序结构体
|
||||
wire.Struct(new(Application), "Server", "Cleanup"),
|
||||
)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func provideConfig() (*config.Config, error) {
|
||||
return config.Load()
|
||||
}
|
||||
|
||||
func provideDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
return initDB(cfg)
|
||||
}
|
||||
|
||||
func provideRedis(cfg *config.Config) *redis.Client {
|
||||
return initRedis(cfg)
|
||||
}
|
||||
|
||||
func provideRepositories(db *gorm.DB) *repository.Repositories {
|
||||
return repository.NewRepositories(db)
|
||||
}
|
||||
|
||||
func provideServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *service.Services {
|
||||
return service.NewServices(repos, rdb, cfg)
|
||||
}
|
||||
|
||||
func provideHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo handler.BuildInfo) *handler.Handlers {
|
||||
return handler.NewHandlers(services, repos, rdb, buildInfo)
|
||||
}
|
||||
|
||||
func provideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
BuildType: buildInfo.BuildType,
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
|
||||
return setupRouter(r, cfg, handlers, services, repos)
|
||||
}
|
||||
|
||||
func provideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
|
||||
return createHTTPServer(cfg, router)
|
||||
}
|
||||
|
||||
func provideCleanup() func() {
|
||||
func provideCleanup(
|
||||
db *gorm.DB,
|
||||
rdb *redis.Client,
|
||||
services *service.Services,
|
||||
) func() {
|
||||
return func() {
|
||||
// @todo
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Cleanup steps in reverse dependency order
|
||||
cleanupSteps := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
services.TokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
services.Pricing.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"EmailQueueService", func() error {
|
||||
services.EmailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
{"Database", func() error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}},
|
||||
}
|
||||
|
||||
for _, step := range cleanupSteps {
|
||||
if err := step.fn(); err != nil {
|
||||
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
||||
// Continue with remaining cleanup steps even if one fails
|
||||
} else {
|
||||
log.Printf("[Cleanup] %s succeeded", step.name)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if context timed out
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
|
||||
default:
|
||||
log.Printf("[Cleanup] All cleanup steps completed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,14 +7,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"context"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"net/http"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/handler/admin"
|
||||
"sub2api/internal/infrastructure"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/server"
|
||||
"sub2api/internal/service"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
@@ -24,23 +29,134 @@ import (
|
||||
// Injectors from wire.go:
|
||||
|
||||
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
config, err := provideConfig()
|
||||
configConfig, err := config.ProvideConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db, err := provideDB(config)
|
||||
db, err := infrastructure.ProvideDB(configConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
repositories := provideRepositories(db)
|
||||
client := provideRedis(config)
|
||||
services := provideServices(repositories, client, config)
|
||||
handlers := provideHandlers(services, repositories, client, buildInfo)
|
||||
engine := provideRouter(config, handlers, services, repositories)
|
||||
server := provideHTTPServer(config, engine)
|
||||
v := provideCleanup()
|
||||
userRepository := repository.NewUserRepository(db)
|
||||
settingRepository := repository.NewSettingRepository(db)
|
||||
settingService := service.NewSettingService(settingRepository, configConfig)
|
||||
client := infrastructure.ProvideRedis(configConfig)
|
||||
emailCache := repository.NewEmailCache(client)
|
||||
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||
turnstileVerifier := repository.NewTurnstileVerifier()
|
||||
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
|
||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(authService)
|
||||
userService := service.NewUserService(userRepository, configConfig)
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyRepository := repository.NewApiKeyRepository(db)
|
||||
groupRepository := repository.NewGroupRepository(db)
|
||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
|
||||
apiKeyCache := repository.NewApiKeyCache(client)
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository)
|
||||
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
|
||||
billingCache := repository.NewBillingCache(client)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||
redeemCache := repository.NewRedeemCache(client)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
accountRepository := repository.NewAccountRepository(db)
|
||||
proxyRepository := repository.NewProxyRepository(db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService, proxyExitInfoProber)
|
||||
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
|
||||
adminUserHandler := admin.NewUserHandler(adminService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService, claudeUsageFetcher)
|
||||
claudeUpstream := repository.NewClaudeUpstream(configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
||||
updateCache := repository.NewUpdateCache(client)
|
||||
gitHubReleaseClient := repository.NewGitHubReleaseClient()
|
||||
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
|
||||
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
|
||||
systemHandler := handler.ProvideSystemHandler(updateService)
|
||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
gatewayCache := repository.NewGatewayCache(client)
|
||||
pricingRemoteClient := repository.NewPricingRemoteClient()
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityCache := repository.NewIdentityCache(client)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
||||
groupService := service.NewGroupService(groupRepository)
|
||||
accountService := service.NewAccountService(accountRepository, groupRepository)
|
||||
proxyService := service.NewProxyService(proxyRepository)
|
||||
services := &service.Services{
|
||||
Auth: authService,
|
||||
User: userService,
|
||||
ApiKey: apiKeyService,
|
||||
Group: groupService,
|
||||
Account: accountService,
|
||||
Proxy: proxyService,
|
||||
Redeem: redeemService,
|
||||
Usage: usageService,
|
||||
Pricing: pricingService,
|
||||
Billing: billingService,
|
||||
BillingCache: billingCacheService,
|
||||
Admin: adminService,
|
||||
Gateway: gatewayService,
|
||||
OAuth: oAuthService,
|
||||
RateLimit: rateLimitService,
|
||||
AccountUsage: accountUsageService,
|
||||
AccountTest: accountTestService,
|
||||
Setting: settingService,
|
||||
Email: emailService,
|
||||
EmailQueue: emailQueueService,
|
||||
Turnstile: turnstileService,
|
||||
Subscription: subscriptionService,
|
||||
Concurrency: concurrencyService,
|
||||
Identity: identityService,
|
||||
Update: updateService,
|
||||
TokenRefresh: tokenRefreshService,
|
||||
}
|
||||
repositories := &repository.Repositories{
|
||||
User: userRepository,
|
||||
ApiKey: apiKeyRepository,
|
||||
Group: groupRepository,
|
||||
Account: accountRepository,
|
||||
Proxy: proxyRepository,
|
||||
RedeemCode: redeemCodeRepository,
|
||||
UsageLog: usageLogRepository,
|
||||
Setting: settingRepository,
|
||||
UserSubscription: userSubscriptionRepository,
|
||||
}
|
||||
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
v := provideCleanup(db, client, services)
|
||||
application := &Application{
|
||||
Server: server,
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
}
|
||||
return application, nil
|
||||
@@ -53,47 +169,64 @@ type Application struct {
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
func provideConfig() (*config.Config, error) {
|
||||
return config.Load()
|
||||
}
|
||||
|
||||
func provideDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
return initDB(cfg)
|
||||
}
|
||||
|
||||
func provideRedis(cfg *config.Config) *redis.Client {
|
||||
return initRedis(cfg)
|
||||
}
|
||||
|
||||
func provideRepositories(db *gorm.DB) *repository.Repositories {
|
||||
return repository.NewRepositories(db)
|
||||
}
|
||||
|
||||
func provideServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *service.Services {
|
||||
return service.NewServices(repos, rdb, cfg)
|
||||
}
|
||||
|
||||
func provideHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo handler.BuildInfo) *handler.Handlers {
|
||||
return handler.NewHandlers(services, repos, rdb, buildInfo)
|
||||
}
|
||||
|
||||
func provideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
BuildType: buildInfo.BuildType,
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
|
||||
return setupRouter(r, cfg, handlers, services, repos)
|
||||
}
|
||||
|
||||
func provideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
|
||||
return createHTTPServer(cfg, router)
|
||||
}
|
||||
|
||||
func provideCleanup() func() {
|
||||
func provideCleanup(
|
||||
db *gorm.DB,
|
||||
rdb *redis.Client,
|
||||
services *service.Services,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cleanupSteps := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
services.TokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
services.Pricing.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"EmailQueueService", func() error {
|
||||
services.EmailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
{"Database", func() error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}},
|
||||
}
|
||||
|
||||
for _, step := range cleanupSteps {
|
||||
if err := step.fn(); err != nil {
|
||||
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
||||
|
||||
} else {
|
||||
log.Printf("[Cleanup] %s succeeded", step.name)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
|
||||
default:
|
||||
log.Printf("[Cleanup] All cleanup steps completed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,15 +8,30 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
}
|
||||
|
||||
// TokenRefreshConfig OAuth token自动刷新配置
|
||||
type TokenRefreshConfig struct {
|
||||
// 是否启用自动刷新
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// 检查间隔(分钟)
|
||||
CheckIntervalMinutes int `mapstructure:"check_interval_minutes"`
|
||||
// 提前刷新时间(小时),在token过期前多久开始刷新
|
||||
RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"`
|
||||
// 最大重试次数
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
// 重试退避基础时间(秒)
|
||||
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||
}
|
||||
|
||||
type PricingConfig struct {
|
||||
@@ -192,6 +207,13 @@ func setDefaults() {
|
||||
|
||||
// Gateway
|
||||
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||
|
||||
// TokenRefresh
|
||||
viper.SetDefault("token_refresh.enabled", true)
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
@@ -203,3 +225,29 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServerAddress returns the server address (host:port) from config file or environment variable.
|
||||
// This is a lightweight function that can be used before full config validation,
|
||||
// such as during setup wizard startup.
|
||||
// Priority: config.yaml > environment variables > defaults
|
||||
func GetServerAddress() string {
|
||||
v := viper.New()
|
||||
v.SetConfigName("config")
|
||||
v.SetConfigType("yaml")
|
||||
v.AddConfigPath(".")
|
||||
v.AddConfigPath("./config")
|
||||
v.AddConfigPath("/etc/sub2api")
|
||||
|
||||
// Support SERVER_HOST and SERVER_PORT environment variables
|
||||
v.AutomaticEnv()
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
v.SetDefault("server.host", "0.0.0.0")
|
||||
v.SetDefault("server.port", 8080)
|
||||
|
||||
// Try to read config file (ignore errors if not found)
|
||||
_ = v.ReadInConfig()
|
||||
|
||||
host := v.GetString("server.host")
|
||||
port := v.GetInt("server.port")
|
||||
return fmt.Sprintf("%s:%d", host, port)
|
||||
}
|
||||
|
||||
13
backend/internal/config/wire.go
Normal file
13
backend/internal/config/wire.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package config
|
||||
|
||||
import "github.com/google/wire"
|
||||
|
||||
// ProviderSet 提供配置层的依赖
|
||||
var ProviderSet = wire.NewSet(
|
||||
ProvideConfig,
|
||||
)
|
||||
|
||||
// ProvideConfig 提供应用配置
|
||||
func ProvideConfig() (*Config, error) {
|
||||
return Load()
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
|
||||
@@ -186,6 +187,11 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "Account deleted successfully"})
|
||||
}
|
||||
|
||||
// TestAccountRequest represents the request body for testing an account
|
||||
type TestAccountRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
}
|
||||
|
||||
// Test handles testing account connectivity with SSE streaming
|
||||
// POST /api/v1/admin/accounts/:id/test
|
||||
func (h *AccountHandler) Test(c *gin.Context) {
|
||||
@@ -195,8 +201,12 @@ func (h *AccountHandler) Test(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var req TestAccountRequest
|
||||
// Allow empty body, model_id is optional
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
// Use AccountTestService to test the account with SSE streaming
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID); err != nil {
|
||||
if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil {
|
||||
// Error already sent via SSE, just log
|
||||
return
|
||||
}
|
||||
@@ -231,16 +241,20 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Update account credentials
|
||||
newCredentials := map[string]interface{}{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"token_type": tokenInfo.TokenType,
|
||||
"expires_in": tokenInfo.ExpiresIn,
|
||||
"expires_at": tokenInfo.ExpiresAt,
|
||||
"refresh_token": tokenInfo.RefreshToken,
|
||||
"scope": tokenInfo.Scope,
|
||||
// Copy existing credentials to preserve non-token settings (e.g., intercept_warmup_requests)
|
||||
newCredentials := make(map[string]interface{})
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
|
||||
// Update token-related fields
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = tokenInfo.ExpiresIn
|
||||
newCredentials["expires_at"] = tokenInfo.ExpiresAt
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
@@ -535,3 +549,58 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
|
||||
|
||||
response.Success(c, account)
|
||||
}
|
||||
|
||||
// GetAvailableModels handles getting available models for an account
|
||||
// GET /api/v1/admin/accounts/:id/models
|
||||
func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
response.Success(c, claude.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: return models based on model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if mapping == nil || len(mapping) == 0 {
|
||||
// No mapping configured, return default models
|
||||
response.Success(c, claude.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// Return mapped models (keys of the mapping are the available model IDs)
|
||||
var models []claude.Model
|
||||
for requestedModel := range mapping {
|
||||
// Try to find display info from default models
|
||||
var found bool
|
||||
for _, dm := range claude.DefaultModels {
|
||||
if dm.ID == requestedModel {
|
||||
models = append(models, dm)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// If not found in defaults, create a basic entry
|
||||
if !found {
|
||||
models = append(models, claude.Model{
|
||||
ID: requestedModel,
|
||||
Type: "model",
|
||||
DisplayName: requestedModel,
|
||||
CreatedAt: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, models)
|
||||
}
|
||||
|
||||
@@ -127,12 +127,25 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
||||
|
||||
// GetUsageTrend handles getting usage trend data
|
||||
// GET /api/v1/admin/dashboard/trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour)
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
|
||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
trend, err := h.usageRepo.GetUsageTrend(c.Request.Context(), startTime, endTime, granularity)
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = id
|
||||
}
|
||||
}
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
|
||||
apiKeyID = id
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.usageRepo.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
@@ -148,11 +161,24 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
|
||||
// GetModelStats handles getting model usage statistics
|
||||
// GET /api/v1/admin/dashboard/models
|
||||
// Query params: start_date, end_date (YYYY-MM-DD)
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
|
||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
stats, err := h.usageRepo.GetModelStats(c.Request.Context(), startTime, endTime)
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = id
|
||||
}
|
||||
}
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
|
||||
apiKeyID = id
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/service"
|
||||
@@ -112,12 +113,12 @@ func (h *ProxyHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||
Name: req.Name,
|
||||
Protocol: req.Protocol,
|
||||
Host: req.Host,
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Protocol: strings.TrimSpace(req.Protocol),
|
||||
Host: strings.TrimSpace(req.Host),
|
||||
Port: req.Port,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
Username: strings.TrimSpace(req.Username),
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create proxy: "+err.Error())
|
||||
@@ -143,13 +144,13 @@ func (h *ProxyHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
|
||||
Name: req.Name,
|
||||
Protocol: req.Protocol,
|
||||
Host: req.Host,
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Protocol: strings.TrimSpace(req.Protocol),
|
||||
Host: strings.TrimSpace(req.Host),
|
||||
Port: req.Port,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
Status: req.Status,
|
||||
Username: strings.TrimSpace(req.Username),
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
Status: strings.TrimSpace(req.Status),
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update proxy: "+err.Error())
|
||||
@@ -263,8 +264,14 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
|
||||
skipped := 0
|
||||
|
||||
for _, item := range req.Proxies {
|
||||
// Trim all string fields
|
||||
host := strings.TrimSpace(item.Host)
|
||||
protocol := strings.TrimSpace(item.Protocol)
|
||||
username := strings.TrimSpace(item.Username)
|
||||
password := strings.TrimSpace(item.Password)
|
||||
|
||||
// Check for duplicates (same host, port, username, password)
|
||||
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), item.Host, item.Port, item.Username, item.Password)
|
||||
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
|
||||
return
|
||||
@@ -278,11 +285,11 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
|
||||
// Create proxy with default name
|
||||
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||
Name: "default",
|
||||
Protocol: item.Protocol,
|
||||
Host: item.Host,
|
||||
Protocol: protocol,
|
||||
Host: host,
|
||||
Port: item.Port,
|
||||
Username: item.Username,
|
||||
Password: item.Password,
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
// If creation fails due to duplicate, count as skipped
|
||||
|
||||
@@ -4,15 +4,15 @@ import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// toResponsePagination converts repository.PaginationResult to response.PaginationResult
|
||||
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult {
|
||||
// toResponsePagination converts pagination.PaginationResult to response.PaginationResult
|
||||
func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// SystemHandler handles system-related operations
|
||||
@@ -18,9 +17,9 @@ type SystemHandler struct {
|
||||
}
|
||||
|
||||
// NewSystemHandler creates a new SystemHandler
|
||||
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
|
||||
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
|
||||
return &SystemHandler{
|
||||
updateSvc: service.NewUpdateService(rdb, version, buildType),
|
||||
updateSvc: updateSvc,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
@@ -14,10 +15,10 @@ import (
|
||||
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageRepo *repository.UsageLogRepository
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
usageService *service.UsageService
|
||||
adminService service.AdminService
|
||||
usageRepo *repository.UsageLogRepository
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
usageService *service.UsageService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
@@ -82,7 +83,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := repository.UsageLogFilters{
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -53,7 +53,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,10 +7,12 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -126,6 +128,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
@@ -285,29 +297,8 @@ func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id
|
||||
// Models handles listing available models
|
||||
// GET /v1/models
|
||||
func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
models := []gin.H{
|
||||
{
|
||||
"id": "claude-opus-4-5-20251101",
|
||||
"type": "model",
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"created_at": "2025-11-01T00:00:00Z",
|
||||
},
|
||||
{
|
||||
"id": "claude-sonnet-4-5-20250929",
|
||||
"type": "model",
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"created_at": "2025-09-29T00:00:00Z",
|
||||
},
|
||||
{
|
||||
"id": "claude-haiku-4-5-20251001",
|
||||
"type": "model",
|
||||
"display_name": "Claude Haiku 4.5",
|
||||
"created_at": "2025-10-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": models,
|
||||
"data": claude.DefaultModels,
|
||||
"object": "list",
|
||||
})
|
||||
}
|
||||
@@ -443,3 +434,155 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// CountTokens handles token counting endpoint
|
||||
// POST /v1/messages/count_tokens
|
||||
// 特点:校验订阅/余额,但不计算并发、不记录使用量
|
||||
func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求获取模型名
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取订阅信息(可能为nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
// 校验 billing eligibility(订阅/余额)
|
||||
// 【注意】不计算并发,但需要校验订阅/余额
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 计算粘性会话 hash
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求(不记录使用量)
|
||||
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
|
||||
log.Printf("Forward count_tokens request failed: %v", err)
|
||||
// 错误响应已在 ForwardCountTokens 中处理
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等)
|
||||
func isWarmupRequest(body []byte) bool {
|
||||
// 快速检查:如果body不包含关键字,直接返回false
|
||||
bodyStr := string(body)
|
||||
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 解析完整请求
|
||||
var req struct {
|
||||
Messages []struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
} `json:"messages"`
|
||||
System []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"system"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查 messages 中的标题提示模式
|
||||
for _, msg := range req.Messages {
|
||||
for _, content := range msg.Content {
|
||||
if content.Type == "text" {
|
||||
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||||
content.Text == "Warmup" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 system 中的标题提取模式
|
||||
for _, system := range req.System {
|
||||
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
events := []string{
|
||||
`event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`,
|
||||
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
|
||||
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
|
||||
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
_, _ = c.Writer.WriteString(event + "\n\n")
|
||||
c.Writer.Flush()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": "msg_mock_warmup",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||
"stop_reason": "end_turn",
|
||||
"usage": gin.H{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 2,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,10 +2,6 @@ package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/handler/admin"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// AdminHandlers contains all admin-related HTTP handlers
|
||||
@@ -41,30 +37,3 @@ type BuildInfo struct {
|
||||
Version string
|
||||
BuildType string // "source" for manual builds, "release" for CI builds
|
||||
}
|
||||
|
||||
// NewHandlers creates a new Handlers instance with all handlers initialized
|
||||
func NewHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo BuildInfo) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: NewAuthHandler(services.Auth),
|
||||
User: NewUserHandler(services.User),
|
||||
APIKey: NewAPIKeyHandler(services.ApiKey),
|
||||
Usage: NewUsageHandler(services.Usage, repos.UsageLog, services.ApiKey),
|
||||
Redeem: NewRedeemHandler(services.Redeem),
|
||||
Subscription: NewSubscriptionHandler(services.Subscription),
|
||||
Admin: &AdminHandlers{
|
||||
Dashboard: admin.NewDashboardHandler(services.Admin, repos.UsageLog),
|
||||
User: admin.NewUserHandler(services.Admin),
|
||||
Group: admin.NewGroupHandler(services.Admin),
|
||||
Account: admin.NewAccountHandler(services.Admin, services.OAuth, services.RateLimit, services.AccountUsage, services.AccountTest),
|
||||
OAuth: admin.NewOAuthHandler(services.OAuth, services.Admin),
|
||||
Proxy: admin.NewProxyHandler(services.Admin),
|
||||
Redeem: admin.NewRedeemHandler(services.Admin),
|
||||
Setting: admin.NewSettingHandler(services.Setting, services.Email),
|
||||
System: admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType),
|
||||
Subscription: admin.NewSubscriptionHandler(services.Subscription),
|
||||
Usage: admin.NewUsageHandler(repos.UsageLog, repos.ApiKey, services.Usage, services.Admin),
|
||||
},
|
||||
Gateway: NewGatewayHandler(services.Gateway, services.User, services.Concurrency, services.BillingCache),
|
||||
Setting: NewSettingHandler(services.Setting, buildInfo.Version),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
@@ -68,9 +69,9 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
var records []model.UsageLog
|
||||
var result *repository.PaginationResult
|
||||
var result *pagination.PaginationResult
|
||||
var err error
|
||||
|
||||
if apiKeyID > 0 {
|
||||
@@ -362,7 +363,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Verify ownership of all requested API keys
|
||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000})
|
||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to verify API key ownership")
|
||||
return
|
||||
|
||||
102
backend/internal/handler/wire.go
Normal file
102
backend/internal/handler/wire.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"sub2api/internal/handler/admin"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// ProvideAdminHandlers creates the AdminHandlers struct
|
||||
func ProvideAdminHandlers(
|
||||
dashboardHandler *admin.DashboardHandler,
|
||||
userHandler *admin.UserHandler,
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
systemHandler *admin.SystemHandler,
|
||||
subscriptionHandler *admin.SubscriptionHandler,
|
||||
usageHandler *admin.UsageHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
User: userHandler,
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
OAuth: oauthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Setting: settingHandler,
|
||||
System: systemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Usage: usageHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
|
||||
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
|
||||
return admin.NewSystemHandler(updateService)
|
||||
}
|
||||
|
||||
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
||||
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
|
||||
return NewSettingHandler(settingService, buildInfo.Version)
|
||||
}
|
||||
|
||||
// ProvideHandlers creates the Handlers struct
|
||||
func ProvideHandlers(
|
||||
authHandler *AuthHandler,
|
||||
userHandler *UserHandler,
|
||||
apiKeyHandler *APIKeyHandler,
|
||||
usageHandler *UsageHandler,
|
||||
redeemHandler *RedeemHandler,
|
||||
subscriptionHandler *SubscriptionHandler,
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
User: userHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
Usage: usageHandler,
|
||||
Redeem: redeemHandler,
|
||||
Subscription: subscriptionHandler,
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
Setting: settingHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all handlers
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Top-level handlers
|
||||
NewAuthHandler,
|
||||
NewUserHandler,
|
||||
NewAPIKeyHandler,
|
||||
NewUsageHandler,
|
||||
NewRedeemHandler,
|
||||
NewSubscriptionHandler,
|
||||
NewGatewayHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
admin.NewDashboardHandler,
|
||||
admin.NewUserHandler,
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewSettingHandler,
|
||||
ProvideSystemHandler,
|
||||
admin.NewSubscriptionHandler,
|
||||
admin.NewUsageHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
ProvideHandlers,
|
||||
)
|
||||
38
backend/internal/infrastructure/database.go
Normal file
38
backend/internal/infrastructure/database.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// InitDB 初始化数据库连接
|
||||
func InitDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
// 初始化时区(在数据库连接之前,确保时区设置正确)
|
||||
if err := timezone.Init(cfg.Timezone); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gormConfig := &gorm.Config{}
|
||||
if cfg.Server.Mode == "debug" {
|
||||
gormConfig.Logger = logger.Default.LogMode(logger.Info)
|
||||
}
|
||||
|
||||
// 使用带时区的 DSN 连接数据库
|
||||
db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 自动迁移(始终执行,确保数据库结构与代码同步)
|
||||
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
|
||||
if err := model.AutoMigrate(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
16
backend/internal/infrastructure/redis.go
Normal file
16
backend/internal/infrastructure/redis.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// InitRedis 初始化 Redis 客户端
|
||||
func InitRedis(cfg *config.Config) *redis.Client {
|
||||
return redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
})
|
||||
}
|
||||
25
backend/internal/infrastructure/wire.go
Normal file
25
backend/internal/infrastructure/wire.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ProviderSet 提供基础设施层的依赖
|
||||
var ProviderSet = wire.NewSet(
|
||||
ProvideDB,
|
||||
ProvideRedis,
|
||||
)
|
||||
|
||||
// ProvideDB 提供数据库连接
|
||||
func ProvideDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
return InitDB(cfg)
|
||||
}
|
||||
|
||||
// ProvideRedis 提供 Redis 客户端
|
||||
func ProvideRedis(cfg *config.Config) *redis.Client {
|
||||
return InitRedis(cfg)
|
||||
}
|
||||
@@ -263,3 +263,17 @@ func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
|
||||
// 启用后,标题生成、Warmup等预热请求将返回mock响应,不消耗上游token
|
||||
func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
74
backend/internal/pkg/claude/constants.go
Normal file
74
backend/internal/pkg/claude/constants.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package claude
|
||||
|
||||
// Claude Code 客户端相关常量
|
||||
|
||||
// Beta header 常量
|
||||
const (
|
||||
BetaOAuth = "oauth-2025-04-20"
|
||||
BetaClaudeCode = "claude-code-20250219"
|
||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||
)
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
// Claude Code 客户端默认请求头
|
||||
var DefaultHeaders = map[string]string{
|
||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
||||
"X-Stainless-Lang": "js",
|
||||
"X-Stainless-Package-Version": "0.52.0",
|
||||
"X-Stainless-OS": "Linux",
|
||||
"X-Stainless-Arch": "x64",
|
||||
"X-Stainless-Runtime": "node",
|
||||
"X-Stainless-Runtime-Version": "v22.14.0",
|
||||
"X-Stainless-Retry-Count": "0",
|
||||
"X-Stainless-Timeout": "60",
|
||||
"X-App": "cli",
|
||||
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
||||
}
|
||||
|
||||
// Model 表示一个 Claude 模型
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// DefaultModels Claude Code 客户端支持的默认模型列表
|
||||
var DefaultModels = []Model{
|
||||
{
|
||||
ID: "claude-opus-4-5-20251101",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Opus 4.5",
|
||||
CreatedAt: "2025-11-01T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Sonnet 4.5",
|
||||
CreatedAt: "2025-09-29T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4-5-20251001",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Haiku 4.5",
|
||||
CreatedAt: "2025-10-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
// DefaultModelIDs 返回默认模型的 ID 列表
|
||||
func DefaultModelIDs() []string {
|
||||
ids := make([]string, len(DefaultModels))
|
||||
for i, m := range DefaultModels {
|
||||
ids[i] = m.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// DefaultTestModel 测试时使用的默认模型
|
||||
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
||||
42
backend/internal/pkg/pagination/pagination.go
Normal file
42
backend/internal/pkg/pagination/pagination.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package pagination
|
||||
|
||||
// PaginationParams 分页参数
|
||||
type PaginationParams struct {
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
PageSize int
|
||||
Pages int
|
||||
}
|
||||
|
||||
// DefaultPagination 默认分页参数
|
||||
func DefaultPagination() PaginationParams {
|
||||
return PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
}
|
||||
|
||||
// Offset 计算偏移量
|
||||
func (p PaginationParams) Offset() int {
|
||||
if p.Page < 1 {
|
||||
p.Page = 1
|
||||
}
|
||||
return (p.Page - 1) * p.PageSize
|
||||
}
|
||||
|
||||
// Limit 获取限制数
|
||||
func (p PaginationParams) Limit() int {
|
||||
if p.PageSize < 1 {
|
||||
return 20
|
||||
}
|
||||
if p.PageSize > 100 {
|
||||
return 100
|
||||
}
|
||||
return p.PageSize
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize in
|
||||
})
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果(与repository.PaginationResult兼容)
|
||||
// PaginationResult 分页结果(与pagination.PaginationResult兼容)
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
|
||||
8
backend/internal/pkg/usagestats/account_stats.go
Normal file
8
backend/internal/pkg/usagestats/account_stats.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package usagestats
|
||||
|
||||
// AccountStats 账号使用统计
|
||||
type AccountStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -47,12 +48,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) {
|
||||
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
|
||||
func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) {
|
||||
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
var accounts []model.Account
|
||||
var total int64
|
||||
|
||||
@@ -94,7 +95,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati
|
||||
pages++
|
||||
}
|
||||
|
||||
return accounts, &PaginationResult{
|
||||
return accounts, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -226,7 +227,7 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"rate_limited_at": now,
|
||||
"rate_limited_at": now,
|
||||
"rate_limit_reset_at": resetAt,
|
||||
}).Error
|
||||
}
|
||||
|
||||
51
backend/internal/repository/api_key_cache.go
Normal file
51
backend/internal/repository/api_key_cache.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return c.rdb.Incr(ctx, apiKey).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -45,7 +46,7 @@ func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
|
||||
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
var total int64
|
||||
|
||||
@@ -64,7 +65,7 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
pages++
|
||||
}
|
||||
|
||||
return keys, &PaginationResult{
|
||||
return keys, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -84,7 +85,7 @@ func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
|
||||
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
var total int64
|
||||
|
||||
@@ -103,7 +104,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
pages++
|
||||
}
|
||||
|
||||
return keys, &PaginationResult{
|
||||
return keys, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
|
||||
174
backend/internal/repository/billing_cache.go
Normal file
174
backend/internal/repository/billing_cache.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
subFieldDailyUsage = "daily_usage"
|
||||
subFieldWeeklyUsage = "weekly_usage"
|
||||
subFieldMonthlyUsage = "monthly_usage"
|
||||
subFieldVersion = "version"
|
||||
)
|
||||
|
||||
var (
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
return 0
|
||||
end
|
||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], newVal)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
updateSubUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type billingCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
|
||||
return &billingCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseFloat(val, 64)
|
||||
}
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return c.parseSubscriptionCache(result)
|
||||
}
|
||||
|
||||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
|
||||
result := &ports.SubscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
return nil, errors.New("invalid cache: missing status")
|
||||
}
|
||||
|
||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err == nil {
|
||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||
}
|
||||
|
||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||
}
|
||||
|
||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||
}
|
||||
|
||||
if versionStr, ok := data[subFieldVersion]; ok {
|
||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||
subFieldVersion: data.Version,
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
235
backend/internal/repository/claude_oauth_service.go
Normal file
235
backend/internal/repository/claude_oauth_service.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/oauth"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type claudeOAuthService struct{}
|
||||
|
||||
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
|
||||
return &claudeOAuthService{}
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
targetURL := "https://claude.ai/api/organizations"
|
||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetSuccessResult(&orgs).
|
||||
Get(targetURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if len(orgs) == 0 {
|
||||
return "", fmt.Errorf("no organizations found")
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
|
||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"response_type": "code",
|
||||
"client_id": oauth.ClientID,
|
||||
"organization_uuid": orgUUID,
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"scope": scope,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var result struct {
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetHeader("Accept-Language", "en-US,en;q=0.9").
|
||||
SetHeader("Cache-Control", "no-cache").
|
||||
SetHeader("Origin", "https://claude.ai").
|
||||
SetHeader("Referer", "https://claude.ai/new").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&result).
|
||||
Post(authURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if result.RedirectURI == "" {
|
||||
return "", fmt.Errorf("no redirect_uri in response")
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(result.RedirectURI)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
|
||||
}
|
||||
|
||||
queryParams := parsedURL.Query()
|
||||
authCode := queryParams.Get("code")
|
||||
responseState := queryParams.Get("state")
|
||||
|
||||
if authCode == "" {
|
||||
return "", fmt.Errorf("no authorization code in redirect_uri")
|
||||
}
|
||||
|
||||
fullCode := authCode
|
||||
if responseState != "" {
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
|
||||
return fullCode, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
|
||||
authCode := code
|
||||
codeState := ""
|
||||
if len(code) > 0 {
|
||||
parts := make([]string, 0, 2)
|
||||
for i, part := range []rune(code) {
|
||||
if part == '#' {
|
||||
authCode = code[:i]
|
||||
codeState = code[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
authCode = code
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"code": authCode,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauth.ClientID,
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"code_verifier": codeVerifier,
|
||||
}
|
||||
|
||||
if codeState != "" {
|
||||
reqBody["state"] = codeState
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", oauth.ClientID)
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
ImpersonateChrome().
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
64
backend/internal/repository/claude_service.go
Normal file
64
backend/internal/repository/claude_service.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/service"
|
||||
)
|
||||
|
||||
type claudeUpstreamService struct {
|
||||
defaultClient *http.Client
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
|
||||
return &claudeUpstreamService{
|
||||
defaultClient: &http.Client{Transport: transport},
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||
if proxyURL == "" {
|
||||
return s.defaultClient.Do(req)
|
||||
}
|
||||
client := s.createProxyClient(proxyURL)
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (s *claudeUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return s.defaultClient
|
||||
}
|
||||
|
||||
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
|
||||
return &http.Client{Transport: transport}
|
||||
}
|
||||
59
backend/internal/repository/claude_usage_service.go
Normal file
59
backend/internal/repository/claude_usage_service.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
)
|
||||
|
||||
type claudeUsageService struct{}
|
||||
|
||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{}
|
||||
}
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var usageResp service.ClaudeUsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
||||
return nil, fmt.Errorf("decode response failed: %w", err)
|
||||
}
|
||||
|
||||
return &usageResp, nil
|
||||
}
|
||||
132
backend/internal/repository/concurrency_cache.go
Normal file
132
backend/internal/repository/concurrency_cache.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
accountConcurrencyKeyPrefix = "concurrency:account:"
|
||||
userConcurrencyKeyPrefix = "concurrency:user:"
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
concurrencyTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
acquireScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current < tonumber(ARGV[1]) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
releaseScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local waitKey = KEYS[1]
|
||||
local maxWait = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current = redis.call('GET', waitKey)
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current >= maxWait then
|
||||
return 0
|
||||
end
|
||||
redis.call('INCR', waitKey)
|
||||
redis.call('EXPIRE', waitKey, ttl)
|
||||
return 1
|
||||
`)
|
||||
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
|
||||
return &concurrencyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
48
backend/internal/repository/email_cache.go
Normal file
48
backend/internal/repository/email_cache.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const verifyCodeKeyPrefix = "verify_code:"
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
|
||||
return &emailCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data ports.VerificationCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
35
backend/internal/repository/gateway_cache.go
Normal file
35
backend/internal/repository/gateway_cache.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const stickySessionPrefix = "sticky_session:"
|
||||
|
||||
type gatewayCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
|
||||
return &gatewayCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
return c.rdb.Get(ctx, key).Int64()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||
}
|
||||
116
backend/internal/repository/github_release_service.go
Normal file
116
backend/internal/repository/github_release_service.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
)
|
||||
|
||||
type githubReleaseClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewGitHubReleaseClient() service.GitHubReleaseClient {
|
||||
return &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "Sub2API-Updater")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release service.GitHubRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &release, nil
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// SECURITY: Check Content-Length if available
|
||||
if resp.ContentLength > maxSize {
|
||||
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxSize)
|
||||
}
|
||||
|
||||
out, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||
written, err := io.Copy(out, limited)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if we hit the limit (downloaded more than maxSize)
|
||||
if written > maxSize {
|
||||
os.Remove(dest) // Clean up partial file
|
||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -36,12 +37,12 @@ func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
|
||||
}
|
||||
|
||||
func (r *GroupRepository) List(ctx context.Context, params PaginationParams) ([]model.Group, *PaginationResult, error) {
|
||||
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
||||
}
|
||||
|
||||
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
|
||||
func (r *GroupRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *PaginationResult, error) {
|
||||
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
var groups []model.Group
|
||||
var total int64
|
||||
|
||||
@@ -77,7 +78,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params Pagination
|
||||
pages++
|
||||
}
|
||||
|
||||
return groups, &PaginationResult{
|
||||
return groups, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
|
||||
47
backend/internal/repository/identity_cache.go
Normal file
47
backend/internal/repository/identity_cache.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
fingerprintKeyPrefix = "fingerprint:"
|
||||
fingerprintTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
type identityCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache {
|
||||
return &identityCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var fp ports.Fingerprint
|
||||
if err := json.Unmarshal([]byte(val), &fp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fp, nil
|
||||
}
|
||||
|
||||
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
val, err := json.Marshal(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
|
||||
}
|
||||
73
backend/internal/repository/pricing_service.go
Normal file
73
backend/internal/repository/pricing_service.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
)
|
||||
|
||||
type pricingRemoteClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewPricingRemoteClient() service.PricingRemoteClient {
|
||||
return &pricingRemoteClient{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 哈希文件格式:hash filename 或者纯 hash
|
||||
hash := strings.TrimSpace(string(body))
|
||||
parts := strings.Fields(hash)
|
||||
if len(parts) > 0 {
|
||||
return parts[0], nil
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
104
backend/internal/repository/proxy_probe_service.go
Normal file
104
backend/internal/repository/proxy_probe_service.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type proxyProbeService struct{}
|
||||
|
||||
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
|
||||
return &proxyProbeService{}
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
transport, err := createProxyTransport(proxyURL)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 15 * time.Second,
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
latencyMs := time.Since(startTime).Milliseconds()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var ipInfo struct {
|
||||
IP string `json:"ip"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
Country string `json:"country"`
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &service.ProxyExitInfo{
|
||||
IP: ipInfo.IP,
|
||||
City: ipInfo.City,
|
||||
Region: ipInfo.Region,
|
||||
Country: ipInfo.Country,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
|
||||
func createProxyTransport(proxyURL string) (*http.Transport, error) {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https":
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
case "socks5":
|
||||
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -36,12 +37,12 @@ func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) List(ctx context.Context, params PaginationParams) ([]model.Proxy, *PaginationResult, error) {
|
||||
func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
|
||||
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params PaginationParams, protocol, status, search string) ([]model.Proxy, *PaginationResult, error) {
|
||||
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
var proxies []model.Proxy
|
||||
var total int64
|
||||
|
||||
@@ -72,7 +73,7 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params Pagination
|
||||
pages++
|
||||
}
|
||||
|
||||
return proxies, &PaginationResult{
|
||||
return proxies, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
|
||||
49
backend/internal/repository/redeem_cache.go
Normal file
49
backend/internal/repository/redeem_cache.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
|
||||
redeemLockKeyPrefix = "redeem:lock:"
|
||||
redeemRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type redeemCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache {
|
||||
return &redeemCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
|
||||
key := redeemLockKeyPrefix + code
|
||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||
}
|
||||
|
||||
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
|
||||
key := redeemLockKeyPrefix + code
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -46,12 +47,12 @@ func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) List(ctx context.Context, params PaginationParams) ([]model.RedeemCode, *PaginationResult, error) {
|
||||
func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
|
||||
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params PaginationParams, codeType, status, search string) ([]model.RedeemCode, *PaginationResult, error) {
|
||||
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
var codes []model.RedeemCode
|
||||
var total int64
|
||||
|
||||
@@ -82,7 +83,7 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params Pagin
|
||||
pages++
|
||||
}
|
||||
|
||||
return codes, &PaginationResult{
|
||||
return codes, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Repositories 所有仓库的集合
|
||||
type Repositories struct {
|
||||
User *UserRepository
|
||||
@@ -16,59 +12,3 @@ type Repositories struct {
|
||||
Setting *SettingRepository
|
||||
UserSubscription *UserSubscriptionRepository
|
||||
}
|
||||
|
||||
// NewRepositories 创建所有仓库
|
||||
func NewRepositories(db *gorm.DB) *Repositories {
|
||||
return &Repositories{
|
||||
User: NewUserRepository(db),
|
||||
ApiKey: NewApiKeyRepository(db),
|
||||
Group: NewGroupRepository(db),
|
||||
Account: NewAccountRepository(db),
|
||||
Proxy: NewProxyRepository(db),
|
||||
RedeemCode: NewRedeemCodeRepository(db),
|
||||
UsageLog: NewUsageLogRepository(db),
|
||||
Setting: NewSettingRepository(db),
|
||||
UserSubscription: NewUserSubscriptionRepository(db),
|
||||
}
|
||||
}
|
||||
|
||||
// PaginationParams 分页参数
|
||||
type PaginationParams struct {
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
PageSize int
|
||||
Pages int
|
||||
}
|
||||
|
||||
// DefaultPagination 默认分页参数
|
||||
func DefaultPagination() PaginationParams {
|
||||
return PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
}
|
||||
|
||||
// Offset 计算偏移量
|
||||
func (p PaginationParams) Offset() int {
|
||||
if p.Page < 1 {
|
||||
p.Page = 1
|
||||
}
|
||||
return (p.Page - 1) * p.PageSize
|
||||
}
|
||||
|
||||
// Limit 获取限制数
|
||||
func (p PaginationParams) Limit() int {
|
||||
if p.PageSize < 1 {
|
||||
return 20
|
||||
}
|
||||
if p.PageSize > 100 {
|
||||
return 100
|
||||
}
|
||||
return p.PageSize
|
||||
}
|
||||
|
||||
55
backend/internal/repository/turnstile_service.go
Normal file
55
backend/internal/repository/turnstile_service.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
)
|
||||
|
||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
type turnstileVerifier struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||
return &turnstileVerifier{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*service.TurnstileVerifyResponse, error) {
|
||||
formData := url.Values{}
|
||||
formData.Set("secret", secretKey)
|
||||
formData.Set("response", token)
|
||||
if remoteIP != "" {
|
||||
formData.Set("remoteip", remoteIP)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := v.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result service.TurnstileVerifyResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
28
backend/internal/repository/update_cache.go
Normal file
28
backend/internal/repository/update_cache.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const updateCacheKey = "update:latest"
|
||||
|
||||
type updateCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewUpdateCache(rdb *redis.Client) ports.UpdateCache {
|
||||
return &updateCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) {
|
||||
return c.rdb.Get(ctx, updateCacheKey).Result()
|
||||
}
|
||||
|
||||
func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error {
|
||||
return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err()
|
||||
}
|
||||
@@ -3,7 +3,9 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/pkg/usagestats"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -30,7 +32,7 @@ func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.Usag
|
||||
return &log, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -49,7 +51,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &PaginationResult{
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -57,7 +59,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -76,7 +78,7 @@ func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &PaginationResult{
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -270,7 +272,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -289,7 +291,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &PaginationResult{
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -297,7 +299,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
|
||||
@@ -306,7 +308,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
|
||||
@@ -315,7 +317,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||
@@ -324,7 +326,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
|
||||
@@ -337,15 +339,8 @@ func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
|
||||
}
|
||||
|
||||
// AccountStats 账号使用统计
|
||||
type AccountStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
}
|
||||
|
||||
// GetAccountTodayStats 获取账号今日统计
|
||||
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*AccountStats, error) {
|
||||
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
||||
today := timezone.Today()
|
||||
|
||||
var stats struct {
|
||||
@@ -367,7 +362,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountStats{
|
||||
return &usagestats.AccountStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
@@ -375,7 +370,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// GetAccountWindowStats 获取账号时间窗口内的统计
|
||||
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*AccountStats, error) {
|
||||
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||
var stats struct {
|
||||
Requests int64 `gorm:"column:requests"`
|
||||
Tokens int64 `gorm:"column:tokens"`
|
||||
@@ -395,7 +390,7 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountStats{
|
||||
return &usagestats.AccountStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
@@ -436,75 +431,13 @@ type UserUsageTrendPoint struct {
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// GetUsageTrend returns usage trend data grouped by date
|
||||
// granularity: "day" or "hour"
|
||||
func (r *UsageLogRepository) GetUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
|
||||
var results []TrendDataPoint
|
||||
|
||||
// Choose date format based on granularity
|
||||
var dateFormat string
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
} else {
|
||||
dateFormat = "YYYY-MM-DD"
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
TO_CHAR(created_at, ?) as date,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
`, dateFormat).
|
||||
Where("created_at >= ? AND created_at < ?", startTime, endTime).
|
||||
Group("date").
|
||||
Order("date ASC").
|
||||
Scan(&results).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetModelStats returns usage statistics grouped by model
|
||||
func (r *UsageLogRepository) GetModelStats(ctx context.Context, startTime, endTime time.Time) ([]ModelStat, error) {
|
||||
var results []ModelStat
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
model,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
`).
|
||||
Where("created_at >= ? AND created_at < ?", startTime, endTime).
|
||||
Group("model").
|
||||
Order("total_tokens DESC").
|
||||
Scan(&results).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Date string `json:"date"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
@@ -780,7 +713,7 @@ type UsageLogFilters struct {
|
||||
}
|
||||
|
||||
// ListWithFilters lists usage logs with optional filters (for admin)
|
||||
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -816,7 +749,7 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params Paginat
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &PaginationResult{
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -838,7 +771,7 @@ type UsageStats struct {
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
@@ -964,6 +897,76 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
|
||||
func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
|
||||
var results []TrendDataPoint
|
||||
|
||||
var dateFormat string
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
} else {
|
||||
dateFormat = "YYYY-MM-DD"
|
||||
}
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
TO_CHAR(created_at, ?) as date,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
`, dateFormat).
|
||||
Where("created_at >= ? AND created_at < ?", startTime, endTime)
|
||||
|
||||
if userID > 0 {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
}
|
||||
if apiKeyID > 0 {
|
||||
db = db.Where("api_key_id = ?", apiKeyID)
|
||||
}
|
||||
|
||||
err := db.Group("date").Order("date ASC").Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
||||
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]ModelStat, error) {
|
||||
var results []ModelStat
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
model,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
`).
|
||||
Where("created_at >= ? AND created_at < ?", startTime, endTime)
|
||||
|
||||
if userID > 0 {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
}
|
||||
if apiKeyID > 0 {
|
||||
db = db.Where("api_key_id = ?", apiKeyID)
|
||||
}
|
||||
|
||||
err := db.Group("model").Order("total_tokens DESC").Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetGlobalStats gets usage statistics for all users within a time range
|
||||
func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
var stats struct {
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -45,12 +46,12 @@ func (r *UserRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
|
||||
}
|
||||
|
||||
func (r *UserRepository) List(ctx context.Context, params PaginationParams) ([]model.User, *PaginationResult, error) {
|
||||
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists users with optional filtering by status, role, and search query
|
||||
func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationParams, status, role, search string) ([]model.User, *PaginationResult, error) {
|
||||
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
|
||||
var users []model.User
|
||||
var total int64
|
||||
|
||||
@@ -81,7 +82,7 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationP
|
||||
pages++
|
||||
}
|
||||
|
||||
return users, &PaginationResult{
|
||||
return users, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -127,4 +128,3 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
|
||||
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -100,7 +101,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
|
||||
}
|
||||
|
||||
// ListByGroupID 获取分组的所有订阅(分页)
|
||||
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.UserSubscription, *PaginationResult, error) {
|
||||
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []model.UserSubscription
|
||||
var total int64
|
||||
|
||||
@@ -126,7 +127,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
||||
pages++
|
||||
}
|
||||
|
||||
return subs, &PaginationResult{
|
||||
return subs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -135,7 +136,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (r *UserSubscriptionRepository) List(ctx context.Context, params PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *PaginationResult, error) {
|
||||
func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []model.UserSubscription
|
||||
var total int64
|
||||
|
||||
@@ -172,7 +173,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params Pagination
|
||||
pages++
|
||||
}
|
||||
|
||||
return subs, &PaginationResult{
|
||||
return subs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
|
||||
51
backend/internal/repository/wire.go
Normal file
51
backend/internal/repository/wire.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// ProviderSet is the Wire provider set for all repositories
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewUserRepository,
|
||||
NewApiKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewUsageLogRepository,
|
||||
NewSettingRepository,
|
||||
NewUserSubscriptionRepository,
|
||||
wire.Struct(new(Repositories), "*"),
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
NewBillingCache,
|
||||
NewApiKeyCache,
|
||||
NewConcurrencyCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
NewRedeemCache,
|
||||
NewUpdateCache,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
NewPricingRemoteClient,
|
||||
NewGitHubReleaseClient,
|
||||
NewProxyExitInfoProber,
|
||||
NewClaudeUsageFetcher,
|
||||
NewClaudeOAuthClient,
|
||||
NewClaudeUpstream,
|
||||
|
||||
// Bind concrete repositories to service port interfaces
|
||||
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
|
||||
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
|
||||
wire.Bind(new(ports.GroupRepository), new(*GroupRepository)),
|
||||
wire.Bind(new(ports.AccountRepository), new(*AccountRepository)),
|
||||
wire.Bind(new(ports.ProxyRepository), new(*ProxyRepository)),
|
||||
wire.Bind(new(ports.RedeemCodeRepository), new(*RedeemCodeRepository)),
|
||||
wire.Bind(new(ports.UsageLogRepository), new(*UsageLogRepository)),
|
||||
wire.Bind(new(ports.SettingRepository), new(*SettingRepository)),
|
||||
wire.Bind(new(ports.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
|
||||
)
|
||||
45
backend/internal/server/http.go
Normal file
45
backend/internal/server/http.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// ProviderSet 提供服务器层的依赖
|
||||
var ProviderSet = wire.NewSet(
|
||||
ProvideRouter,
|
||||
ProvideHTTPServer,
|
||||
)
|
||||
|
||||
// ProvideRouter 提供路由器
|
||||
func ProvideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
|
||||
return SetupRouter(r, cfg, handlers, services, repos)
|
||||
}
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
|
||||
return &http.Server{
|
||||
Addr: cfg.Server.Address(),
|
||||
Handler: router,
|
||||
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
|
||||
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
|
||||
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
|
||||
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
|
||||
// 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟
|
||||
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
|
||||
}
|
||||
}
|
||||
289
backend/internal/server/router.go
Normal file
289
backend/internal/server/router.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/handler"
|
||||
"sub2api/internal/middleware"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
"sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SetupRouter 配置路由器中间件和路由
|
||||
func SetupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
|
||||
// 应用中间件
|
||||
r.Use(middleware.Logger())
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, services, repos)
|
||||
|
||||
// Serve embedded frontend if available
|
||||
if web.HasEmbeddedFrontend() {
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// registerRoutes 注册所有 HTTP 路由
|
||||
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
|
||||
// 健康检查
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Claude Code 遥测日志(忽略,直接返回200)
|
||||
r.POST("/api/event_logging/batch", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// Setup status endpoint (always returns needs_setup: false in normal mode)
|
||||
// This is used by the frontend to detect when the service has restarted after setup
|
||||
r.GET("/setup/status", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": gin.H{
|
||||
"needs_setup": false,
|
||||
"step": "completed",
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
// API v1
|
||||
v1 := r.Group("/api/v1")
|
||||
{
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
settings := v1.Group("/settings")
|
||||
{
|
||||
settings.GET("/public", h.Setting.GetPublicSettings)
|
||||
}
|
||||
|
||||
// 需要认证的接口
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
|
||||
{
|
||||
// 当前用户信息
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
|
||||
// 用户接口
|
||||
user := authenticated.Group("/user")
|
||||
{
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
}
|
||||
|
||||
// API Key管理
|
||||
keys := authenticated.Group("/keys")
|
||||
{
|
||||
keys.GET("", h.APIKey.List)
|
||||
keys.GET("/:id", h.APIKey.GetByID)
|
||||
keys.POST("", h.APIKey.Create)
|
||||
keys.PUT("/:id", h.APIKey.Update)
|
||||
keys.DELETE("/:id", h.APIKey.Delete)
|
||||
}
|
||||
|
||||
// 用户可用分组(非管理员接口)
|
||||
groups := authenticated.Group("/groups")
|
||||
{
|
||||
groups.GET("/available", h.APIKey.GetAvailableGroups)
|
||||
}
|
||||
|
||||
// 使用记录
|
||||
usage := authenticated.Group("/usage")
|
||||
{
|
||||
usage.GET("", h.Usage.List)
|
||||
usage.GET("/:id", h.Usage.GetByID)
|
||||
usage.GET("/stats", h.Usage.Stats)
|
||||
// User dashboard endpoints
|
||||
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
|
||||
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
|
||||
usage.GET("/dashboard/models", h.Usage.DashboardModels)
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
|
||||
}
|
||||
|
||||
// 卡密兑换
|
||||
redeem := authenticated.Group("/redeem")
|
||||
{
|
||||
redeem.POST("", h.Redeem.Redeem)
|
||||
redeem.GET("/history", h.Redeem.GetHistory)
|
||||
}
|
||||
|
||||
// 用户订阅
|
||||
subscriptions := authenticated.Group("/subscriptions")
|
||||
{
|
||||
subscriptions.GET("", h.Subscription.List)
|
||||
subscriptions.GET("/active", h.Subscription.GetActive)
|
||||
subscriptions.GET("/progress", h.Subscription.GetProgress)
|
||||
subscriptions.GET("/summary", h.Subscription.GetSummary)
|
||||
}
|
||||
}
|
||||
|
||||
// 管理员接口
|
||||
admin := v1.Group("/admin")
|
||||
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
|
||||
{
|
||||
// 仪表盘
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
|
||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
|
||||
}
|
||||
|
||||
// 用户管理
|
||||
users := admin.Group("/users")
|
||||
{
|
||||
users.GET("", h.Admin.User.List)
|
||||
users.GET("/:id", h.Admin.User.GetByID)
|
||||
users.POST("", h.Admin.User.Create)
|
||||
users.PUT("/:id", h.Admin.User.Update)
|
||||
users.DELETE("/:id", h.Admin.User.Delete)
|
||||
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
||||
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||
}
|
||||
|
||||
// 分组管理
|
||||
groups := admin.Group("/groups")
|
||||
{
|
||||
groups.GET("", h.Admin.Group.List)
|
||||
groups.GET("/all", h.Admin.Group.GetAll)
|
||||
groups.GET("/:id", h.Admin.Group.GetByID)
|
||||
groups.POST("", h.Admin.Group.Create)
|
||||
groups.PUT("/:id", h.Admin.Group.Update)
|
||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||
}
|
||||
|
||||
// 账号管理
|
||||
accounts := admin.Group("/accounts")
|
||||
{
|
||||
accounts.GET("", h.Admin.Account.List)
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
|
||||
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
|
||||
// OAuth routes
|
||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
|
||||
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
|
||||
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
|
||||
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
|
||||
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
|
||||
}
|
||||
|
||||
// 代理管理
|
||||
proxies := admin.Group("/proxies")
|
||||
{
|
||||
proxies.GET("", h.Admin.Proxy.List)
|
||||
proxies.GET("/all", h.Admin.Proxy.GetAll)
|
||||
proxies.GET("/:id", h.Admin.Proxy.GetByID)
|
||||
proxies.POST("", h.Admin.Proxy.Create)
|
||||
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
||||
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
|
||||
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
||||
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
||||
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
||||
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
|
||||
}
|
||||
|
||||
// 卡密管理
|
||||
codes := admin.Group("/redeem-codes")
|
||||
{
|
||||
codes.GET("", h.Admin.Redeem.List)
|
||||
codes.GET("/stats", h.Admin.Redeem.GetStats)
|
||||
codes.GET("/export", h.Admin.Redeem.Export)
|
||||
codes.GET("/:id", h.Admin.Redeem.GetByID)
|
||||
codes.POST("/generate", h.Admin.Redeem.Generate)
|
||||
codes.DELETE("/:id", h.Admin.Redeem.Delete)
|
||||
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
|
||||
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
|
||||
}
|
||||
|
||||
// 系统设置
|
||||
adminSettings := admin.Group("/settings")
|
||||
{
|
||||
adminSettings.GET("", h.Admin.Setting.GetSettings)
|
||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
|
||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||
}
|
||||
|
||||
// 系统管理
|
||||
system := admin.Group("/system")
|
||||
{
|
||||
system.GET("/version", h.Admin.System.GetVersion)
|
||||
system.GET("/check-updates", h.Admin.System.CheckUpdates)
|
||||
system.POST("/update", h.Admin.System.PerformUpdate)
|
||||
system.POST("/rollback", h.Admin.System.Rollback)
|
||||
system.POST("/restart", h.Admin.System.RestartService)
|
||||
}
|
||||
|
||||
// 订阅管理
|
||||
subscriptions := admin.Group("/subscriptions")
|
||||
{
|
||||
subscriptions.GET("", h.Admin.Subscription.List)
|
||||
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
|
||||
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
|
||||
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
||||
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
||||
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
||||
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
||||
}
|
||||
|
||||
// 分组下的订阅列表
|
||||
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
|
||||
|
||||
// 用户下的订阅列表
|
||||
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
|
||||
|
||||
// 使用记录管理
|
||||
usage := admin.Group("/usage")
|
||||
{
|
||||
usage.GET("", h.Admin.Usage.List)
|
||||
usage.GET("/stats", h.Admin.Usage.Stats)
|
||||
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
|
||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// API网关(Claude API兼容)
|
||||
gateway := r.Group("/v1")
|
||||
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -41,12 +42,12 @@ type UpdateAccountRequest struct {
|
||||
|
||||
// AccountService 账号管理服务
|
||||
type AccountService struct {
|
||||
accountRepo *repository.AccountRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
accountRepo ports.AccountRepository
|
||||
groupRepo ports.GroupRepository
|
||||
}
|
||||
|
||||
// NewAccountService 创建账号服务实例
|
||||
func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService {
|
||||
func NewAccountService(accountRepo ports.AccountRepository, groupRepo ports.GroupRepository) *AccountService {
|
||||
return &AccountService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -108,7 +109,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
|
||||
}
|
||||
|
||||
// List 获取账号列表
|
||||
func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) {
|
||||
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
accounts, pagination, err := s.accountRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list accounts: %w", err)
|
||||
|
||||
@@ -10,12 +10,12 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testModel = "claude-sonnet-4-5-20250929"
|
||||
)
|
||||
|
||||
// TestEvent represents a SSE event for account testing
|
||||
@@ -37,19 +36,17 @@ type TestEvent struct {
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
repos *repository.Repositories
|
||||
oauthService *OAuthService
|
||||
httpClient *http.Client
|
||||
accountRepo ports.AccountRepository
|
||||
oauthService *OAuthService
|
||||
claudeUpstream ClaudeUpstream
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService {
|
||||
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, claudeUpstream ClaudeUpstream) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
repos: repos,
|
||||
oauthService: oauthService,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
claudeUpstream: claudeUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,10 +59,10 @@ func generateSessionString() string {
|
||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
|
||||
}
|
||||
|
||||
// createTestPayload creates a minimal test request payload for OAuth/Setup Token accounts
|
||||
func createTestPayload() map[string]interface{} {
|
||||
// createTestPayload creates a Claude Code style test request payload
|
||||
func createTestPayload(modelID string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"model": testModel,
|
||||
"model": modelID,
|
||||
"messages": []map[string]interface{}{
|
||||
{
|
||||
"role": "user",
|
||||
@@ -92,45 +89,48 @@ func createTestPayload() map[string]interface{} {
|
||||
"metadata": map[string]string{
|
||||
"user_id": generateSessionString(),
|
||||
},
|
||||
"max_tokens": 1024,
|
||||
"max_tokens": 1024,
|
||||
"temperature": 1,
|
||||
"stream": true,
|
||||
}
|
||||
}
|
||||
|
||||
// createApiKeyTestPayload creates a simpler test request payload for API Key accounts
|
||||
func createApiKeyTestPayload(model string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": []map[string]interface{}{
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi",
|
||||
},
|
||||
},
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"stream": true,
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccountConnection tests an account's connection by sending a test request
|
||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64) error {
|
||||
// All account types use full Claude Code client characteristics, only auth header differs
|
||||
// modelID is optional - if empty, defaults to claude.DefaultTestModel
|
||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get account
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Account not found")
|
||||
}
|
||||
|
||||
// Determine authentication method based on account type
|
||||
// Determine the model to use
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
testModelID = claude.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if mapping != nil && len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine authentication method and API URL
|
||||
var authToken string
|
||||
var authType string // "bearer" for OAuth, "apikey" for API Key
|
||||
var useBearer bool
|
||||
var apiURL string
|
||||
|
||||
if account.IsOAuth() {
|
||||
// OAuth or Setup Token account
|
||||
authType = "bearer"
|
||||
// OAuth or Setup Token - use Bearer token
|
||||
useBearer = true
|
||||
apiURL = testClaudeAPIURL
|
||||
authToken = account.GetCredential("access_token")
|
||||
if authToken == "" {
|
||||
@@ -141,7 +141,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
needRefresh := false
|
||||
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err == nil && time.Now().Unix()+300 > expiresAt { // 5 minute buffer
|
||||
if err == nil && time.Now().Unix()+300 > expiresAt {
|
||||
needRefresh = true
|
||||
}
|
||||
}
|
||||
@@ -154,19 +154,17 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
authToken = tokenInfo.AccessToken
|
||||
}
|
||||
} else if account.Type == "apikey" {
|
||||
// API Key account
|
||||
authType = "apikey"
|
||||
// API Key - use x-api-key header
|
||||
useBearer = false
|
||||
authToken = account.GetCredential("api_key")
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
|
||||
// Get base URL (use default if not set)
|
||||
apiURL = account.GetBaseURL()
|
||||
if apiURL == "" {
|
||||
apiURL = "https://api.anthropic.com"
|
||||
}
|
||||
// Append /v1/messages endpoint
|
||||
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
@@ -179,57 +177,42 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create test request payload
|
||||
var payload map[string]interface{}
|
||||
var actualModel string
|
||||
if authType == "apikey" {
|
||||
// Use simpler payload for API Key (without Claude Code specific fields)
|
||||
// Apply model mapping if configured
|
||||
actualModel = account.GetMappedModel(testModel)
|
||||
payload = createApiKeyTestPayload(actualModel)
|
||||
} else {
|
||||
actualModel = testModel
|
||||
payload = createTestPayload()
|
||||
}
|
||||
// Create Claude Code style payload (same for all account types)
|
||||
payload := createTestPayload(testModelID)
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
// Send test_start event with model info
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: actualModel})
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
|
||||
// Set headers based on auth type
|
||||
// Set common headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
|
||||
|
||||
if authType == "bearer" {
|
||||
// Apply Claude Code client headers
|
||||
for key, value := range claude.DefaultHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Set authentication header
|
||||
if useBearer {
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31,interleaved-thinking-2025-05-14,output-128k-2025-02-19")
|
||||
} else {
|
||||
// API Key uses x-api-key header
|
||||
req.Header.Set("x-api-key", authToken)
|
||||
}
|
||||
|
||||
// Configure proxy if account has one
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL := account.Proxy.URL()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
}
|
||||
}
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
resp, err := s.claudeUpstream.Do(req, proxyURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -252,7 +235,6 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// Stream ended, send complete event
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
@@ -310,5 +292,5 @@ func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||
func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error {
|
||||
log.Printf("Account test error: %s", errorMsg)
|
||||
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
|
||||
return fmt.Errorf(errorMsg)
|
||||
return fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
|
||||
@@ -2,17 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// usageCache 用于缓存usage数据
|
||||
@@ -35,10 +31,10 @@ type WindowStats struct {
|
||||
|
||||
// UsageProgress 使用量进度
|
||||
type UsageProgress struct {
|
||||
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%)
|
||||
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
||||
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%)
|
||||
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
||||
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
@@ -65,21 +61,26 @@ type ClaudeUsageResponse struct {
|
||||
} `json:"seven_day_sonnet"`
|
||||
}
|
||||
|
||||
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
|
||||
type ClaudeUsageFetcher interface {
|
||||
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
|
||||
}
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
type AccountUsageService struct {
|
||||
repos *repository.Repositories
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
oauthService *OAuthService
|
||||
httpClient *http.Client
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService {
|
||||
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
repos: repos,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
oauthService: oauthService,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
usageFetcher: usageFetcher,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,7 +89,7 @@ func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthS
|
||||
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
||||
// API Key账号: 不支持usage查询
|
||||
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account failed: %w", err)
|
||||
}
|
||||
@@ -148,7 +149,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
|
||||
stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
|
||||
return
|
||||
@@ -163,7 +164,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
|
||||
|
||||
// GetTodayStats 获取账号今日统计
|
||||
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
|
||||
stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID)
|
||||
stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get today stats failed: %w", err)
|
||||
}
|
||||
@@ -177,58 +178,23 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
|
||||
|
||||
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
|
||||
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
|
||||
// 获取access token(从credentials中获取)
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
}
|
||||
|
||||
// 获取代理配置
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL := account.Proxy.URL()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
}
|
||||
}
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
||||
usageResp, err := s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var usageResp ClaudeUsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
||||
return nil, fmt.Errorf("decode response failed: %w", err)
|
||||
}
|
||||
|
||||
// 转换为UsageInfo
|
||||
now := time.Now()
|
||||
return s.buildUsageInfo(&usageResp, &now), nil
|
||||
return s.buildUsageInfo(usageResp, &now), nil
|
||||
}
|
||||
|
||||
// parseTime 尝试多种格式解析时间
|
||||
|
||||
@@ -2,20 +2,14 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -177,44 +171,63 @@ type ProxyTestResult struct {
|
||||
Country string `json:"country,omitempty"`
|
||||
}
|
||||
|
||||
// ProxyExitInfo represents proxy exit information from ipinfo.io
|
||||
type ProxyExitInfo struct {
|
||||
IP string
|
||||
City string
|
||||
Region string
|
||||
Country string
|
||||
}
|
||||
|
||||
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
|
||||
type ProxyExitInfoProber interface {
|
||||
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||
}
|
||||
|
||||
// adminServiceImpl implements AdminService
|
||||
type adminServiceImpl struct {
|
||||
userRepo *repository.UserRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
accountRepo *repository.AccountRepository
|
||||
proxyRepo *repository.ProxyRepository
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
redeemCodeRepo *repository.RedeemCodeRepository
|
||||
usageLogRepo *repository.UsageLogRepository
|
||||
userSubRepo *repository.UserSubscriptionRepository
|
||||
userRepo ports.UserRepository
|
||||
groupRepo ports.GroupRepository
|
||||
accountRepo ports.AccountRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
apiKeyRepo ports.ApiKeyRepository
|
||||
redeemCodeRepo ports.RedeemCodeRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
func NewAdminService(repos *repository.Repositories) AdminService {
|
||||
func NewAdminService(
|
||||
userRepo ports.UserRepository,
|
||||
groupRepo ports.GroupRepository,
|
||||
accountRepo ports.AccountRepository,
|
||||
proxyRepo ports.ProxyRepository,
|
||||
apiKeyRepo ports.ApiKeyRepository,
|
||||
redeemCodeRepo ports.RedeemCodeRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: repos.User,
|
||||
groupRepo: repos.Group,
|
||||
accountRepo: repos.Account,
|
||||
proxyRepo: repos.Proxy,
|
||||
apiKeyRepo: repos.ApiKey,
|
||||
redeemCodeRepo: repos.RedeemCode,
|
||||
usageLogRepo: repos.UsageLog,
|
||||
userSubRepo: repos.UserSubscription,
|
||||
}
|
||||
}
|
||||
|
||||
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
|
||||
// 注意:AdminService是接口,需要类型断言
|
||||
func SetAdminServiceBillingCache(adminService AdminService, billingCacheService *BillingCacheService) {
|
||||
if impl, ok := adminService.(*adminServiceImpl); ok {
|
||||
impl.billingCacheService = billingCacheService
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
}
|
||||
}
|
||||
|
||||
// User management implementations
|
||||
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -383,7 +396,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -404,7 +417,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
|
||||
// Group management implementations
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -575,7 +588,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -585,7 +598,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
|
||||
// Account management implementations
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -703,7 +716,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
|
||||
|
||||
// Proxy management implementations
|
||||
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -788,7 +801,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
|
||||
|
||||
// Redeem code management implementations
|
||||
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -872,79 +885,12 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return testProxyConnection(ctx, proxy)
|
||||
}
|
||||
|
||||
// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json
|
||||
func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestResult, error) {
|
||||
proxyURL := proxy.URL()
|
||||
|
||||
// Create HTTP client with proxy
|
||||
transport, err := createProxyTransport(proxyURL)
|
||||
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Failed to create proxy transport: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 15 * time.Second,
|
||||
}
|
||||
|
||||
// Measure latency
|
||||
startTime := time.Now()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Failed to create request: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Proxy connection failed: %v", err),
|
||||
}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
latencyMs := time.Since(startTime).Milliseconds()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Request failed with status: %d", resp.StatusCode),
|
||||
LatencyMs: latencyMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Parse ipinfo.io response
|
||||
var ipInfo struct {
|
||||
IP string `json:"ip"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
Country string `json:"country"`
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: true,
|
||||
Message: "Proxy is accessible but failed to read response",
|
||||
LatencyMs: latencyMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: true,
|
||||
Message: "Proxy is accessible but failed to parse response",
|
||||
LatencyMs: latencyMs,
|
||||
Message: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -952,38 +898,9 @@ func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestRes
|
||||
Success: true,
|
||||
Message: "Proxy is accessible",
|
||||
LatencyMs: latencyMs,
|
||||
IPAddress: ipInfo.IP,
|
||||
City: ipInfo.City,
|
||||
Region: ipInfo.Region,
|
||||
Country: ipInfo.Country,
|
||||
IPAddress: exitInfo.IP,
|
||||
City: exitInfo.City,
|
||||
Region: exitInfo.Region,
|
||||
Country: exitInfo.Country,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createProxyTransport creates an HTTP transport with the given proxy URL
|
||||
func createProxyTransport(proxyURL string) (*http.Transport, error) {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https":
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
case "socks5":
|
||||
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
@@ -8,8 +8,9 @@ import (
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -17,18 +18,16 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApiKeyNotFound = errors.New("api key not found")
|
||||
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
|
||||
ErrApiKeyExists = errors.New("api key already exists")
|
||||
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
|
||||
ErrApiKeyNotFound = errors.New("api key not found")
|
||||
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
|
||||
ErrApiKeyExists = errors.New("api key already exists")
|
||||
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
apiKeyRateLimitDuration = time.Hour
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
)
|
||||
|
||||
// CreateApiKeyRequest 创建API Key请求
|
||||
@@ -47,21 +46,21 @@ type UpdateApiKeyRequest struct {
|
||||
|
||||
// ApiKeyService API Key服务
|
||||
type ApiKeyService struct {
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
userRepo *repository.UserRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
userSubRepo *repository.UserSubscriptionRepository
|
||||
rdb *redis.Client
|
||||
apiKeyRepo ports.ApiKeyRepository
|
||||
userRepo ports.UserRepository
|
||||
groupRepo ports.GroupRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
cache ports.ApiKeyCache
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewApiKeyService 创建API Key服务实例
|
||||
func NewApiKeyService(
|
||||
apiKeyRepo *repository.ApiKeyRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
groupRepo *repository.GroupRepository,
|
||||
userSubRepo *repository.UserSubscriptionRepository,
|
||||
rdb *redis.Client,
|
||||
apiKeyRepo ports.ApiKeyRepository,
|
||||
userRepo ports.UserRepository,
|
||||
groupRepo ports.GroupRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
cache ports.ApiKeyCache,
|
||||
cfg *config.Config,
|
||||
) *ApiKeyService {
|
||||
return &ApiKeyService{
|
||||
@@ -69,7 +68,7 @@ func NewApiKeyService(
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
@@ -112,13 +111,11 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
|
||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
count, err := s.rdb.Get(ctx, key).Int()
|
||||
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
@@ -133,16 +130,11 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
|
||||
|
||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, _ = pipe.Exec(ctx)
|
||||
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
|
||||
}
|
||||
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
@@ -237,7 +229,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) {
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
@@ -272,7 +264,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
|
||||
}
|
||||
|
||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
||||
if s.rdb != nil {
|
||||
if s.cache != nil {
|
||||
// 这里可以序列化并缓存API Key
|
||||
_ = cacheKey // 使用变量避免未使用错误
|
||||
}
|
||||
@@ -325,9 +317,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
if req.Status != nil {
|
||||
apiKey.Status = *req.Status
|
||||
// 如果状态改变,清除Redis缓存
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
||||
_ = s.rdb.Del(ctx, cacheKey)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,9 +345,8 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
}
|
||||
|
||||
// 清除Redis缓存
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
||||
_ = s.rdb.Del(ctx, cacheKey)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||
@@ -399,13 +389,13 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.rdb != nil {
|
||||
if s.cache != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
|
||||
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
|
||||
return fmt.Errorf("increment usage: %w", err)
|
||||
}
|
||||
// 设置24小时过期
|
||||
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
|
||||
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"log"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -16,13 +16,13 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid email or password")
|
||||
ErrUserNotActive = errors.New("user is not active")
|
||||
ErrEmailExists = errors.New("email already exists")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrTokenExpired = errors.New("token has expired")
|
||||
ErrEmailVerifyRequired = errors.New("email verification is required")
|
||||
ErrRegDisabled = errors.New("registration is currently disabled")
|
||||
ErrInvalidCredentials = errors.New("invalid email or password")
|
||||
ErrUserNotActive = errors.New("user is not active")
|
||||
ErrEmailExists = errors.New("email already exists")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrTokenExpired = errors.New("token has expired")
|
||||
ErrEmailVerifyRequired = errors.New("email verification is required")
|
||||
ErrRegDisabled = errors.New("registration is currently disabled")
|
||||
)
|
||||
|
||||
// JWTClaims JWT载荷数据
|
||||
@@ -35,7 +35,7 @@ type JWTClaims struct {
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
userRepo *repository.UserRepository
|
||||
userRepo ports.UserRepository
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
@@ -44,33 +44,24 @@ type AuthService struct {
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(userRepo *repository.UserRepository, cfg *config.Config) *AuthService {
|
||||
func NewAuthService(
|
||||
userRepo ports.UserRepository,
|
||||
cfg *config.Config,
|
||||
settingService *SettingService,
|
||||
emailService *EmailService,
|
||||
turnstileService *TurnstileService,
|
||||
emailQueueService *EmailQueueService,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
cfg: cfg,
|
||||
userRepo: userRepo,
|
||||
cfg: cfg,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
}
|
||||
}
|
||||
|
||||
// SetSettingService 设置系统设置服务(用于检查注册开关和邮件验证)
|
||||
func (s *AuthService) SetSettingService(settingService *SettingService) {
|
||||
s.settingService = settingService
|
||||
}
|
||||
|
||||
// SetEmailService 设置邮件服务(用于邮件验证)
|
||||
func (s *AuthService) SetEmailService(emailService *EmailService) {
|
||||
s.emailService = emailService
|
||||
}
|
||||
|
||||
// SetTurnstileService 设置Turnstile服务(用于验证码校验)
|
||||
func (s *AuthService) SetTurnstileService(turnstileService *TurnstileService) {
|
||||
s.turnstileService = turnstileService
|
||||
}
|
||||
|
||||
// SetEmailQueueService 设置邮件队列服务(用于异步发送邮件)
|
||||
func (s *AuthService) SetEmailQueueService(emailQueueService *EmailQueueService) {
|
||||
s.emailQueueService = emailQueueService
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "")
|
||||
|
||||
@@ -5,30 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 缓存Key前缀和TTL
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// 订阅缓存Hash字段
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
subFieldDailyUsage = "daily_usage"
|
||||
subFieldWeeklyUsage = "weekly_usage"
|
||||
subFieldMonthlyUsage = "monthly_usage"
|
||||
subFieldVersion = "version"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
@@ -38,35 +18,6 @@ var (
|
||||
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
|
||||
)
|
||||
|
||||
// 预编译的Lua脚本
|
||||
var (
|
||||
// deductBalanceScript: 扣减余额缓存,key不存在则忽略
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
return 0
|
||||
end
|
||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], newVal)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
// updateSubUsageScript: 更新订阅用量缓存,key不存在则忽略
|
||||
updateSubUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
type subscriptionCacheData struct {
|
||||
Status string
|
||||
@@ -80,15 +31,15 @@ type subscriptionCacheData struct {
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
rdb *redis.Client
|
||||
userRepo *repository.UserRepository
|
||||
subRepo *repository.UserSubscriptionRepository
|
||||
cache ports.BillingCache
|
||||
userRepo ports.UserRepository
|
||||
subRepo ports.UserSubscriptionRepository
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserRepository, subRepo *repository.UserSubscriptionRepository) *BillingCacheService {
|
||||
func NewBillingCacheService(cache ports.BillingCache, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
|
||||
return &BillingCacheService{
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
}
|
||||
@@ -100,24 +51,19 @@ func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserReposito
|
||||
|
||||
// GetUserBalance 获取用户余额(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
// Redis不可用,直接查询数据库
|
||||
return s.getUserBalanceFromDB(ctx, userID)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
// 尝试从缓存读取
|
||||
val, err := s.rdb.Get(ctx, key).Result()
|
||||
balance, err := s.cache.GetUserBalance(ctx, userID)
|
||||
if err == nil {
|
||||
balance, parseErr := strconv.ParseFloat(val, 64)
|
||||
if parseErr == nil {
|
||||
return balance, nil
|
||||
}
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
// 缓存未命中或解析错误,从数据库读取
|
||||
balance, err := s.getUserBalanceFromDB(ctx, userID)
|
||||
// 缓存未命中,从数据库读取
|
||||
balance, err = s.getUserBalanceFromDB(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -143,39 +89,28 @@ func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID i
|
||||
|
||||
// setBalanceCache 设置余额缓存
|
||||
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
if err := s.rdb.Set(ctx, key, balance, billingCacheTTL).Err(); err != nil {
|
||||
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
|
||||
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
|
||||
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
// 使用预编译的Lua脚本原子性扣减,如果key不存在则忽略
|
||||
_, err := deductBalanceScript.Run(ctx, s.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
return nil
|
||||
return s.cache.DeductUserBalance(ctx, userID, amount)
|
||||
}
|
||||
|
||||
// InvalidateUserBalance 失效用户余额缓存
|
||||
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
if err := s.rdb.Del(ctx, key).Err(); err != nil {
|
||||
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
|
||||
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
@@ -188,19 +123,14 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID
|
||||
|
||||
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return s.getSubscriptionFromDB(ctx, userID, groupID)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
// 尝试从缓存读取
|
||||
result, err := s.rdb.HGetAll(ctx, key).Result()
|
||||
if err == nil && len(result) > 0 {
|
||||
data, parseErr := s.parseSubscriptionCache(result)
|
||||
if parseErr == nil {
|
||||
return data, nil
|
||||
}
|
||||
cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
if err == nil && cacheData != nil {
|
||||
return s.convertFromPortsData(cacheData), nil
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库读取
|
||||
@@ -219,6 +149,28 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCacheData) *subscriptionCacheData {
|
||||
return &subscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
DailyUsage: data.DailyUsage,
|
||||
WeeklyUsage: data.WeeklyUsage,
|
||||
MonthlyUsage: data.MonthlyUsage,
|
||||
Version: data.Version,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *ports.SubscriptionCacheData {
|
||||
return &ports.SubscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
DailyUsage: data.DailyUsage,
|
||||
WeeklyUsage: data.WeeklyUsage,
|
||||
MonthlyUsage: data.MonthlyUsage,
|
||||
Version: data.Version,
|
||||
}
|
||||
}
|
||||
|
||||
// getSubscriptionFromDB 从数据库获取订阅数据
|
||||
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
@@ -236,90 +188,30 @@ func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseSubscriptionCache 解析订阅缓存数据
|
||||
func (s *BillingCacheService) parseSubscriptionCache(data map[string]string) (*subscriptionCacheData, error) {
|
||||
result := &subscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
return nil, errors.New("invalid cache: missing status")
|
||||
}
|
||||
|
||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err == nil {
|
||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||
}
|
||||
|
||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||
}
|
||||
|
||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||
}
|
||||
|
||||
if versionStr, ok := data[subFieldVersion]; ok {
|
||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// setSubscriptionCache 设置订阅缓存
|
||||
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
|
||||
if s.rdb == nil || data == nil {
|
||||
if s.cache == nil || data == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||
subFieldVersion: data.Version,
|
||||
}
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
|
||||
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
|
||||
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
// 使用预编译的Lua脚本原子性增加用量,如果key不存在则忽略
|
||||
_, err := updateSubUsageScript.Run(ctx, s.rdb, []string{key}, costUSD, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
|
||||
}
|
||||
|
||||
// InvalidateSubscription 失效指定订阅缓存
|
||||
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
if err := s.rdb.Del(ctx, key).Err(); err != nil {
|
||||
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
|
||||
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,22 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
const (
|
||||
// Redis key prefixes
|
||||
accountConcurrencyKey = "concurrency:account:"
|
||||
userConcurrencyKey = "concurrency:user:"
|
||||
userWaitCountKey = "concurrency:wait:"
|
||||
|
||||
// TTL for concurrency keys (auto-release safety net)
|
||||
concurrencyKeyTTL = 10 * time.Minute
|
||||
|
||||
// Wait polling interval
|
||||
waitPollInterval = 100 * time.Millisecond
|
||||
|
||||
@@ -28,70 +19,14 @@ const (
|
||||
defaultExtraWaitSlots = 20
|
||||
)
|
||||
|
||||
// Pre-compiled Lua scripts for better performance
|
||||
var (
|
||||
// acquireScript: increment counter if below max, return 1 if successful
|
||||
acquireScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current < tonumber(ARGV[1]) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
// releaseScript: decrement counter, but don't go below 0
|
||||
releaseScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// incrementWaitScript: increment wait counter if below max, return 1 if successful
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local waitKey = KEYS[1]
|
||||
local maxWait = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current = redis.call('GET', waitKey)
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current >= maxWait then
|
||||
return 0
|
||||
end
|
||||
redis.call('INCR', waitKey)
|
||||
redis.call('EXPIRE', waitKey, ttl)
|
||||
return 1
|
||||
`)
|
||||
|
||||
// decrementWaitScript: decrement wait counter, but don't go below 0
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||||
type ConcurrencyService struct {
|
||||
rdb *redis.Client
|
||||
cache ports.ConcurrencyCache
|
||||
}
|
||||
|
||||
// NewConcurrencyService creates a new ConcurrencyService
|
||||
func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService {
|
||||
return &ConcurrencyService{rdb: rdb}
|
||||
func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
|
||||
return &ConcurrencyService{cache: cache}
|
||||
}
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
@@ -104,20 +39,6 @@ type AcquireResult struct {
|
||||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
||||
}
|
||||
|
||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
||||
}
|
||||
|
||||
// acquireSlot is the core implementation for acquiring a concurrency slot
|
||||
func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
@@ -126,8 +47,7 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Try to acquire immediately
|
||||
acquired, err := s.tryAcquire(ctx, key, maxConcurrency)
|
||||
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -135,64 +55,56 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: s.makeReleaseFunc(key),
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil {
|
||||
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Not acquired, return with Acquired=false
|
||||
// The caller (gateway handler) will handle waiting with ping support
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// tryAcquire attempts to increment the counter if below max
|
||||
// Uses pre-compiled Lua script for atomicity and performance
|
||||
func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) {
|
||||
result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int()
|
||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {}, // no-op
|
||||
}, nil
|
||||
}
|
||||
|
||||
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("acquire slot failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// makeReleaseFunc creates a function to release a concurrency slot
|
||||
func (s *ConcurrencyService) makeReleaseFunc(key string) func() {
|
||||
return func() {
|
||||
// Use background context to ensure release even if original context is cancelled
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil {
|
||||
// Log error but don't panic - TTL will eventually clean up
|
||||
log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err)
|
||||
}
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil {
|
||||
log.Printf("Warning: failed to release user slot for %d: %v", userID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetCurrentCount returns the current concurrency count for debugging/monitoring
|
||||
func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) {
|
||||
val, err := s.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// GetAccountCurrentCount returns current concurrency count for an account
|
||||
func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
}
|
||||
|
||||
// GetUserCurrentCount returns current concurrency count for a user
|
||||
func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
@@ -203,44 +115,36 @@ func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int
|
||||
// Returns true if successful, false if the wait queue is full.
|
||||
// maxWait should be user.Concurrency + defaultExtraWaitSlots
|
||||
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
// Redis not available, allow request
|
||||
return true, nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int()
|
||||
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
|
||||
if err != nil {
|
||||
// On error, allow the request to proceed (fail open)
|
||||
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result == 1, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait queue counter for a user.
|
||||
// Should be called when a request completes or exits the wait queue.
|
||||
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
// Use background context to ensure decrement even if original context is cancelled
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil {
|
||||
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
|
||||
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserWaitCount returns current wait queue count for a user
|
||||
func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
}
|
||||
|
||||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||||
func CalculateMaxWait(userConcurrency int) int {
|
||||
|
||||
@@ -4,17 +4,14 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -25,19 +22,11 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
verifyCodeKeyPrefix = "email_verify:"
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
maxVerifyCodeAttempts = 5
|
||||
)
|
||||
|
||||
// verifyCodeData Redis 中存储的验证码数据
|
||||
type verifyCodeData struct {
|
||||
Code string `json:"code"`
|
||||
Attempts int `json:"attempts"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SmtpConfig SMTP配置
|
||||
type SmtpConfig struct {
|
||||
Host string
|
||||
@@ -51,15 +40,15 @@ type SmtpConfig struct {
|
||||
|
||||
// EmailService 邮件服务
|
||||
type EmailService struct {
|
||||
settingRepo *repository.SettingRepository
|
||||
rdb *redis.Client
|
||||
settingRepo ports.SettingRepository
|
||||
cache ports.EmailCache
|
||||
}
|
||||
|
||||
// NewEmailService 创建邮件服务实例
|
||||
func NewEmailService(settingRepo *repository.SettingRepository, rdb *redis.Client) *EmailService {
|
||||
func NewEmailService(settingRepo ports.SettingRepository, cache ports.EmailCache) *EmailService {
|
||||
return &EmailService{
|
||||
settingRepo: settingRepo,
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,10 +190,8 @@ func (s *EmailService) GenerateVerifyCode() (string, error) {
|
||||
|
||||
// SendVerifyCode 发送验证码邮件
|
||||
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
|
||||
// 检查是否在冷却期内
|
||||
existing, err := s.getVerifyCodeData(ctx, key)
|
||||
existing, err := s.cache.GetVerificationCode(ctx, email)
|
||||
if err == nil && existing != nil {
|
||||
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
|
||||
return ErrVerifyCodeTooFrequent
|
||||
@@ -218,12 +205,12 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
||||
}
|
||||
|
||||
// 保存验证码到 Redis
|
||||
data := &verifyCodeData{
|
||||
data := &ports.VerificationCodeData{
|
||||
Code: code,
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := s.setVerifyCodeData(ctx, key, data); err != nil {
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
return fmt.Errorf("save verify code: %w", err)
|
||||
}
|
||||
|
||||
@@ -241,9 +228,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
||||
|
||||
// VerifyCode 验证验证码
|
||||
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
|
||||
data, err := s.getVerifyCodeData(ctx, key)
|
||||
data, err := s.cache.GetVerificationCode(ctx, email)
|
||||
if err != nil || data == nil {
|
||||
return ErrInvalidVerifyCode
|
||||
}
|
||||
@@ -256,7 +241,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
// 验证码不匹配
|
||||
if data.Code != code {
|
||||
data.Attempts++
|
||||
_ = s.setVerifyCodeData(ctx, key, data)
|
||||
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL)
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
@@ -264,32 +249,10 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
}
|
||||
|
||||
// 验证成功,删除验证码
|
||||
s.rdb.Del(ctx, key)
|
||||
_ = s.cache.DeleteVerificationCode(ctx, email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getVerifyCodeData 从 Redis 获取验证码数据
|
||||
func (s *EmailService) getVerifyCodeData(ctx context.Context, key string) (*verifyCodeData, error) {
|
||||
val, err := s.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data verifyCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// setVerifyCodeData 保存验证码数据到 Redis
|
||||
func (s *EmailService) setVerifyCodeData(ctx context.Context, key string, data *verifyCodeData) error {
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.rdb.Set(ctx, key, val, verifyCodeTTL).Err()
|
||||
}
|
||||
|
||||
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
|
||||
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
||||
return fmt.Sprintf(`
|
||||
|
||||
@@ -12,25 +12,27 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ClaudeUpstream handles HTTP requests to Claude API
|
||||
type ClaudeUpstream interface {
|
||||
Do(req *http.Request, proxyURL string) (*http.Response, error)
|
||||
}
|
||||
|
||||
const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
stickySessionPrefix = "sticky_session:"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
)
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
@@ -76,46 +78,48 @@ type ForwardResult struct {
|
||||
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
repos *repository.Repositories
|
||||
rdb *redis.Client
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
cache ports.GatewayCache
|
||||
cfg *config.Config
|
||||
oauthService *OAuthService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
httpClient *http.Client
|
||||
claudeUpstream ClaudeUpstream
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config, oauthService *OAuthService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService) *GatewayService {
|
||||
// 计算响应头超时时间
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second // 默认5分钟,LLM高负载时可能排队较久
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout, // 等待上游响应头的超时
|
||||
// 注意:不设置整体 Timeout,让流式响应可以无限时间传输
|
||||
}
|
||||
func NewGatewayService(
|
||||
accountRepo ports.AccountRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userRepo ports.UserRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
cache ports.GatewayCache,
|
||||
cfg *config.Config,
|
||||
oauthService *OAuthService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
identityService *IdentityService,
|
||||
claudeUpstream ClaudeUpstream,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
repos: repos,
|
||||
rdb: rdb,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
oauthService: oauthService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
identityService: identityService,
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
// 不设置 Timeout:流式请求可能持续十几分钟
|
||||
// 超时控制由 Transport.ResponseHeaderTimeout 负责(只控制等待响应头)
|
||||
},
|
||||
claudeUpstream: claudeUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,14 +274,14 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
|
||||
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64()
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// 续期粘性会话
|
||||
s.rdb.Expire(ctx, stickySessionPrefix+sessionHash, stickySessionTTL)
|
||||
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -287,9 +291,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
var accounts []model.Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.repos.Account.ListSchedulableByGroupID(ctx, *groupID)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
|
||||
} else {
|
||||
accounts, err = s.repos.Account.ListSchedulable(ctx)
|
||||
accounts, err = s.accountRepo.ListSchedulable(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -327,7 +331,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" {
|
||||
s.rdb.Set(ctx, stickySessionPrefix+sessionHash, selected.ID, stickySessionTTL)
|
||||
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
@@ -352,37 +356,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
|
||||
|
||||
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
expiresAtStr := account.GetCredential("expires_at")
|
||||
|
||||
// 检查是否需要刷新
|
||||
needRefresh := false
|
||||
if expiresAtStr != "" {
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err == nil && time.Now().Unix()+tokenRefreshBuffer > expiresAt {
|
||||
needRefresh = true
|
||||
}
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
if needRefresh || accessToken == "" {
|
||||
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("refresh token failed: %w", err)
|
||||
}
|
||||
|
||||
// 更新账号凭证
|
||||
account.Credentials["access_token"] = tokenInfo.AccessToken
|
||||
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if err := s.repos.Account.Update(ctx, account); err != nil {
|
||||
log.Printf("Failed to update account credentials: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo.AccessToken, "oauth", nil
|
||||
}
|
||||
|
||||
// Token刷新由后台 TokenRefreshService 处理,此处只返回当前token
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
|
||||
@@ -418,48 +395,25 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
}
|
||||
|
||||
// 构建上游请求
|
||||
upstreamResult, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 选择使用的client:如果有代理则使用独立的client,否则使用共享的httpClient
|
||||
httpClient := s.httpClient
|
||||
if upstreamResult.Client != nil {
|
||||
httpClient = upstreamResult.Client
|
||||
// 获取代理URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := httpClient.Do(upstreamResult.Request)
|
||||
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理401错误:刷新token重试
|
||||
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
|
||||
resp.Body.Close()
|
||||
token, tokenType, err = s.forceRefreshToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh failed: %w", err)
|
||||
}
|
||||
upstreamResult, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 重试时也需要使用正确的client
|
||||
httpClient = s.httpClient
|
||||
if upstreamResult.Client != nil {
|
||||
httpClient = upstreamResult.Client
|
||||
}
|
||||
resp, err = httpClient.Do(upstreamResult.Request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retry request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
||||
if resp.StatusCode >= 400 {
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
@@ -491,13 +445,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildUpstreamRequestResult contains the request and optional custom client for proxy
|
||||
type buildUpstreamRequestResult struct {
|
||||
Request *http.Request
|
||||
Client *http.Client // nil means use default s.httpClient
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) {
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
@@ -506,7 +454,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
// OAuth账号:应用统一指纹
|
||||
var fingerprint *Fingerprint
|
||||
var fingerprint *ports.Fingerprint
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
@@ -566,48 +514,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||
}
|
||||
|
||||
// 配置代理 - 创建独立的client避免并发修改共享httpClient
|
||||
var customClient *http.Client
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL := account.Proxy.URL()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
// 计算响应头超时时间(与默认 Transport 保持一致)
|
||||
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
}
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
// 创建独立的client,避免并发时修改共享的s.httpClient.Transport
|
||||
customClient = &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &buildUpstreamRequestResult{
|
||||
Request: req,
|
||||
Client: customClient,
|
||||
}, nil
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// getBetaHeader 处理anthropic-beta header
|
||||
// 对于OAuth账号,需要确保包含oauth-2025-04-20
|
||||
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string {
|
||||
const oauthBeta = "oauth-2025-04-20"
|
||||
const claudeCodeBeta = "claude-code-20250219"
|
||||
|
||||
// 如果客户端传了anthropic-beta
|
||||
if clientBetaHeader != "" {
|
||||
// 已包含oauth beta则直接返回
|
||||
if strings.Contains(clientBetaHeader, oauthBeta) {
|
||||
if strings.Contains(clientBetaHeader, claude.BetaOAuth) {
|
||||
return clientBetaHeader
|
||||
}
|
||||
|
||||
@@ -620,7 +536,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
// 在claude-code-20250219后面插入oauth beta
|
||||
claudeCodeIdx := -1
|
||||
for i, p := range parts {
|
||||
if p == claudeCodeBeta {
|
||||
if p == claude.BetaClaudeCode {
|
||||
claudeCodeIdx = i
|
||||
break
|
||||
}
|
||||
@@ -630,13 +546,13 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
// 在claude-code后面插入
|
||||
newParts := make([]string, 0, len(parts)+1)
|
||||
newParts = append(newParts, parts[:claudeCodeIdx+1]...)
|
||||
newParts = append(newParts, oauthBeta)
|
||||
newParts = append(newParts, claude.BetaOAuth)
|
||||
newParts = append(newParts, parts[claudeCodeIdx+1:]...)
|
||||
return strings.Join(newParts, ",")
|
||||
}
|
||||
|
||||
// 没有claude-code,放在第一位
|
||||
return oauthBeta + "," + clientBetaHeader
|
||||
return claude.BetaOAuth + "," + clientBetaHeader
|
||||
}
|
||||
|
||||
// 客户端没传,根据模型生成
|
||||
@@ -650,29 +566,10 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
|
||||
// haiku模型不需要claude-code beta
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return "oauth-2025-04-20,interleaved-thinking-2025-05-14"
|
||||
return claude.HaikuBetaHeader
|
||||
}
|
||||
|
||||
return "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14"
|
||||
}
|
||||
|
||||
func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
account.Credentials["access_token"] = tokenInfo.AccessToken
|
||||
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if err := s.repos.Account.Update(ctx, account); err != nil {
|
||||
log.Printf("Failed to update account credentials: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo.AccessToken, "oauth", nil
|
||||
return claude.DefaultBetaHeader
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
|
||||
@@ -1000,7 +897,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
if err := s.repos.UsageLog.Create(ctx, usageLog); err != nil {
|
||||
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
|
||||
log.Printf("Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -1008,7 +905,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if cost.TotalCost > 0 {
|
||||
if err := s.repos.UserSubscription.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
@@ -1023,7 +920,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if cost.ActualCost > 0 {
|
||||
if err := s.repos.User.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
@@ -1038,9 +935,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
|
||||
// 更新账号最后使用时间
|
||||
if err := s.repos.Account.UpdateLastUsed(ctx, account.ID); err != nil {
|
||||
if err := s.accountRepo.UpdateLastUsed(ctx, account.ID); err != nil {
|
||||
log.Printf("Update last used failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error {
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
|
||||
mappedModel := account.GetMappedModel(req.Model)
|
||||
if mappedModel != req.Model {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取凭证
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token")
|
||||
return err
|
||||
}
|
||||
|
||||
// 构建上游请求
|
||||
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取代理URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
// 标记账号状态(429/529等)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
// 返回简化的错误响应
|
||||
errMsg := "Upstream request failed"
|
||||
switch resp.StatusCode {
|
||||
case 429:
|
||||
errMsg = "Rate limit exceeded"
|
||||
case 529:
|
||||
errMsg = "Service overloaded"
|
||||
}
|
||||
s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg)
|
||||
return fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 透传成功响应
|
||||
c.Data(resp.StatusCode, "application/json", respBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildCountTokensRequest 构建 count_tokens 上游请求
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
|
||||
// OAuth 账号:应用统一指纹和重写 userID
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err == nil {
|
||||
accountUUID := account.GetExtraString("account_uuid")
|
||||
if accountUUID != "" && fp.ClientID != "" {
|
||||
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置认证头
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
} else {
|
||||
req.Header.Set("x-api-key", token)
|
||||
}
|
||||
|
||||
// 白名单透传 headers
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if allowedHeaders[lowerKey] {
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth 账号:应用指纹到请求头
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if fp != nil {
|
||||
s.identityService.ApplyFingerprint(req, fp)
|
||||
}
|
||||
}
|
||||
|
||||
// 确保必要的 headers 存在
|
||||
if req.Header.Get("Content-Type") == "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// countTokensError 返回 count_tokens 错误响应
|
||||
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -34,11 +35,11 @@ type UpdateGroupRequest struct {
|
||||
|
||||
// GroupService 分组管理服务
|
||||
type GroupService struct {
|
||||
groupRepo *repository.GroupRepository
|
||||
groupRepo ports.GroupRepository
|
||||
}
|
||||
|
||||
// NewGroupService 创建分组服务实例
|
||||
func NewGroupService(groupRepo *repository.GroupRepository) *GroupService {
|
||||
func NewGroupService(groupRepo ports.GroupRepository) *GroupService {
|
||||
return &GroupService{
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
@@ -84,7 +85,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err
|
||||
}
|
||||
|
||||
// List 获取分组列表
|
||||
func (s *GroupService) List(ctx context.Context, params repository.PaginationParams) ([]model.Group, *repository.PaginationResult, error) {
|
||||
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
groups, pagination, err := s.groupRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list groups: %w", err)
|
||||
|
||||
@@ -11,15 +11,10 @@ import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
// Redis key prefix
|
||||
identityFingerprintKey = "identity:fingerprint:"
|
||||
)
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
var (
|
||||
@@ -29,20 +24,8 @@ var (
|
||||
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
||||
)
|
||||
|
||||
// Fingerprint 存储的指纹数据结构
|
||||
type Fingerprint struct {
|
||||
ClientID string `json:"client_id"` // 64位hex客户端ID(首次随机生成)
|
||||
UserAgent string `json:"user_agent"` // User-Agent
|
||||
StainlessLang string `json:"x_stainless_lang"` // x-stainless-lang
|
||||
StainlessPackageVersion string `json:"x_stainless_package_version"` // x-stainless-package-version
|
||||
StainlessOS string `json:"x_stainless_os"` // x-stainless-os
|
||||
StainlessArch string `json:"x_stainless_arch"` // x-stainless-arch
|
||||
StainlessRuntime string `json:"x_stainless_runtime"` // x-stainless-runtime
|
||||
StainlessRuntimeVersion string `json:"x_stainless_runtime_version"` // x-stainless-runtime-version
|
||||
}
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = Fingerprint{
|
||||
var defaultFingerprint = ports.Fingerprint{
|
||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.52.0",
|
||||
@@ -54,39 +37,31 @@ var defaultFingerprint = Fingerprint{
|
||||
|
||||
// IdentityService 管理OAuth账号的请求身份指纹
|
||||
type IdentityService struct {
|
||||
rdb *redis.Client
|
||||
cache ports.IdentityCache
|
||||
}
|
||||
|
||||
// NewIdentityService 创建新的IdentityService
|
||||
func NewIdentityService(rdb *redis.Client) *IdentityService {
|
||||
return &IdentityService{rdb: rdb}
|
||||
func NewIdentityService(cache ports.IdentityCache) *IdentityService {
|
||||
return &IdentityService{cache: cache}
|
||||
}
|
||||
|
||||
// GetOrCreateFingerprint 获取或创建账号的指纹
|
||||
// 如果缓存存在,检测user-agent版本,新版本则更新
|
||||
// 如果缓存不存在,生成随机ClientID并从请求头创建指纹,然后缓存
|
||||
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
|
||||
key := identityFingerprintKey + strconv.FormatInt(accountID, 10)
|
||||
|
||||
// 尝试从Redis获取缓存的指纹
|
||||
data, err := s.rdb.Get(ctx, key).Bytes()
|
||||
if err == nil && len(data) > 0 {
|
||||
// 缓存存在,解析指纹
|
||||
var cached Fingerprint
|
||||
if err := json.Unmarshal(data, &cached); err == nil {
|
||||
// 检查客户端的user-agent是否是更新版本
|
||||
clientUA := headers.Get("User-Agent")
|
||||
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
|
||||
// 更新user-agent
|
||||
cached.UserAgent = clientUA
|
||||
// 保存更新后的指纹
|
||||
if newData, err := json.Marshal(cached); err == nil {
|
||||
s.rdb.Set(ctx, key, newData, 0) // 永不过期
|
||||
}
|
||||
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||
}
|
||||
return &cached, nil
|
||||
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*ports.Fingerprint, error) {
|
||||
// 尝试从缓存获取指纹
|
||||
cached, err := s.cache.GetFingerprint(ctx, accountID)
|
||||
if err == nil && cached != nil {
|
||||
// 检查客户端的user-agent是否是更新版本
|
||||
clientUA := headers.Get("User-Agent")
|
||||
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
|
||||
// 更新user-agent
|
||||
cached.UserAgent = clientUA
|
||||
// 保存更新后的指纹
|
||||
_ = s.cache.SetFingerprint(ctx, accountID, cached)
|
||||
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// 缓存不存在或解析失败,创建新指纹
|
||||
@@ -95,11 +70,9 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
||||
// 生成随机ClientID
|
||||
fp.ClientID = generateClientID()
|
||||
|
||||
// 保存到Redis(永不过期)
|
||||
if data, err := json.Marshal(fp); err == nil {
|
||||
if err := s.rdb.Set(ctx, key, data, 0).Err(); err != nil {
|
||||
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||
}
|
||||
// 保存到缓存(永不过期)
|
||||
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
|
||||
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||
}
|
||||
|
||||
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
|
||||
@@ -107,8 +80,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// createFingerprintFromHeaders 从请求头创建指纹
|
||||
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
|
||||
fp := &Fingerprint{}
|
||||
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *ports.Fingerprint {
|
||||
fp := &ports.Fingerprint{}
|
||||
|
||||
// 获取User-Agent
|
||||
if ua := headers.Get("User-Agent"); ua != "" {
|
||||
@@ -137,7 +110,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
|
||||
}
|
||||
|
||||
// ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头)
|
||||
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
|
||||
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerprint) {
|
||||
if fp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,32 +2,36 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/oauth"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
||||
type ClaudeOAuthClient interface {
|
||||
GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error)
|
||||
GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
|
||||
ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
|
||||
}
|
||||
|
||||
// OAuthService handles OAuth authentication flows
|
||||
type OAuthService struct {
|
||||
sessionStore *oauth.SessionStore
|
||||
proxyRepo *repository.ProxyRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
oauthClient ClaudeOAuthClient
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuth service
|
||||
func NewOAuthService(proxyRepo *repository.ProxyRepository) *OAuthService {
|
||||
func NewOAuthService(proxyRepo ports.ProxyRepository, oauthClient ClaudeOAuthClient) *OAuthService {
|
||||
return &OAuthService{
|
||||
sessionStore: oauth.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
oauthClient: oauthClient,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,177 +214,21 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
|
||||
|
||||
// getOrganizationUUID gets the organization UUID from claude.ai using sessionKey
|
||||
func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
targetURL := "https://claude.ai/api/organizations"
|
||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetSuccessResult(&orgs).
|
||||
Get(targetURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if len(orgs) == 0 {
|
||||
return "", fmt.Errorf("no organizations found")
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
||||
return orgs[0].UUID, nil
|
||||
return s.oauthClient.GetOrganizationUUID(ctx, sessionKey, proxyURL)
|
||||
}
|
||||
|
||||
// getAuthorizationCode gets the authorization code using sessionKey
|
||||
func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||
|
||||
// Build request body - must include organization_uuid as per CRS
|
||||
reqBody := map[string]interface{}{
|
||||
"response_type": "code",
|
||||
"client_id": oauth.ClientID,
|
||||
"organization_uuid": orgUUID, // Required field!
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"scope": scope,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
// Response contains redirect_uri with code, not direct code field
|
||||
var result struct {
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetHeader("Accept-Language", "en-US,en;q=0.9").
|
||||
SetHeader("Cache-Control", "no-cache").
|
||||
SetHeader("Origin", "https://claude.ai").
|
||||
SetHeader("Referer", "https://claude.ai/new").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&result).
|
||||
Post(authURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if result.RedirectURI == "" {
|
||||
return "", fmt.Errorf("no redirect_uri in response")
|
||||
}
|
||||
|
||||
// Parse redirect_uri to extract code and state
|
||||
parsedURL, err := url.Parse(result.RedirectURI)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
|
||||
}
|
||||
|
||||
queryParams := parsedURL.Query()
|
||||
authCode := queryParams.Get("code")
|
||||
responseState := queryParams.Get("state")
|
||||
|
||||
if authCode == "" {
|
||||
return "", fmt.Errorf("no authorization code in redirect_uri")
|
||||
}
|
||||
|
||||
// Combine code with state if present (as CRS does)
|
||||
fullCode := authCode
|
||||
if responseState != "" {
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
|
||||
return fullCode, nil
|
||||
return s.oauthClient.GetAuthorizationCode(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges authorization code for tokens
|
||||
func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
// Parse code#state format if present
|
||||
authCode := code
|
||||
codeState := ""
|
||||
if parts := strings.Split(code, "#"); len(parts) > 1 {
|
||||
authCode = parts[0]
|
||||
codeState = parts[1]
|
||||
}
|
||||
|
||||
// Build JSON body as CRS does (not form data!)
|
||||
reqBody := map[string]interface{}{
|
||||
"code": authCode,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauth.ClientID,
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"code_verifier": codeVerifier,
|
||||
}
|
||||
|
||||
// Add state if present
|
||||
if codeState != "" {
|
||||
reqBody["state"] = codeState
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
|
||||
tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
|
||||
|
||||
tokenInfo := &TokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
@@ -390,7 +238,6 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
|
||||
Scope: tokenResp.Scope,
|
||||
}
|
||||
|
||||
// Extract org_uuid and account_uuid from response
|
||||
if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" {
|
||||
tokenInfo.OrgUUID = tokenResp.Organization.UUID
|
||||
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
|
||||
@@ -405,27 +252,9 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
|
||||
|
||||
// RefreshToken refreshes an OAuth token
|
||||
func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", oauth.ClientID)
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
|
||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenInfo{
|
||||
@@ -455,17 +284,3 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
// createReqClient creates a req client with Chrome impersonation and optional proxy
|
||||
func (s *OAuthService) createReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
// Set proxy if specified
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
35
backend/internal/service/ports/account.go
Normal file
35
backend/internal/service/ports/account.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type AccountRepository interface {
|
||||
Create(ctx context.Context, account *model.Account) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Account, error)
|
||||
Update(ctx context.Context, account *model.Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
ListActive(ctx context.Context) ([]model.Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error)
|
||||
|
||||
UpdateLastUsed(ctx context.Context, id int64) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]model.Account, error)
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
}
|
||||
24
backend/internal/service/ports/api_key.go
Normal file
24
backend/internal/service/ports/api_key.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
Create(ctx context.Context, key *model.ApiKey) error
|
||||
GetByID(ctx context.Context, id int64) (*model.ApiKey, error)
|
||||
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
|
||||
Update(ctx context.Context, key *model.ApiKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
|
||||
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
16
backend/internal/service/ports/api_key_cache.go
Normal file
16
backend/internal/service/ports/api_key_cache.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ApiKeyCache defines cache operations for API key service
|
||||
type ApiKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
|
||||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
}
|
||||
31
backend/internal/service/ports/billing_cache.go
Normal file
31
backend/internal/service/ports/billing_cache.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SubscriptionCacheData represents cached subscription data
|
||||
type SubscriptionCacheData struct {
|
||||
Status string
|
||||
ExpiresAt time.Time
|
||||
DailyUsage float64
|
||||
WeeklyUsage float64
|
||||
MonthlyUsage float64
|
||||
Version int64
|
||||
}
|
||||
|
||||
// BillingCache defines cache operations for billing service
|
||||
type BillingCache interface {
|
||||
// Balance operations
|
||||
GetUserBalance(ctx context.Context, userID int64) (float64, error)
|
||||
SetUserBalance(ctx context.Context, userID int64, balance float64) error
|
||||
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
|
||||
InvalidateUserBalance(ctx context.Context, userID int64) error
|
||||
|
||||
// Subscription operations
|
||||
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
|
||||
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
|
||||
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
|
||||
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
|
||||
}
|
||||
19
backend/internal/service/ports/concurrency_cache.go
Normal file
19
backend/internal/service/ports/concurrency_cache.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package ports
|
||||
|
||||
import "context"
|
||||
|
||||
// ConcurrencyCache defines cache operations for concurrency service
|
||||
type ConcurrencyCache interface {
|
||||
// Slot management
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error)
|
||||
ReleaseUserSlot(ctx context.Context, userID int64) error
|
||||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
// Wait queue
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
}
|
||||
20
backend/internal/service/ports/email_cache.go
Normal file
20
backend/internal/service/ports/email_cache.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// VerificationCodeData represents verification code data
|
||||
type VerificationCodeData struct {
|
||||
Code string
|
||||
Attempts int
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// EmailCache defines cache operations for email service
|
||||
type EmailCache interface {
|
||||
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
|
||||
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
|
||||
DeleteVerificationCode(ctx context.Context, email string) error
|
||||
}
|
||||
13
backend/internal/service/ports/gateway_cache.go
Normal file
13
backend/internal/service/ports/gateway_cache.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GatewayCache defines cache operations for gateway service
|
||||
type GatewayCache interface {
|
||||
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
|
||||
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
|
||||
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
||||
}
|
||||
28
backend/internal/service/ports/group.go
Normal file
28
backend/internal/service/ports/group.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GroupRepository interface {
|
||||
Create(ctx context.Context, group *model.Group) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Group, error)
|
||||
Update(ctx context.Context, group *model.Group) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]model.Group, error)
|
||||
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error)
|
||||
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
||||
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
|
||||
DB() *gorm.DB
|
||||
}
|
||||
21
backend/internal/service/ports/identity_cache.go
Normal file
21
backend/internal/service/ports/identity_cache.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package ports
|
||||
|
||||
import "context"
|
||||
|
||||
// Fingerprint represents account fingerprint data
|
||||
type Fingerprint struct {
|
||||
ClientID string
|
||||
UserAgent string
|
||||
StainlessLang string
|
||||
StainlessPackageVersion string
|
||||
StainlessOS string
|
||||
StainlessArch string
|
||||
StainlessRuntime string
|
||||
StainlessRuntimeVersion string
|
||||
}
|
||||
|
||||
// IdentityCache defines cache operations for identity service
|
||||
type IdentityCache interface {
|
||||
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
|
||||
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
|
||||
}
|
||||
23
backend/internal/service/ports/proxy.go
Normal file
23
backend/internal/service/ports/proxy.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type ProxyRepository interface {
|
||||
Create(ctx context.Context, proxy *model.Proxy) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Proxy, error)
|
||||
Update(ctx context.Context, proxy *model.Proxy) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]model.Proxy, error)
|
||||
ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
|
||||
|
||||
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
|
||||
}
|
||||
15
backend/internal/service/ports/redeem_cache.go
Normal file
15
backend/internal/service/ports/redeem_cache.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RedeemCache defines cache operations for redeem service
|
||||
type RedeemCache interface {
|
||||
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementRedeemAttemptCount(ctx context.Context, userID int64) error
|
||||
|
||||
AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error)
|
||||
ReleaseRedeemLock(ctx context.Context, code string) error
|
||||
}
|
||||
22
backend/internal/service/ports/redeem_code.go
Normal file
22
backend/internal/service/ports/redeem_code.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type RedeemCodeRepository interface {
|
||||
Create(ctx context.Context, code *model.RedeemCode) error
|
||||
CreateBatch(ctx context.Context, codes []model.RedeemCode) error
|
||||
GetByID(ctx context.Context, id int64) (*model.RedeemCode, error)
|
||||
GetByCode(ctx context.Context, code string) (*model.RedeemCode, error)
|
||||
Update(ctx context.Context, code *model.RedeemCode) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
Use(ctx context.Context, id, userID int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error)
|
||||
ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error)
|
||||
}
|
||||
17
backend/internal/service/ports/setting.go
Normal file
17
backend/internal/service/ports/setting.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
)
|
||||
|
||||
type SettingRepository interface {
|
||||
Get(ctx context.Context, key string) (*model.Setting, error)
|
||||
GetValue(ctx context.Context, key string) (string, error)
|
||||
Set(ctx context.Context, key, value string) error
|
||||
GetMultiple(ctx context.Context, keys []string) (map[string]string, error)
|
||||
SetMultiple(ctx context.Context, settings map[string]string) error
|
||||
GetAll(ctx context.Context) (map[string]string, error)
|
||||
Delete(ctx context.Context, key string) error
|
||||
}
|
||||
12
backend/internal/service/ports/update_cache.go
Normal file
12
backend/internal/service/ports/update_cache.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UpdateCache defines cache operations for update service
|
||||
type UpdateCache interface {
|
||||
GetUpdateInfo(ctx context.Context) (string, error)
|
||||
SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error
|
||||
}
|
||||
28
backend/internal/service/ports/usage_log.go
Normal file
28
backend/internal/service/ports/usage_log.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
Create(ctx context.Context, log *model.UsageLog) error
|
||||
GetByID(ctx context.Context, id int64) (*model.UsageLog, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
|
||||
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
|
||||
}
|
||||
25
backend/internal/service/ports/user.go
Normal file
25
backend/internal/service/ports/user.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
Create(ctx context.Context, user *model.User) error
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
GetByEmail(ctx context.Context, email string) (*model.User, error)
|
||||
Update(ctx context.Context, user *model.User) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error)
|
||||
|
||||
UpdateBalance(ctx context.Context, id int64, amount float64) error
|
||||
DeductBalance(ctx context.Context, id int64, amount float64) error
|
||||
UpdateConcurrency(ctx context.Context, id int64, amount int) error
|
||||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||||
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
36
backend/internal/service/ports/user_subscription.go
Normal file
36
backend/internal/service/ports/user_subscription.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type UserSubscriptionRepository interface {
|
||||
Create(ctx context.Context, sub *model.UserSubscription) error
|
||||
GetByID(ctx context.Context, id int64) (*model.UserSubscription, error)
|
||||
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
|
||||
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
|
||||
Update(ctx context.Context, sub *model.UserSubscription) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
|
||||
ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error)
|
||||
|
||||
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
|
||||
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
|
||||
UpdateStatus(ctx context.Context, subscriptionID int64, status string) error
|
||||
UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error
|
||||
|
||||
ActivateWindows(ctx context.Context, id int64, start time.Time) error
|
||||
ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
|
||||
ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
|
||||
ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
|
||||
IncrementUsage(ctx context.Context, id int64, costUSD float64) error
|
||||
|
||||
BatchUpdateExpiredStatus(ctx context.Context) (int64, error)
|
||||
}
|
||||
@@ -1,13 +1,12 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -20,13 +19,19 @@ import (
|
||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||
type LiteLLMModelPricing struct {
|
||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
}
|
||||
|
||||
// PricingRemoteClient 远程价格数据获取接口
|
||||
type PricingRemoteClient interface {
|
||||
FetchPricingJSON(ctx context.Context, url string) ([]byte, error)
|
||||
FetchHashText(ctx context.Context, url string) (string, error)
|
||||
}
|
||||
|
||||
// LiteLLMRawEntry 用于解析原始JSON数据
|
||||
@@ -42,11 +47,12 @@ type LiteLLMRawEntry struct {
|
||||
|
||||
// PricingService 动态价格服务
|
||||
type PricingService struct {
|
||||
cfg *config.Config
|
||||
mu sync.RWMutex
|
||||
pricingData map[string]*LiteLLMModelPricing
|
||||
lastUpdated time.Time
|
||||
localHash string
|
||||
cfg *config.Config
|
||||
remoteClient PricingRemoteClient
|
||||
mu sync.RWMutex
|
||||
pricingData map[string]*LiteLLMModelPricing
|
||||
lastUpdated time.Time
|
||||
localHash string
|
||||
|
||||
// 停止信号
|
||||
stopCh chan struct{}
|
||||
@@ -54,11 +60,12 @@ type PricingService struct {
|
||||
}
|
||||
|
||||
// NewPricingService 创建价格服务
|
||||
func NewPricingService(cfg *config.Config) *PricingService {
|
||||
func NewPricingService(cfg *config.Config, remoteClient PricingRemoteClient) *PricingService {
|
||||
s := &PricingService{
|
||||
cfg: cfg,
|
||||
pricingData: make(map[string]*LiteLLMModelPricing),
|
||||
stopCh: make(chan struct{}),
|
||||
cfg: cfg,
|
||||
remoteClient: remoteClient,
|
||||
pricingData: make(map[string]*LiteLLMModelPricing),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -199,21 +206,13 @@ func (s *PricingService) syncWithRemote() error {
|
||||
func (s *PricingService) downloadPricingData() error {
|
||||
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Get(s.cfg.Pricing.RemoteURL)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download failed: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read response failed: %w", err)
|
||||
}
|
||||
|
||||
// 解析JSON数据(使用灵活的解析方式)
|
||||
data, err := s.parsePricingData(body)
|
||||
@@ -367,29 +366,10 @@ func (s *PricingService) useFallbackPricing() error {
|
||||
|
||||
// fetchRemoteHash 从远程获取哈希值
|
||||
func (s *PricingService) fetchRemoteHash() (string, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Get(s.cfg.Pricing.HashURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 哈希文件格式:hash filename 或者纯 hash
|
||||
hash := strings.TrimSpace(string(body))
|
||||
parts := strings.Fields(hash)
|
||||
if len(parts) > 0 {
|
||||
return parts[0], nil
|
||||
}
|
||||
return hash, nil
|
||||
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
|
||||
}
|
||||
|
||||
// computeFileHash 计算文件哈希
|
||||
@@ -466,14 +446,14 @@ func (s *PricingService) extractBaseName(model string) string {
|
||||
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// Claude模型系列匹配规则
|
||||
familyPatterns := map[string][]string{
|
||||
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
|
||||
"opus-4": {"claude-opus-4", "claude-3-opus"},
|
||||
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
|
||||
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
|
||||
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
|
||||
"sonnet-3": {"claude-3-sonnet"},
|
||||
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
|
||||
"haiku-3": {"claude-3-haiku"},
|
||||
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
|
||||
"opus-4": {"claude-opus-4", "claude-3-opus"},
|
||||
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
|
||||
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
|
||||
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
|
||||
"sonnet-3": {"claude-3-sonnet"},
|
||||
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
|
||||
"haiku-3": {"claude-3-haiku"},
|
||||
}
|
||||
|
||||
// 确定模型属于哪个系列
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -37,11 +38,11 @@ type UpdateProxyRequest struct {
|
||||
|
||||
// ProxyService 代理管理服务
|
||||
type ProxyService struct {
|
||||
proxyRepo *repository.ProxyRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
}
|
||||
|
||||
// NewProxyService 创建代理服务实例
|
||||
func NewProxyService(proxyRepo *repository.ProxyRepository) *ProxyService {
|
||||
func NewProxyService(proxyRepo ports.ProxyRepository) *ProxyService {
|
||||
return &ProxyService{
|
||||
proxyRepo: proxyRepo,
|
||||
}
|
||||
@@ -80,7 +81,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err
|
||||
}
|
||||
|
||||
// List 获取代理列表
|
||||
func (s *ProxyService) List(ctx context.Context, params repository.PaginationParams) ([]model.Proxy, *repository.PaginationResult, error) {
|
||||
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
proxies, pagination, err := s.proxyRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list proxies: %w", err)
|
||||
|
||||
@@ -9,20 +9,20 @@ import (
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// RateLimitService 处理限流和过载状态管理
|
||||
type RateLimitService struct {
|
||||
repos *repository.Repositories
|
||||
cfg *config.Config
|
||||
accountRepo ports.AccountRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewRateLimitService 创建RateLimitService实例
|
||||
func NewRateLimitService(repos *repository.Repositories, cfg *config.Config) *RateLimitService {
|
||||
func NewRateLimitService(accountRepo ports.AccountRepository, cfg *config.Config) *RateLimitService {
|
||||
return &RateLimitService{
|
||||
repos: repos,
|
||||
cfg: cfg,
|
||||
accountRepo: accountRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod
|
||||
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
|
||||
if err := s.repos.Account.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
if resetTimestamp == "" {
|
||||
// 没有重置时间,使用默认5分钟
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
return
|
||||
@@ -88,7 +88,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
if err != nil {
|
||||
log.Printf("Parse reset timestamp failed: %v", err)
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
return
|
||||
@@ -97,7 +97,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
resetAt := time.Unix(ts, 0)
|
||||
|
||||
// 标记限流状态
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
@@ -105,7 +105,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
// 根据重置时间反推5h窗口
|
||||
windowEnd := resetAt
|
||||
windowStart := resetAt.Add(-5 * time.Hour)
|
||||
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account
|
||||
}
|
||||
|
||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||
if err := s.repos.Account.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
@@ -152,13 +152,13 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
|
||||
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
|
||||
}
|
||||
|
||||
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
||||
if status == "allowed" && account.IsRateLimited() {
|
||||
if err := s.repos.Account.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
@@ -166,5 +166,5 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
|
||||
|
||||
// ClearRateLimit 清除账号的限流状态
|
||||
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
|
||||
return s.repos.Account.ClearRateLimit(ctx, accountID)
|
||||
return s.accountRepo.ClearRateLimit(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -8,7 +8,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -25,11 +26,9 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
redeemRateLimitKeyPrefix = "redeem:rate_limit:"
|
||||
redeemLockKeyPrefix = "redeem:lock:"
|
||||
redeemMaxErrorsPerHour = 20
|
||||
redeemRateLimitDuration = time.Hour
|
||||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||||
redeemMaxErrorsPerHour = 20
|
||||
redeemRateLimitDuration = time.Hour
|
||||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||||
)
|
||||
|
||||
// GenerateCodesRequest 生成兑换码请求
|
||||
@@ -49,28 +48,30 @@ type RedeemCodeResponse struct {
|
||||
|
||||
// RedeemService 兑换码服务
|
||||
type RedeemService struct {
|
||||
redeemRepo *repository.RedeemCodeRepository
|
||||
userRepo *repository.UserRepository
|
||||
redeemRepo ports.RedeemCodeRepository
|
||||
userRepo ports.UserRepository
|
||||
subscriptionService *SubscriptionService
|
||||
rdb *redis.Client
|
||||
cache ports.RedeemCache
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
// NewRedeemService 创建兑换码服务实例
|
||||
func NewRedeemService(redeemRepo *repository.RedeemCodeRepository, userRepo *repository.UserRepository, subscriptionService *SubscriptionService, rdb *redis.Client) *RedeemService {
|
||||
func NewRedeemService(
|
||||
redeemRepo ports.RedeemCodeRepository,
|
||||
userRepo ports.UserRepository,
|
||||
subscriptionService *SubscriptionService,
|
||||
cache ports.RedeemCache,
|
||||
billingCacheService *BillingCacheService,
|
||||
) *RedeemService {
|
||||
return &RedeemService{
|
||||
redeemRepo: redeemRepo,
|
||||
userRepo: userRepo,
|
||||
subscriptionService: subscriptionService,
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
}
|
||||
|
||||
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
|
||||
func (s *RedeemService) SetBillingCacheService(billingCacheService *BillingCacheService) {
|
||||
s.billingCacheService = billingCacheService
|
||||
}
|
||||
|
||||
// GenerateRandomCode 生成随机兑换码
|
||||
func (s *RedeemService) GenerateRandomCode() (string, error) {
|
||||
// 生成16字节随机数据
|
||||
@@ -137,13 +138,11 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
|
||||
|
||||
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
|
||||
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
|
||||
count, err := s.rdb.Get(ctx, key).Int()
|
||||
count, err := s.cache.GetRedeemAttemptCount(ctx, userID)
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
@@ -158,27 +157,21 @@ func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64)
|
||||
|
||||
// incrementRedeemErrorCount 增加用户兑换错误计数
|
||||
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||
_, _ = pipe.Exec(ctx)
|
||||
_ = s.cache.IncrementRedeemAttemptCount(ctx, userID)
|
||||
}
|
||||
|
||||
// acquireRedeemLock 尝试获取兑换码的分布式锁
|
||||
// 返回 true 表示获取成功,false 表示锁已被占用
|
||||
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return true // 无 Redis 时降级为不加锁
|
||||
}
|
||||
|
||||
key := redeemLockKeyPrefix + code
|
||||
ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result()
|
||||
ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration)
|
||||
if err != nil {
|
||||
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
|
||||
return true
|
||||
@@ -188,12 +181,11 @@ func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool
|
||||
|
||||
// releaseRedeemLock 释放兑换码的分布式锁
|
||||
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := redeemLockKeyPrefix + code
|
||||
s.rdb.Del(ctx, key)
|
||||
_ = s.cache.ReleaseRedeemLock(ctx, code)
|
||||
}
|
||||
|
||||
// Redeem 使用兑换码
|
||||
@@ -335,7 +327,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede
|
||||
}
|
||||
|
||||
// List 获取兑换码列表(管理员功能)
|
||||
func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) {
|
||||
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
codes, pagination, err := s.redeemRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Services 服务集合容器
|
||||
type Services struct {
|
||||
Auth *AuthService
|
||||
@@ -33,107 +26,6 @@ type Services struct {
|
||||
Subscription *SubscriptionService
|
||||
Concurrency *ConcurrencyService
|
||||
Identity *IdentityService
|
||||
}
|
||||
|
||||
// NewServices 创建所有服务实例
|
||||
func NewServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *Services {
|
||||
// 初始化价格服务
|
||||
pricingService := NewPricingService(cfg)
|
||||
if err := pricingService.Initialize(); err != nil {
|
||||
// 价格服务初始化失败不应阻止启动,使用回退价格
|
||||
println("[Service] Warning: Pricing service initialization failed:", err.Error())
|
||||
}
|
||||
|
||||
// 初始化计费服务(依赖价格服务)
|
||||
billingService := NewBillingService(cfg, pricingService)
|
||||
|
||||
// 初始化其他服务
|
||||
authService := NewAuthService(repos.User, cfg)
|
||||
userService := NewUserService(repos.User, cfg)
|
||||
apiKeyService := NewApiKeyService(repos.ApiKey, repos.User, repos.Group, repos.UserSubscription, rdb, cfg)
|
||||
groupService := NewGroupService(repos.Group)
|
||||
accountService := NewAccountService(repos.Account, repos.Group)
|
||||
proxyService := NewProxyService(repos.Proxy)
|
||||
usageService := NewUsageService(repos.UsageLog, repos.User)
|
||||
|
||||
// 初始化订阅服务 (RedeemService 依赖)
|
||||
subscriptionService := NewSubscriptionService(repos)
|
||||
|
||||
// 初始化兑换服务 (依赖订阅服务)
|
||||
redeemService := NewRedeemService(repos.RedeemCode, repos.User, subscriptionService, rdb)
|
||||
|
||||
// 初始化Admin服务
|
||||
adminService := NewAdminService(repos)
|
||||
|
||||
// 初始化OAuth服务(GatewayService依赖)
|
||||
oauthService := NewOAuthService(repos.Proxy)
|
||||
|
||||
// 初始化限流服务
|
||||
rateLimitService := NewRateLimitService(repos, cfg)
|
||||
|
||||
// 初始化计费缓存服务
|
||||
billingCacheService := NewBillingCacheService(rdb, repos.User, repos.UserSubscription)
|
||||
|
||||
// 初始化账号使用量服务
|
||||
accountUsageService := NewAccountUsageService(repos, oauthService)
|
||||
|
||||
// 初始化账号测试服务
|
||||
accountTestService := NewAccountTestService(repos, oauthService)
|
||||
|
||||
// 初始化身份指纹服务
|
||||
identityService := NewIdentityService(rdb)
|
||||
|
||||
// 初始化Gateway服务
|
||||
gatewayService := NewGatewayService(repos, rdb, cfg, oauthService, billingService, rateLimitService, billingCacheService, identityService)
|
||||
|
||||
// 初始化设置服务
|
||||
settingService := NewSettingService(repos.Setting, cfg)
|
||||
emailService := NewEmailService(repos.Setting, rdb)
|
||||
|
||||
// 初始化邮件队列服务
|
||||
emailQueueService := NewEmailQueueService(emailService, 3)
|
||||
|
||||
// 初始化Turnstile服务
|
||||
turnstileService := NewTurnstileService(settingService)
|
||||
|
||||
// 设置Auth服务的依赖(用于注册开关和邮件验证)
|
||||
authService.SetSettingService(settingService)
|
||||
authService.SetEmailService(emailService)
|
||||
authService.SetTurnstileService(turnstileService)
|
||||
authService.SetEmailQueueService(emailQueueService)
|
||||
|
||||
// 初始化并发控制服务
|
||||
concurrencyService := NewConcurrencyService(rdb)
|
||||
|
||||
// 注入计费缓存服务到需要失效缓存的服务
|
||||
redeemService.SetBillingCacheService(billingCacheService)
|
||||
subscriptionService.SetBillingCacheService(billingCacheService)
|
||||
SetAdminServiceBillingCache(adminService, billingCacheService)
|
||||
|
||||
return &Services{
|
||||
Auth: authService,
|
||||
User: userService,
|
||||
ApiKey: apiKeyService,
|
||||
Group: groupService,
|
||||
Account: accountService,
|
||||
Proxy: proxyService,
|
||||
Redeem: redeemService,
|
||||
Usage: usageService,
|
||||
Pricing: pricingService,
|
||||
Billing: billingService,
|
||||
BillingCache: billingCacheService,
|
||||
Admin: adminService,
|
||||
Gateway: gatewayService,
|
||||
OAuth: oauthService,
|
||||
RateLimit: rateLimitService,
|
||||
AccountUsage: accountUsageService,
|
||||
AccountTest: accountTestService,
|
||||
Setting: settingService,
|
||||
Email: emailService,
|
||||
EmailQueue: emailQueueService,
|
||||
Turnstile: turnstileService,
|
||||
Subscription: subscriptionService,
|
||||
Concurrency: concurrencyService,
|
||||
Identity: identityService,
|
||||
}
|
||||
Update *UpdateService
|
||||
TokenRefresh *TokenRefreshService
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"strconv"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -18,12 +18,12 @@ var (
|
||||
|
||||
// SettingService 系统设置服务
|
||||
type SettingService struct {
|
||||
settingRepo *repository.SettingRepository
|
||||
settingRepo ports.SettingRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewSettingService 创建系统设置服务实例
|
||||
func NewSettingService(settingRepo *repository.SettingRepository, cfg *config.Config) *SettingService {
|
||||
func NewSettingService(settingRepo ports.SettingRepository, cfg *config.Config) *SettingService {
|
||||
return &SettingService{
|
||||
settingRepo: settingRepo,
|
||||
cfg: cfg,
|
||||
|
||||
@@ -7,7 +7,8 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -23,18 +24,18 @@ var (
|
||||
|
||||
// SubscriptionService 订阅服务
|
||||
type SubscriptionService struct {
|
||||
repos *repository.Repositories
|
||||
groupRepo ports.GroupRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
// NewSubscriptionService 创建订阅服务
|
||||
func NewSubscriptionService(repos *repository.Repositories) *SubscriptionService {
|
||||
return &SubscriptionService{repos: repos}
|
||||
}
|
||||
|
||||
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
|
||||
func (s *SubscriptionService) SetBillingCacheService(billingCacheService *BillingCacheService) {
|
||||
s.billingCacheService = billingCacheService
|
||||
func NewSubscriptionService(groupRepo ports.GroupRepository, userSubRepo ports.UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
|
||||
return &SubscriptionService{
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
}
|
||||
|
||||
// AssignSubscriptionInput 分配订阅输入
|
||||
@@ -49,7 +50,7 @@ type AssignSubscriptionInput struct {
|
||||
// AssignSubscription 分配订阅给用户(不允许重复分配)
|
||||
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
@@ -58,7 +59,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
}
|
||||
|
||||
// 检查是否已存在订阅
|
||||
exists, err := s.repos.UserSubscription.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -88,10 +89,11 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
// 如果用户已有同分组的订阅:
|
||||
// - 未过期:从当前过期时间累加天数
|
||||
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
|
||||
//
|
||||
// 如果没有订阅:创建新订阅
|
||||
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
@@ -100,7 +102,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 查询是否已有订阅
|
||||
existingSub, err := s.repos.UserSubscription.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
existingSub, err := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if err != nil {
|
||||
// 不存在记录是正常情况,其他错误需要返回
|
||||
existingSub = nil
|
||||
@@ -125,13 +127,13 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 更新过期时间
|
||||
if err := s.repos.UserSubscription.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
|
||||
if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
|
||||
return nil, false, fmt.Errorf("extend subscription: %w", err)
|
||||
}
|
||||
|
||||
// 如果订阅已过期或被暂停,恢复为active状态
|
||||
if existingSub.Status != model.SubscriptionStatusActive {
|
||||
if err := s.repos.UserSubscription.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
|
||||
return nil, false, fmt.Errorf("update subscription status: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -143,7 +145,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
newNotes += "\n"
|
||||
}
|
||||
newNotes += input.Notes
|
||||
if err := s.repos.UserSubscription.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
// 备注更新失败不影响主流程
|
||||
}
|
||||
}
|
||||
@@ -159,7 +161,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 返回更新后的订阅
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, existingSub.ID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, existingSub.ID)
|
||||
return sub, true, err // true 表示是续期
|
||||
}
|
||||
|
||||
@@ -191,27 +193,27 @@ func (s *SubscriptionService) createSubscription(ctx context.Context, input *Ass
|
||||
|
||||
now := time.Now()
|
||||
sub := &model.UserSubscription{
|
||||
UserID: input.UserID,
|
||||
GroupID: input.GroupID,
|
||||
StartsAt: now,
|
||||
ExpiresAt: now.AddDate(0, 0, validityDays),
|
||||
Status: model.SubscriptionStatusActive,
|
||||
UserID: input.UserID,
|
||||
GroupID: input.GroupID,
|
||||
StartsAt: now,
|
||||
ExpiresAt: now.AddDate(0, 0, validityDays),
|
||||
Status: model.SubscriptionStatusActive,
|
||||
AssignedAt: now,
|
||||
Notes: input.Notes,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Notes: input.Notes,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
// 只有当 AssignedBy > 0 时才设置(0 表示系统分配,如兑换码)
|
||||
if input.AssignedBy > 0 {
|
||||
sub.AssignedBy = &input.AssignedBy
|
||||
}
|
||||
|
||||
if err := s.repos.UserSubscription.Create(ctx, sub); err != nil {
|
||||
if err := s.userSubRepo.Create(ctx, sub); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 重新获取完整订阅信息(包含关联)
|
||||
return s.repos.UserSubscription.GetByID(ctx, sub.ID)
|
||||
return s.userSubRepo.GetByID(ctx, sub.ID)
|
||||
}
|
||||
|
||||
// BulkAssignSubscriptionInput 批量分配订阅输入
|
||||
@@ -225,17 +227,17 @@ type BulkAssignSubscriptionInput struct {
|
||||
|
||||
// BulkAssignResult 批量分配结果
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int
|
||||
FailedCount int
|
||||
SuccessCount int
|
||||
FailedCount int
|
||||
Subscriptions []model.UserSubscription
|
||||
Errors []string
|
||||
Errors []string
|
||||
}
|
||||
|
||||
// BulkAssignSubscription 批量分配订阅
|
||||
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
|
||||
result := &BulkAssignResult{
|
||||
Subscriptions: make([]model.UserSubscription, 0),
|
||||
Errors: make([]string, 0),
|
||||
Errors: make([]string, 0),
|
||||
}
|
||||
|
||||
for _, userID := range input.UserIDs {
|
||||
@@ -261,12 +263,12 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
|
||||
// RevokeSubscription 撤销订阅
|
||||
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
|
||||
// 先获取订阅信息用于失效缓存
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.repos.UserSubscription.Delete(ctx, subscriptionID); err != nil {
|
||||
if err := s.userSubRepo.Delete(ctx, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -285,20 +287,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
||||
|
||||
// ExtendSubscription 延长订阅
|
||||
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
// 计算新的过期时间
|
||||
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
|
||||
if err := s.repos.UserSubscription.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
|
||||
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果订阅已过期,恢复为active状态
|
||||
if sub.Status == model.SubscriptionStatusExpired {
|
||||
if err := s.repos.UserSubscription.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -313,17 +315,17 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
}()
|
||||
}
|
||||
|
||||
return s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
return s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取订阅
|
||||
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.GetByID(ctx, id)
|
||||
return s.userSubRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// GetActiveSubscription 获取用户对特定分组的有效订阅
|
||||
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
sub, err := s.repos.UserSubscription.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
@@ -332,24 +334,29 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
|
||||
|
||||
// ListUserSubscriptions 获取用户的所有订阅
|
||||
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.ListByUserID(ctx, userID)
|
||||
return s.userSubRepo.ListByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// ListActiveUserSubscriptions 获取用户的所有有效订阅
|
||||
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
|
||||
return s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// ListGroupSubscriptions 获取分组的所有订阅
|
||||
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *repository.PaginationResult, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.repos.UserSubscription.ListByGroupID(ctx, groupID, params)
|
||||
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.userSubRepo.ListByGroupID(ctx, groupID, params)
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *repository.PaginationResult, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.repos.UserSubscription.List(ctx, params, userID, groupID, status)
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.userSubRepo.List(ctx, params, userID, groupID, status)
|
||||
}
|
||||
|
||||
// startOfDay 返回给定时间所在日期的零点(保持原时区)
|
||||
func startOfDay(t time.Time) time.Time {
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||
}
|
||||
|
||||
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
|
||||
@@ -358,39 +365,50 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
return s.repos.UserSubscription.ActivateWindows(ctx, sub.ID, now)
|
||||
// 使用当天零点作为窗口起始时间
|
||||
windowStart := startOfDay(time.Now())
|
||||
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
|
||||
}
|
||||
|
||||
// CheckAndResetWindows 检查并重置过期的窗口
|
||||
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error {
|
||||
now := time.Now()
|
||||
// 使用当天零点作为新窗口起始时间
|
||||
windowStart := startOfDay(time.Now())
|
||||
needsInvalidateCache := false
|
||||
|
||||
// 日窗口重置(24小时)
|
||||
if sub.NeedsDailyReset() {
|
||||
if err := s.repos.UserSubscription.ResetDailyUsage(ctx, sub.ID, now); err != nil {
|
||||
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.DailyWindowStart = &now
|
||||
sub.DailyWindowStart = &windowStart
|
||||
sub.DailyUsageUSD = 0
|
||||
needsInvalidateCache = true
|
||||
}
|
||||
|
||||
// 周窗口重置(7天)
|
||||
if sub.NeedsWeeklyReset() {
|
||||
if err := s.repos.UserSubscription.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
|
||||
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.WeeklyWindowStart = &now
|
||||
sub.WeeklyWindowStart = &windowStart
|
||||
sub.WeeklyUsageUSD = 0
|
||||
needsInvalidateCache = true
|
||||
}
|
||||
|
||||
// 月窗口重置(30天)
|
||||
if sub.NeedsMonthlyReset() {
|
||||
if err := s.repos.UserSubscription.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
|
||||
if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, windowStart); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.MonthlyWindowStart = &now
|
||||
sub.MonthlyWindowStart = &windowStart
|
||||
sub.MonthlyUsageUSD = 0
|
||||
needsInvalidateCache = true
|
||||
}
|
||||
|
||||
// 如果有窗口被重置,失效 Redis 缓存以保持一致性
|
||||
if needsInvalidateCache && s.billingCacheService != nil {
|
||||
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -412,15 +430,15 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.U
|
||||
|
||||
// RecordUsage 记录使用量到订阅
|
||||
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
|
||||
return s.repos.UserSubscription.IncrementUsage(ctx, subscriptionID, costUSD)
|
||||
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
|
||||
}
|
||||
|
||||
// SubscriptionProgress 订阅进度
|
||||
type SubscriptionProgress struct {
|
||||
ID int64 `json:"id"`
|
||||
GroupName string `json:"group_name"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
ExpiresInDays int `json:"expires_in_days"`
|
||||
ID int64 `json:"id"`
|
||||
GroupName string `json:"group_name"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
ExpiresInDays int `json:"expires_in_days"`
|
||||
Daily *UsageWindowProgress `json:"daily,omitempty"`
|
||||
Weekly *UsageWindowProgress `json:"weekly,omitempty"`
|
||||
Monthly *UsageWindowProgress `json:"monthly,omitempty"`
|
||||
@@ -428,25 +446,25 @@ type SubscriptionProgress struct {
|
||||
|
||||
// UsageWindowProgress 使用窗口进度
|
||||
type UsageWindowProgress struct {
|
||||
LimitUSD float64 `json:"limit_usd"`
|
||||
UsedUSD float64 `json:"used_usd"`
|
||||
RemainingUSD float64 `json:"remaining_usd"`
|
||||
Percentage float64 `json:"percentage"`
|
||||
WindowStart time.Time `json:"window_start"`
|
||||
ResetsAt time.Time `json:"resets_at"`
|
||||
ResetsInSeconds int64 `json:"resets_in_seconds"`
|
||||
LimitUSD float64 `json:"limit_usd"`
|
||||
UsedUSD float64 `json:"used_usd"`
|
||||
RemainingUSD float64 `json:"remaining_usd"`
|
||||
Percentage float64 `json:"percentage"`
|
||||
WindowStart time.Time `json:"window_start"`
|
||||
ResetsAt time.Time `json:"resets_at"`
|
||||
ResetsInSeconds int64 `json:"resets_in_seconds"`
|
||||
}
|
||||
|
||||
// GetSubscriptionProgress 获取订阅使用进度
|
||||
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
group := sub.Group
|
||||
if group == nil {
|
||||
group, err = s.repos.Group.GetByID(ctx, sub.GroupID)
|
||||
group, err = s.groupRepo.GetByID(ctx, sub.GroupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -464,12 +482,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
|
||||
limit := *group.DailyLimitUSD
|
||||
resetsAt := sub.DailyWindowStart.Add(24 * time.Hour)
|
||||
progress.Daily = &UsageWindowProgress{
|
||||
LimitUSD: limit,
|
||||
UsedUSD: sub.DailyUsageUSD,
|
||||
RemainingUSD: limit - sub.DailyUsageUSD,
|
||||
Percentage: (sub.DailyUsageUSD / limit) * 100,
|
||||
WindowStart: *sub.DailyWindowStart,
|
||||
ResetsAt: resetsAt,
|
||||
LimitUSD: limit,
|
||||
UsedUSD: sub.DailyUsageUSD,
|
||||
RemainingUSD: limit - sub.DailyUsageUSD,
|
||||
Percentage: (sub.DailyUsageUSD / limit) * 100,
|
||||
WindowStart: *sub.DailyWindowStart,
|
||||
ResetsAt: resetsAt,
|
||||
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
|
||||
}
|
||||
if progress.Daily.RemainingUSD < 0 {
|
||||
@@ -488,12 +506,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
|
||||
limit := *group.WeeklyLimitUSD
|
||||
resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour)
|
||||
progress.Weekly = &UsageWindowProgress{
|
||||
LimitUSD: limit,
|
||||
UsedUSD: sub.WeeklyUsageUSD,
|
||||
RemainingUSD: limit - sub.WeeklyUsageUSD,
|
||||
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
|
||||
WindowStart: *sub.WeeklyWindowStart,
|
||||
ResetsAt: resetsAt,
|
||||
LimitUSD: limit,
|
||||
UsedUSD: sub.WeeklyUsageUSD,
|
||||
RemainingUSD: limit - sub.WeeklyUsageUSD,
|
||||
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
|
||||
WindowStart: *sub.WeeklyWindowStart,
|
||||
ResetsAt: resetsAt,
|
||||
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
|
||||
}
|
||||
if progress.Weekly.RemainingUSD < 0 {
|
||||
@@ -512,12 +530,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
|
||||
limit := *group.MonthlyLimitUSD
|
||||
resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour)
|
||||
progress.Monthly = &UsageWindowProgress{
|
||||
LimitUSD: limit,
|
||||
UsedUSD: sub.MonthlyUsageUSD,
|
||||
RemainingUSD: limit - sub.MonthlyUsageUSD,
|
||||
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
|
||||
WindowStart: *sub.MonthlyWindowStart,
|
||||
ResetsAt: resetsAt,
|
||||
LimitUSD: limit,
|
||||
UsedUSD: sub.MonthlyUsageUSD,
|
||||
RemainingUSD: limit - sub.MonthlyUsageUSD,
|
||||
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
|
||||
WindowStart: *sub.MonthlyWindowStart,
|
||||
ResetsAt: resetsAt,
|
||||
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
|
||||
}
|
||||
if progress.Monthly.RemainingUSD < 0 {
|
||||
@@ -536,7 +554,7 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
|
||||
|
||||
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
|
||||
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
|
||||
subs, err := s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
|
||||
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -555,7 +573,7 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
|
||||
|
||||
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
|
||||
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
|
||||
return s.repos.UserSubscription.BatchUpdateExpiredStatus(ctx)
|
||||
return s.userSubRepo.BatchUpdateExpiredStatus(ctx)
|
||||
}
|
||||
|
||||
// ValidateSubscription 验证订阅是否有效
|
||||
@@ -568,7 +586,7 @@ func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *mod
|
||||
}
|
||||
if sub.IsExpired() {
|
||||
// 更新状态
|
||||
_ = s.repos.UserSubscription.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
|
||||
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
|
||||
return ErrSubscriptionExpired
|
||||
}
|
||||
return nil
|
||||
|
||||
185
backend/internal/service/token_refresh_service.go
Normal file
185
backend/internal/service/token_refresh_service.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// TokenRefreshService OAuth token自动刷新服务
|
||||
// 定期检查并刷新即将过期的token
|
||||
type TokenRefreshService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
refreshers []TokenRefresher
|
||||
cfg *config.TokenRefreshConfig
|
||||
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewTokenRefreshService 创建token刷新服务
|
||||
func NewTokenRefreshService(
|
||||
accountRepo ports.AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
s := &TokenRefreshService{
|
||||
accountRepo: accountRepo,
|
||||
cfg: &cfg.TokenRefresh,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 注册平台特定的刷新器
|
||||
s.refreshers = []TokenRefresher{
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
// 未来可以添加其他平台的刷新器:
|
||||
// NewOpenAITokenRefresher(...),
|
||||
// NewGeminiTokenRefresher(...),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start 启动后台刷新服务
|
||||
func (s *TokenRefreshService) Start() {
|
||||
if !s.cfg.Enabled {
|
||||
log.Println("[TokenRefresh] Service disabled by configuration")
|
||||
return
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.refreshLoop()
|
||||
|
||||
log.Printf("[TokenRefresh] Service started (check every %d minutes, refresh %v hours before expiry)",
|
||||
s.cfg.CheckIntervalMinutes, s.cfg.RefreshBeforeExpiryHours)
|
||||
}
|
||||
|
||||
// Stop 停止刷新服务
|
||||
func (s *TokenRefreshService) Stop() {
|
||||
close(s.stopCh)
|
||||
s.wg.Wait()
|
||||
log.Println("[TokenRefresh] Service stopped")
|
||||
}
|
||||
|
||||
// refreshLoop 刷新循环
|
||||
func (s *TokenRefreshService) refreshLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
// 计算检查间隔
|
||||
checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute
|
||||
if checkInterval < time.Minute {
|
||||
checkInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动时立即执行一次检查
|
||||
s.processRefresh()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.processRefresh()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processRefresh 执行一次刷新检查
|
||||
func (s *TokenRefreshService) processRefresh() {
|
||||
ctx := context.Background()
|
||||
|
||||
// 计算刷新窗口
|
||||
refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour))
|
||||
|
||||
// 获取所有active状态的账号
|
||||
accounts, err := s.listActiveAccounts(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to list accounts: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
refreshed, failed := 0, 0
|
||||
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
|
||||
// 遍历所有刷新器,找到能处理此账号的
|
||||
for _, refresher := range s.refreshers {
|
||||
if !refresher.CanRefresh(account) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否需要刷新
|
||||
if !refresher.NeedsRefresh(account, refreshWindow) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 执行刷新
|
||||
if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
|
||||
log.Printf("[TokenRefresh] Account %d (%s) failed: %v", account.ID, account.Name, err)
|
||||
failed++
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Account %d (%s) refreshed successfully", account.ID, account.Name)
|
||||
refreshed++
|
||||
}
|
||||
|
||||
// 每个账号只由一个refresher处理
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if refreshed > 0 || failed > 0 {
|
||||
log.Printf("[TokenRefresh] Cycle complete: %d refreshed, %d failed", refreshed, failed)
|
||||
}
|
||||
}
|
||||
|
||||
// listActiveAccounts 获取所有active状态的账号
|
||||
// 使用ListActive确保刷新所有活跃账号的token(包括临时禁用的)
|
||||
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]model.Account, error) {
|
||||
return s.accountRepo.ListActive(ctx)
|
||||
}
|
||||
|
||||
// refreshWithRetry 带重试的刷新
|
||||
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *model.Account, refresher TokenRefresher) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||
newCredentials, err := refresher.Refresh(ctx, account)
|
||||
if err == nil {
|
||||
// 刷新成功,更新账号credentials
|
||||
account.Credentials = model.JSONB(newCredentials)
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Printf("[TokenRefresh] Account %d attempt %d/%d failed: %v",
|
||||
account.ID, attempt, s.cfg.MaxRetries, err)
|
||||
|
||||
// 如果还有重试机会,等待后重试
|
||||
if attempt < s.cfg.MaxRetries {
|
||||
// 指数退避:2^(attempt-1) * baseSeconds
|
||||
backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1))
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
}
|
||||
|
||||
// 所有重试都失败,标记账号为error状态
|
||||
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
90
backend/internal/service/token_refresher.go
Normal file
90
backend/internal/service/token_refresher.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
)
|
||||
|
||||
// TokenRefresher 定义平台特定的token刷新策略接口
|
||||
// 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini)
|
||||
type TokenRefresher interface {
|
||||
// CanRefresh 检查此刷新器是否能处理指定账号
|
||||
CanRefresh(account *model.Account) bool
|
||||
|
||||
// NeedsRefresh 检查账号的token是否需要刷新
|
||||
NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool
|
||||
|
||||
// Refresh 执行token刷新,返回更新后的credentials
|
||||
// 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段
|
||||
Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
|
||||
type ClaudeTokenRefresher struct {
|
||||
oauthService *OAuthService
|
||||
}
|
||||
|
||||
// NewClaudeTokenRefresher 创建Claude token刷新器
|
||||
func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
|
||||
return &ClaudeTokenRefresher{
|
||||
oauthService: oauthService,
|
||||
}
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 anthropic 平台的 oauth 类型账号
|
||||
// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新
|
||||
func (r *ClaudeTokenRefresher) CanRefresh(account *model.Account) bool {
|
||||
return account.Platform == model.PlatformAnthropic &&
|
||||
account.Type == model.AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
// 基于 expires_at 字段判断是否在刷新窗口内
|
||||
func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
|
||||
expiresAtStr := account.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
expiryTime := time.Unix(expiresAt, 0)
|
||||
return time.Until(expiryTime) < refreshWindow
|
||||
}
|
||||
|
||||
// Refresh 执行token刷新
|
||||
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error) {
|
||||
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留现有credentials中的所有字段
|
||||
newCredentials := make(map[string]interface{})
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
|
||||
// 只更新token相关字段
|
||||
// 注意:expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
@@ -2,14 +2,9 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -19,10 +14,15 @@ var (
|
||||
|
||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
// TurnstileVerifier 验证 Turnstile token 的接口
|
||||
type TurnstileVerifier interface {
|
||||
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
|
||||
}
|
||||
|
||||
// TurnstileService Turnstile 验证服务
|
||||
type TurnstileService struct {
|
||||
settingService *SettingService
|
||||
httpClient *http.Client
|
||||
verifier TurnstileVerifier
|
||||
}
|
||||
|
||||
// TurnstileVerifyResponse Cloudflare Turnstile 验证响应
|
||||
@@ -36,12 +36,10 @@ type TurnstileVerifyResponse struct {
|
||||
}
|
||||
|
||||
// NewTurnstileService 创建 Turnstile 服务实例
|
||||
func NewTurnstileService(settingService *SettingService) *TurnstileService {
|
||||
func NewTurnstileService(settingService *SettingService, verifier TurnstileVerifier) *TurnstileService {
|
||||
return &TurnstileService{
|
||||
settingService: settingService,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
verifier: verifier,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,35 +64,12 @@ func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remote
|
||||
return ErrTurnstileVerificationFailed
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
formData := url.Values{}
|
||||
formData.Set("secret", secretKey)
|
||||
formData.Set("response", token)
|
||||
if remoteIP != "" {
|
||||
formData.Set("remoteip", remoteIP)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
// 发送请求
|
||||
log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP)
|
||||
resp, err := s.httpClient.Do(req)
|
||||
result, err := s.verifier.VerifyToken(ctx, secretKey, token, remoteIP)
|
||||
if err != nil {
|
||||
log.Printf("[Turnstile] Request failed: %v", err)
|
||||
return fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 解析响应
|
||||
var result TurnstileVerifyResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
log.Printf("[Turnstile] Failed to decode response: %v", err)
|
||||
return fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -18,7 +17,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -34,17 +33,26 @@ const (
|
||||
maxDownloadSize = 500 * 1024 * 1024
|
||||
)
|
||||
|
||||
// GitHubReleaseClient 获取 GitHub release 信息的接口
|
||||
type GitHubReleaseClient interface {
|
||||
FetchLatestRelease(ctx context.Context, repo string) (*GitHubRelease, error)
|
||||
DownloadFile(ctx context.Context, url, dest string, maxSize int64) error
|
||||
FetchChecksumFile(ctx context.Context, url string) ([]byte, error)
|
||||
}
|
||||
|
||||
// UpdateService handles software updates
|
||||
type UpdateService struct {
|
||||
rdb *redis.Client
|
||||
cache ports.UpdateCache
|
||||
githubClient GitHubReleaseClient
|
||||
currentVersion string
|
||||
buildType string // "source" for manual builds, "release" for CI builds
|
||||
}
|
||||
|
||||
// NewUpdateService creates a new UpdateService
|
||||
func NewUpdateService(rdb *redis.Client, version, buildType string) *UpdateService {
|
||||
func NewUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, version, buildType string) *UpdateService {
|
||||
return &UpdateService{
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
githubClient: githubClient,
|
||||
currentVersion: version,
|
||||
buildType: buildType,
|
||||
}
|
||||
@@ -260,42 +268,11 @@ func (s *UpdateService) Rollback() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
release, err := s.githubClient.FetchLatestRelease(ctx, githubRepo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "Sub2API-Updater")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return &UpdateInfo{
|
||||
CurrentVersion: s.currentVersion,
|
||||
LatestVersion: s.currentVersion,
|
||||
HasUpdate: false,
|
||||
Warning: "No releases found",
|
||||
BuildType: s.buildType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release GitHubRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
latestVersion := strings.TrimPrefix(release.TagName, "v")
|
||||
|
||||
@@ -325,47 +302,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
|
||||
}
|
||||
|
||||
func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// SECURITY: Check Content-Length if available
|
||||
if resp.ContentLength > maxDownloadSize {
|
||||
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxDownloadSize)
|
||||
}
|
||||
|
||||
out, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxDownloadSize+1)
|
||||
written, err := io.Copy(out, limited)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if we hit the limit (downloaded more than maxDownloadSize)
|
||||
if written > maxDownloadSize {
|
||||
os.Remove(dest) // Clean up partial file
|
||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxDownloadSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
return s.githubClient.DownloadFile(ctx, downloadURL, dest, maxDownloadSize)
|
||||
}
|
||||
|
||||
func (s *UpdateService) getArchiveName() string {
|
||||
@@ -402,20 +339,9 @@ func validateDownloadURL(rawURL string) error {
|
||||
|
||||
func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error {
|
||||
// Download checksums file
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil)
|
||||
checksumData, err := s.githubClient.FetchChecksumFile(ctx, checksumURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to download checksums: %d", resp.StatusCode)
|
||||
return fmt.Errorf("failed to download checksums: %w", err)
|
||||
}
|
||||
|
||||
// Calculate file hash
|
||||
@@ -433,7 +359,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
|
||||
|
||||
// Find expected hash in checksums file
|
||||
fileName := filepath.Base(filePath)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := bufio.NewScanner(strings.NewReader(string(checksumData)))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
parts := strings.Fields(line)
|
||||
@@ -533,7 +459,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
}
|
||||
|
||||
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
||||
data, err := s.rdb.Get(ctx, updateCacheKey).Result()
|
||||
data, err := s.cache.GetUpdateInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -573,7 +499,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(cacheData)
|
||||
s.rdb.Set(ctx, updateCacheKey, data, time.Duration(updateCacheTTL)*time.Second)
|
||||
s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||
}
|
||||
|
||||
// compareVersions compares two semantic versions
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -41,24 +42,24 @@ type CreateUsageLogRequest struct {
|
||||
|
||||
// UsageStats 使用统计
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
}
|
||||
|
||||
// UsageService 使用统计服务
|
||||
type UsageService struct {
|
||||
usageRepo *repository.UsageLogRepository
|
||||
userRepo *repository.UserRepository
|
||||
usageRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
}
|
||||
|
||||
// NewUsageService 创建使用统计服务实例
|
||||
func NewUsageService(usageRepo *repository.UsageLogRepository, userRepo *repository.UserRepository) *UsageService {
|
||||
func NewUsageService(usageRepo ports.UsageLogRepository, userRepo ports.UserRepository) *UsageService {
|
||||
return &UsageService{
|
||||
usageRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
@@ -127,7 +128,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog,
|
||||
}
|
||||
|
||||
// ListByUser 获取用户的使用日志列表
|
||||
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -136,7 +137,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repo
|
||||
}
|
||||
|
||||
// ListByApiKey 获取API Key的使用日志列表
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -145,7 +146,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params
|
||||
}
|
||||
|
||||
// ListByAccount 获取账号的使用日志列表
|
||||
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -233,15 +234,15 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
|
||||
}
|
||||
|
||||
result = append(result, map[string]interface{}{
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_tokens": stats.TotalCacheTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost,
|
||||
"total_actual_cost": stats.TotalActualCost,
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_tokens": stats.TotalCacheTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost,
|
||||
"total_actual_cost": stats.TotalActualCost,
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,16 +6,17 @@ import (
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrPasswordIncorrect = errors.New("current password is incorrect")
|
||||
ErrInsufficientPerms = errors.New("insufficient permissions")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrPasswordIncorrect = errors.New("current password is incorrect")
|
||||
ErrInsufficientPerms = errors.New("insufficient permissions")
|
||||
)
|
||||
|
||||
// UpdateProfileRequest 更新用户资料请求
|
||||
@@ -32,12 +33,12 @@ type ChangePasswordRequest struct {
|
||||
|
||||
// UserService 用户服务
|
||||
type UserService struct {
|
||||
userRepo *repository.UserRepository
|
||||
userRepo ports.UserRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务实例
|
||||
func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService {
|
||||
func NewUserService(userRepo ports.UserRepository, cfg *config.Config) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
cfg: cfg,
|
||||
@@ -133,7 +134,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error
|
||||
}
|
||||
|
||||
// List 获取用户列表(管理员功能)
|
||||
func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) {
|
||||
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
|
||||
users, pagination, err := s.userRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list users: %w", err)
|
||||
|
||||
79
backend/internal/service/wire.go
Normal file
79
backend/internal/service/wire.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// BuildInfo contains build information
|
||||
type BuildInfo struct {
|
||||
Version string
|
||||
BuildType string
|
||||
}
|
||||
|
||||
// ProvidePricingService creates and initializes PricingService
|
||||
func ProvidePricingService(cfg *config.Config, remoteClient PricingRemoteClient) (*PricingService, error) {
|
||||
svc := NewPricingService(cfg, remoteClient)
|
||||
if err := svc.Initialize(); err != nil {
|
||||
// 价格服务初始化失败不应阻止启动,使用回退价格
|
||||
println("[Service] Warning: Pricing service initialization failed:", err.Error())
|
||||
}
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// ProvideUpdateService creates UpdateService with BuildInfo
|
||||
func ProvideUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, buildInfo BuildInfo) *UpdateService {
|
||||
return NewUpdateService(cache, githubClient, buildInfo.Version, buildInfo.BuildType)
|
||||
}
|
||||
|
||||
// ProvideEmailQueueService creates EmailQueueService with default worker count
|
||||
func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
||||
return NewEmailQueueService(emailService, 3)
|
||||
}
|
||||
|
||||
// ProvideTokenRefreshService creates and starts TokenRefreshService
|
||||
func ProvideTokenRefreshService(
|
||||
accountRepo ports.AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all services
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
NewAuthService,
|
||||
NewUserService,
|
||||
NewApiKeyService,
|
||||
NewGroupService,
|
||||
NewAccountService,
|
||||
NewProxyService,
|
||||
NewRedeemService,
|
||||
NewUsageService,
|
||||
ProvidePricingService,
|
||||
NewBillingService,
|
||||
NewBillingCacheService,
|
||||
NewAdminService,
|
||||
NewGatewayService,
|
||||
NewOAuthService,
|
||||
NewRateLimitService,
|
||||
NewAccountUsageService,
|
||||
NewAccountTestService,
|
||||
NewSettingService,
|
||||
NewEmailService,
|
||||
ProvideEmailQueueService,
|
||||
NewTurnstileService,
|
||||
NewSubscriptionService,
|
||||
NewConcurrencyService,
|
||||
NewIdentityService,
|
||||
ProvideUpdateService,
|
||||
ProvideTokenRefreshService,
|
||||
|
||||
// Provide the Services container struct
|
||||
wire.Struct(new(Services), "*"),
|
||||
)
|
||||
@@ -96,7 +96,7 @@ services:
|
||||
# PostgreSQL Database
|
||||
# ===========================================================================
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
image: postgres:18-alpine
|
||||
container_name: sub2api-postgres
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<link rel="icon" type="image/png" href="/logo.png" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Sub2API - AI API Gateway</title>
|
||||
</head>
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
<script setup lang="ts">
|
||||
import { RouterView, useRouter, useRoute } from 'vue-router'
|
||||
import { onMounted } from 'vue'
|
||||
import { onMounted, watch } from 'vue'
|
||||
import Toast from '@/components/common/Toast.vue'
|
||||
import { getPublicSettings } from '@/api/auth'
|
||||
import { useAppStore } from '@/stores'
|
||||
import { getSetupStatus } from '@/api/setup'
|
||||
|
||||
const router = useRouter()
|
||||
const route = useRoute()
|
||||
const appStore = useAppStore()
|
||||
|
||||
/**
|
||||
* Update favicon dynamically
|
||||
@@ -24,6 +25,19 @@ function updateFavicon(logoUrl: string) {
|
||||
link.href = logoUrl
|
||||
}
|
||||
|
||||
// Watch for site settings changes and update favicon/title
|
||||
watch(() => appStore.siteLogo, (newLogo) => {
|
||||
if (newLogo) {
|
||||
updateFavicon(newLogo)
|
||||
}
|
||||
}, { immediate: true })
|
||||
|
||||
watch(() => appStore.siteName, (newName) => {
|
||||
if (newName) {
|
||||
document.title = `${newName} - AI API Gateway`
|
||||
}
|
||||
}, { immediate: true })
|
||||
|
||||
onMounted(async () => {
|
||||
// Check if setup is needed
|
||||
try {
|
||||
@@ -36,21 +50,8 @@ onMounted(async () => {
|
||||
// If setup endpoint fails, assume normal mode and continue
|
||||
}
|
||||
|
||||
try {
|
||||
const settings = await getPublicSettings()
|
||||
|
||||
// Update favicon if logo is set
|
||||
if (settings.site_logo) {
|
||||
updateFavicon(settings.site_logo)
|
||||
}
|
||||
|
||||
// Update page title if site name is set
|
||||
if (settings.site_name) {
|
||||
document.title = `${settings.site_name} - AI API Gateway`
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load public settings for favicon:', error)
|
||||
}
|
||||
// Load public settings into appStore (will be cached for other components)
|
||||
await appStore.fetchPublicSettings()
|
||||
})
|
||||
</script>
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user