mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 15:32:13 +08:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1325e9ae6 | ||
|
|
587012396b | ||
|
|
adebd941e1 | ||
|
|
bb500b7b2a | ||
|
|
cceada7dae | ||
|
|
5c2e7ae265 | ||
|
|
420bedd615 | ||
|
|
a79f6c5e1e | ||
|
|
0484c59ead | ||
|
|
7bbf621490 | ||
|
|
ef81aeb463 | ||
|
|
22414326cc | ||
|
|
14b155c66b | ||
|
|
e99b344b2b |
@@ -40,6 +40,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
// 服务器层 ProviderSet
|
||||
server.ProviderSet,
|
||||
|
||||
// BuildInfo provider
|
||||
provideServiceBuildInfo,
|
||||
|
||||
// 清理函数提供者
|
||||
provideCleanup,
|
||||
|
||||
@@ -49,6 +52,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
BuildType: buildInfo.BuildType,
|
||||
}
|
||||
}
|
||||
|
||||
func provideCleanup(
|
||||
db *gorm.DB,
|
||||
rdb *redis.Client,
|
||||
@@ -63,6 +73,10 @@ func provideCleanup(
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
services.TokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
services.Pricing.Stop()
|
||||
return nil
|
||||
|
||||
@@ -41,8 +41,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
settingRepository := repository.NewSettingRepository(db)
|
||||
settingService := service.NewSettingService(settingRepository, configConfig)
|
||||
client := infrastructure.ProvideRedis(configConfig)
|
||||
emailService := service.NewEmailService(settingRepository, client)
|
||||
turnstileService := service.NewTurnstileService(settingService)
|
||||
emailCache := repository.NewEmailCache(client)
|
||||
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||
turnstileVerifier := repository.NewTurnstileVerifier()
|
||||
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
|
||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(authService)
|
||||
@@ -51,55 +53,60 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyRepository := repository.NewApiKeyRepository(db)
|
||||
groupRepository := repository.NewGroupRepository(db)
|
||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(db)
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, client, configConfig)
|
||||
apiKeyCache := repository.NewApiKeyCache(client)
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository)
|
||||
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
|
||||
accountRepository := repository.NewAccountRepository(db)
|
||||
proxyRepository := repository.NewProxyRepository(db)
|
||||
repositories := &repository.Repositories{
|
||||
User: userRepository,
|
||||
ApiKey: apiKeyRepository,
|
||||
Group: groupRepository,
|
||||
Account: accountRepository,
|
||||
Proxy: proxyRepository,
|
||||
RedeemCode: redeemCodeRepository,
|
||||
UsageLog: usageLogRepository,
|
||||
Setting: settingRepository,
|
||||
UserSubscription: userSubscriptionRepository,
|
||||
}
|
||||
billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository)
|
||||
subscriptionService := service.NewSubscriptionService(repositories, billingCacheService)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService)
|
||||
billingCache := repository.NewBillingCache(client)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||
redeemCache := repository.NewRedeemCache(client)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
adminService := service.NewAdminService(repositories, billingCacheService)
|
||||
accountRepository := repository.NewAccountRepository(db)
|
||||
proxyRepository := repository.NewProxyRepository(db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService, proxyExitInfoProber)
|
||||
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
|
||||
adminUserHandler := admin.NewUserHandler(adminService)
|
||||
groupHandler := admin.NewGroupHandler(adminService)
|
||||
oAuthService := service.NewOAuthService(proxyRepository)
|
||||
rateLimitService := service.NewRateLimitService(repositories, configConfig)
|
||||
accountUsageService := service.NewAccountUsageService(repositories, oAuthService)
|
||||
accountTestService := service.NewAccountTestService(repositories, oAuthService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService, claudeUsageFetcher)
|
||||
claudeUpstream := repository.NewClaudeUpstream(configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
||||
systemHandler := handler.ProvideSystemHandler(client, buildInfo)
|
||||
updateCache := repository.NewUpdateCache(client)
|
||||
gitHubReleaseClient := repository.NewGitHubReleaseClient()
|
||||
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
|
||||
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
|
||||
systemHandler := handler.ProvideSystemHandler(updateService)
|
||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
pricingService, err := service.ProvidePricingService(configConfig)
|
||||
gatewayCache := repository.NewGatewayCache(client)
|
||||
pricingRemoteClient := repository.NewPricingRemoteClient()
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(client)
|
||||
gatewayService := service.NewGatewayService(repositories, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
|
||||
concurrencyService := service.NewConcurrencyService(client)
|
||||
identityCache := repository.NewIdentityCache(client)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
||||
@@ -131,6 +138,19 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
Subscription: subscriptionService,
|
||||
Concurrency: concurrencyService,
|
||||
Identity: identityService,
|
||||
Update: updateService,
|
||||
TokenRefresh: tokenRefreshService,
|
||||
}
|
||||
repositories := &repository.Repositories{
|
||||
User: userRepository,
|
||||
ApiKey: apiKeyRepository,
|
||||
Group: groupRepository,
|
||||
Account: accountRepository,
|
||||
Proxy: proxyRepository,
|
||||
RedeemCode: redeemCodeRepository,
|
||||
UsageLog: usageLogRepository,
|
||||
Setting: settingRepository,
|
||||
UserSubscription: userSubscriptionRepository,
|
||||
}
|
||||
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
@@ -149,6 +169,13 @@ type Application struct {
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
BuildType: buildInfo.BuildType,
|
||||
}
|
||||
}
|
||||
|
||||
func provideCleanup(
|
||||
db *gorm.DB,
|
||||
rdb *redis.Client,
|
||||
@@ -162,6 +189,10 @@ func provideCleanup(
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
services.TokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
services.Pricing.Stop()
|
||||
return nil
|
||||
|
||||
@@ -8,15 +8,30 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
}
|
||||
|
||||
// TokenRefreshConfig OAuth token自动刷新配置
|
||||
type TokenRefreshConfig struct {
|
||||
// 是否启用自动刷新
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// 检查间隔(分钟)
|
||||
CheckIntervalMinutes int `mapstructure:"check_interval_minutes"`
|
||||
// 提前刷新时间(小时),在token过期前多久开始刷新
|
||||
RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"`
|
||||
// 最大重试次数
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
// 重试退避基础时间(秒)
|
||||
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||
}
|
||||
|
||||
type PricingConfig struct {
|
||||
@@ -192,6 +207,13 @@ func setDefaults() {
|
||||
|
||||
// Gateway
|
||||
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||
|
||||
// TokenRefresh
|
||||
viper.SetDefault("token_refresh.enabled", true)
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
|
||||
@@ -127,12 +127,25 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
||||
|
||||
// GetUsageTrend handles getting usage trend data
|
||||
// GET /api/v1/admin/dashboard/trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour)
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
|
||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
trend, err := h.usageRepo.GetUsageTrend(c.Request.Context(), startTime, endTime, granularity)
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = id
|
||||
}
|
||||
}
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
|
||||
apiKeyID = id
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.usageRepo.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
@@ -148,11 +161,24 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
|
||||
// GetModelStats handles getting model usage statistics
|
||||
// GET /api/v1/admin/dashboard/models
|
||||
// Query params: start_date, end_date (YYYY-MM-DD)
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
|
||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
stats, err := h.usageRepo.GetModelStats(c.Request.Context(), startTime, endTime)
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = id
|
||||
}
|
||||
}
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
|
||||
apiKeyID = id
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.usageRepo.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
|
||||
@@ -256,3 +256,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
|
||||
response.Success(c, gin.H{"message": "Test email sent successfully"})
|
||||
}
|
||||
|
||||
// GetAdminApiKey 获取管理员 API Key 状态
|
||||
// GET /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get admin API key status: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"exists": exists,
|
||||
"masked_key": maskedKey,
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
|
||||
// POST /api/v1/admin/settings/admin-api-key/regenerate
|
||||
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate admin API key: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"key": key, // 完整 key 只在生成时返回一次
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
// DELETE /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
|
||||
response.InternalError(c, "Failed to delete admin API key: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||
}
|
||||
|
||||
@@ -4,15 +4,15 @@ import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// toResponsePagination converts repository.PaginationResult to response.PaginationResult
|
||||
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult {
|
||||
// toResponsePagination converts pagination.PaginationResult to response.PaginationResult
|
||||
func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// SystemHandler handles system-related operations
|
||||
@@ -18,9 +17,9 @@ type SystemHandler struct {
|
||||
}
|
||||
|
||||
// NewSystemHandler creates a new SystemHandler
|
||||
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
|
||||
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
|
||||
return &SystemHandler{
|
||||
updateSvc: service.NewUpdateService(rdb, version, buildType),
|
||||
updateSvc: updateSvc,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
@@ -14,10 +15,10 @@ import (
|
||||
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageRepo *repository.UsageLogRepository
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
usageService *service.UsageService
|
||||
adminService service.AdminService
|
||||
usageRepo *repository.UsageLogRepository
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
usageService *service.UsageService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
@@ -82,7 +83,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := repository.UsageLogFilters{
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"strconv"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -53,7 +53,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/response"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
@@ -68,9 +69,9 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
var records []model.UsageLog
|
||||
var result *repository.PaginationResult
|
||||
var result *pagination.PaginationResult
|
||||
var err error
|
||||
|
||||
if apiKeyID > 0 {
|
||||
@@ -362,7 +363,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Verify ownership of all requested API keys
|
||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000})
|
||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to verify API key ownership")
|
||||
return
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ProvideAdminHandlers creates the AdminHandlers struct
|
||||
@@ -37,9 +36,9 @@ func ProvideAdminHandlers(
|
||||
}
|
||||
}
|
||||
|
||||
// ProvideSystemHandler creates admin.SystemHandler with BuildInfo parameters
|
||||
func ProvideSystemHandler(rdb *redis.Client, buildInfo BuildInfo) *admin.SystemHandler {
|
||||
return admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType)
|
||||
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
|
||||
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
|
||||
return admin.NewSystemHandler(updateService)
|
||||
}
|
||||
|
||||
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
||||
|
||||
130
backend/internal/middleware/admin_auth.go
Normal file
130
backend/internal/middleware/admin_auth.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"strings"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AdminAuth 管理员认证中间件
|
||||
// 支持两种认证方式(通过不同的 header 区分):
|
||||
// 1. Admin API Key: x-api-key: <admin-api-key>
|
||||
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
|
||||
func AdminAuth(
|
||||
authService *service.AuthService,
|
||||
userRepo interface {
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
GetFirstAdmin(ctx context.Context) (*model.User, error)
|
||||
},
|
||||
settingService *service.SettingService,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查 x-api-key header(Admin API Key 认证)
|
||||
apiKey := c.GetHeader("x-api-key")
|
||||
if apiKey != "" {
|
||||
if !validateAdminApiKey(c, apiKey, settingService, userRepo) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 Authorization header(JWT 认证)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
if !validateJWTForAdmin(c, parts[1], authService, userRepo) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 无有效认证信息
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||
}
|
||||
}
|
||||
|
||||
// validateAdminApiKey 验证管理员 API Key
|
||||
func validateAdminApiKey(
|
||||
c *gin.Context,
|
||||
key string,
|
||||
settingService *service.SettingService,
|
||||
userRepo interface {
|
||||
GetFirstAdmin(ctx context.Context) (*model.User, error)
|
||||
},
|
||||
) bool {
|
||||
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
|
||||
return false
|
||||
}
|
||||
|
||||
// 未配置或不匹配,统一返回相同错误(避免信息泄露)
|
||||
if storedKey == "" || subtle.ConstantTimeCompare([]byte(key), []byte(storedKey)) != 1 {
|
||||
AbortWithError(c, 401, "INVALID_ADMIN_KEY", "Invalid admin API key")
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取真实的管理员用户
|
||||
admin, err := userRepo.GetFirstAdmin(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), admin)
|
||||
c.Set("auth_method", "admin_api_key")
|
||||
return true
|
||||
}
|
||||
|
||||
// validateJWTForAdmin 验证 JWT 并检查管理员权限
|
||||
func validateJWTForAdmin(
|
||||
c *gin.Context,
|
||||
token string,
|
||||
authService *service.AuthService,
|
||||
userRepo interface {
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
},
|
||||
) bool {
|
||||
// 验证 JWT token
|
||||
claims, err := authService.ValidateToken(token)
|
||||
if err != nil {
|
||||
if err == service.ErrTokenExpired {
|
||||
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||
return false
|
||||
}
|
||||
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
|
||||
return false
|
||||
}
|
||||
|
||||
// 从数据库获取用户
|
||||
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查管理员权限
|
||||
if user.Role != model.RoleAdmin {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), user)
|
||||
c.Set("auth_method", "jwt")
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -46,8 +46,14 @@ const (
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
)
|
||||
|
||||
// 管理员 API Key 前缀(与用户 sk- 前缀区分)
|
||||
const AdminApiKeyPrefix = "admin-"
|
||||
|
||||
// SystemSettings 系统设置结构体(用于API响应)
|
||||
type SystemSettings struct {
|
||||
// 注册设置
|
||||
|
||||
42
backend/internal/pkg/pagination/pagination.go
Normal file
42
backend/internal/pkg/pagination/pagination.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package pagination
|
||||
|
||||
// PaginationParams 分页参数
|
||||
type PaginationParams struct {
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
PageSize int
|
||||
Pages int
|
||||
}
|
||||
|
||||
// DefaultPagination 默认分页参数
|
||||
func DefaultPagination() PaginationParams {
|
||||
return PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
}
|
||||
|
||||
// Offset 计算偏移量
|
||||
func (p PaginationParams) Offset() int {
|
||||
if p.Page < 1 {
|
||||
p.Page = 1
|
||||
}
|
||||
return (p.Page - 1) * p.PageSize
|
||||
}
|
||||
|
||||
// Limit 获取限制数
|
||||
func (p PaginationParams) Limit() int {
|
||||
if p.PageSize < 1 {
|
||||
return 20
|
||||
}
|
||||
if p.PageSize > 100 {
|
||||
return 100
|
||||
}
|
||||
return p.PageSize
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize in
|
||||
})
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果(与repository.PaginationResult兼容)
|
||||
// PaginationResult 分页结果(与pagination.PaginationResult兼容)
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
|
||||
8
backend/internal/pkg/usagestats/account_stats.go
Normal file
8
backend/internal/pkg/usagestats/account_stats.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package usagestats
|
||||
|
||||
// AccountStats 账号使用统计
|
||||
type AccountStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -47,12 +48,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) {
|
||||
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
|
||||
func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) {
|
||||
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
var accounts []model.Account
|
||||
var total int64
|
||||
|
||||
@@ -94,7 +95,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati
|
||||
pages++
|
||||
}
|
||||
|
||||
return accounts, &PaginationResult{
|
||||
return accounts, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -226,7 +227,7 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"rate_limited_at": now,
|
||||
"rate_limited_at": now,
|
||||
"rate_limit_reset_at": resetAt,
|
||||
}).Error
|
||||
}
|
||||
|
||||
51
backend/internal/repository/api_key_cache.go
Normal file
51
backend/internal/repository/api_key_cache.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return c.rdb.Incr(ctx, apiKey).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -45,7 +46,7 @@ func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
|
||||
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
var total int64
|
||||
|
||||
@@ -64,7 +65,7 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
pages++
|
||||
}
|
||||
|
||||
return keys, &PaginationResult{
|
||||
return keys, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -84,7 +85,7 @@ func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
|
||||
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
var total int64
|
||||
|
||||
@@ -103,7 +104,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
pages++
|
||||
}
|
||||
|
||||
return keys, &PaginationResult{
|
||||
return keys, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
|
||||
174
backend/internal/repository/billing_cache.go
Normal file
174
backend/internal/repository/billing_cache.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
subFieldDailyUsage = "daily_usage"
|
||||
subFieldWeeklyUsage = "weekly_usage"
|
||||
subFieldMonthlyUsage = "monthly_usage"
|
||||
subFieldVersion = "version"
|
||||
)
|
||||
|
||||
var (
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
return 0
|
||||
end
|
||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], newVal)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
updateSubUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type billingCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
|
||||
return &billingCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseFloat(val, 64)
|
||||
}
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return c.parseSubscriptionCache(result)
|
||||
}
|
||||
|
||||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
|
||||
result := &ports.SubscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
return nil, errors.New("invalid cache: missing status")
|
||||
}
|
||||
|
||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err == nil {
|
||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||
}
|
||||
|
||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||
}
|
||||
|
||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||
}
|
||||
|
||||
if versionStr, ok := data[subFieldVersion]; ok {
|
||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||
subFieldVersion: data.Version,
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
235
backend/internal/repository/claude_oauth_service.go
Normal file
235
backend/internal/repository/claude_oauth_service.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/oauth"
|
||||
"sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type claudeOAuthService struct{}
|
||||
|
||||
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
|
||||
return &claudeOAuthService{}
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
targetURL := "https://claude.ai/api/organizations"
|
||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetSuccessResult(&orgs).
|
||||
Get(targetURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if len(orgs) == 0 {
|
||||
return "", fmt.Errorf("no organizations found")
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
|
||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"response_type": "code",
|
||||
"client_id": oauth.ClientID,
|
||||
"organization_uuid": orgUUID,
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"scope": scope,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var result struct {
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetHeader("Accept-Language", "en-US,en;q=0.9").
|
||||
SetHeader("Cache-Control", "no-cache").
|
||||
SetHeader("Origin", "https://claude.ai").
|
||||
SetHeader("Referer", "https://claude.ai/new").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&result).
|
||||
Post(authURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if result.RedirectURI == "" {
|
||||
return "", fmt.Errorf("no redirect_uri in response")
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(result.RedirectURI)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
|
||||
}
|
||||
|
||||
queryParams := parsedURL.Query()
|
||||
authCode := queryParams.Get("code")
|
||||
responseState := queryParams.Get("state")
|
||||
|
||||
if authCode == "" {
|
||||
return "", fmt.Errorf("no authorization code in redirect_uri")
|
||||
}
|
||||
|
||||
fullCode := authCode
|
||||
if responseState != "" {
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
|
||||
return fullCode, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
|
||||
authCode := code
|
||||
codeState := ""
|
||||
if len(code) > 0 {
|
||||
parts := make([]string, 0, 2)
|
||||
for i, part := range []rune(code) {
|
||||
if part == '#' {
|
||||
authCode = code[:i]
|
||||
codeState = code[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
authCode = code
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"code": authCode,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauth.ClientID,
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"code_verifier": codeVerifier,
|
||||
}
|
||||
|
||||
if codeState != "" {
|
||||
reqBody["state"] = codeState
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := createReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", oauth.ClientID)
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
ImpersonateChrome().
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
64
backend/internal/repository/claude_service.go
Normal file
64
backend/internal/repository/claude_service.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/service"
|
||||
)
|
||||
|
||||
type claudeUpstreamService struct {
|
||||
defaultClient *http.Client
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
|
||||
return &claudeUpstreamService{
|
||||
defaultClient: &http.Client{Transport: transport},
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||
if proxyURL == "" {
|
||||
return s.defaultClient.Do(req)
|
||||
}
|
||||
client := s.createProxyClient(proxyURL)
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (s *claudeUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return s.defaultClient
|
||||
}
|
||||
|
||||
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
|
||||
return &http.Client{Transport: transport}
|
||||
}
|
||||
59
backend/internal/repository/claude_usage_service.go
Normal file
59
backend/internal/repository/claude_usage_service.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
)
|
||||
|
||||
type claudeUsageService struct{}
|
||||
|
||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{}
|
||||
}
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var usageResp service.ClaudeUsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
||||
return nil, fmt.Errorf("decode response failed: %w", err)
|
||||
}
|
||||
|
||||
return &usageResp, nil
|
||||
}
|
||||
132
backend/internal/repository/concurrency_cache.go
Normal file
132
backend/internal/repository/concurrency_cache.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
accountConcurrencyKeyPrefix = "concurrency:account:"
|
||||
userConcurrencyKeyPrefix = "concurrency:user:"
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
concurrencyTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
acquireScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current < tonumber(ARGV[1]) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
releaseScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local waitKey = KEYS[1]
|
||||
local maxWait = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current = redis.call('GET', waitKey)
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current >= maxWait then
|
||||
return 0
|
||||
end
|
||||
redis.call('INCR', waitKey)
|
||||
redis.call('EXPIRE', waitKey, ttl)
|
||||
return 1
|
||||
`)
|
||||
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
|
||||
return &concurrencyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKeyPrefix, accountID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, int(concurrencyTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
_, err := releaseScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(concurrencyTTL.Seconds())).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
48
backend/internal/repository/email_cache.go
Normal file
48
backend/internal/repository/email_cache.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const verifyCodeKeyPrefix = "verify_code:"
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
|
||||
return &emailCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data ports.VerificationCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
35
backend/internal/repository/gateway_cache.go
Normal file
35
backend/internal/repository/gateway_cache.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const stickySessionPrefix = "sticky_session:"
|
||||
|
||||
type gatewayCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
|
||||
return &gatewayCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
return c.rdb.Get(ctx, key).Int64()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||
}
|
||||
116
backend/internal/repository/github_release_service.go
Normal file
116
backend/internal/repository/github_release_service.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
)
|
||||
|
||||
type githubReleaseClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewGitHubReleaseClient() service.GitHubReleaseClient {
|
||||
return &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "Sub2API-Updater")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release service.GitHubRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &release, nil
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// SECURITY: Check Content-Length if available
|
||||
if resp.ContentLength > maxSize {
|
||||
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxSize)
|
||||
}
|
||||
|
||||
out, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||
written, err := io.Copy(out, limited)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if we hit the limit (downloaded more than maxSize)
|
||||
if written > maxSize {
|
||||
os.Remove(dest) // Clean up partial file
|
||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -36,12 +37,12 @@ func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
|
||||
}
|
||||
|
||||
func (r *GroupRepository) List(ctx context.Context, params PaginationParams) ([]model.Group, *PaginationResult, error) {
|
||||
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
||||
}
|
||||
|
||||
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
|
||||
func (r *GroupRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *PaginationResult, error) {
|
||||
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
var groups []model.Group
|
||||
var total int64
|
||||
|
||||
@@ -77,7 +78,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params Pagination
|
||||
pages++
|
||||
}
|
||||
|
||||
return groups, &PaginationResult{
|
||||
return groups, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
|
||||
47
backend/internal/repository/identity_cache.go
Normal file
47
backend/internal/repository/identity_cache.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
fingerprintKeyPrefix = "fingerprint:"
|
||||
fingerprintTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
type identityCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache {
|
||||
return &identityCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var fp ports.Fingerprint
|
||||
if err := json.Unmarshal([]byte(val), &fp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fp, nil
|
||||
}
|
||||
|
||||
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
val, err := json.Marshal(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
|
||||
}
|
||||
73
backend/internal/repository/pricing_service.go
Normal file
73
backend/internal/repository/pricing_service.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
)
|
||||
|
||||
type pricingRemoteClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewPricingRemoteClient() service.PricingRemoteClient {
|
||||
return &pricingRemoteClient{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 哈希文件格式:hash filename 或者纯 hash
|
||||
hash := strings.TrimSpace(string(body))
|
||||
parts := strings.Fields(hash)
|
||||
if len(parts) > 0 {
|
||||
return parts[0], nil
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
104
backend/internal/repository/proxy_probe_service.go
Normal file
104
backend/internal/repository/proxy_probe_service.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type proxyProbeService struct{}
|
||||
|
||||
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
|
||||
return &proxyProbeService{}
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
transport, err := createProxyTransport(proxyURL)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 15 * time.Second,
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
latencyMs := time.Since(startTime).Milliseconds()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var ipInfo struct {
|
||||
IP string `json:"ip"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
Country string `json:"country"`
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &service.ProxyExitInfo{
|
||||
IP: ipInfo.IP,
|
||||
City: ipInfo.City,
|
||||
Region: ipInfo.Region,
|
||||
Country: ipInfo.Country,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
|
||||
func createProxyTransport(proxyURL string) (*http.Transport, error) {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https":
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
case "socks5":
|
||||
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -36,12 +37,12 @@ func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) List(ctx context.Context, params PaginationParams) ([]model.Proxy, *PaginationResult, error) {
|
||||
func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
|
||||
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params PaginationParams, protocol, status, search string) ([]model.Proxy, *PaginationResult, error) {
|
||||
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
var proxies []model.Proxy
|
||||
var total int64
|
||||
|
||||
@@ -72,7 +73,7 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params Pagination
|
||||
pages++
|
||||
}
|
||||
|
||||
return proxies, &PaginationResult{
|
||||
return proxies, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
|
||||
49
backend/internal/repository/redeem_cache.go
Normal file
49
backend/internal/repository/redeem_cache.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
|
||||
redeemLockKeyPrefix = "redeem:lock:"
|
||||
redeemRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type redeemCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache {
|
||||
return &redeemCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
}
|
||||
|
||||
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
|
||||
key := redeemLockKeyPrefix + code
|
||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||
}
|
||||
|
||||
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
|
||||
key := redeemLockKeyPrefix + code
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -46,12 +47,12 @@ func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) List(ctx context.Context, params PaginationParams) ([]model.RedeemCode, *PaginationResult, error) {
|
||||
func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
|
||||
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params PaginationParams, codeType, status, search string) ([]model.RedeemCode, *PaginationResult, error) {
|
||||
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
var codes []model.RedeemCode
|
||||
var total int64
|
||||
|
||||
@@ -82,7 +83,7 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params Pagin
|
||||
pages++
|
||||
}
|
||||
|
||||
return codes, &PaginationResult{
|
||||
return codes, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
|
||||
@@ -12,44 +12,3 @@ type Repositories struct {
|
||||
Setting *SettingRepository
|
||||
UserSubscription *UserSubscriptionRepository
|
||||
}
|
||||
|
||||
// PaginationParams 分页参数
|
||||
type PaginationParams struct {
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
PageSize int
|
||||
Pages int
|
||||
}
|
||||
|
||||
// DefaultPagination 默认分页参数
|
||||
func DefaultPagination() PaginationParams {
|
||||
return PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
}
|
||||
|
||||
// Offset 计算偏移量
|
||||
func (p PaginationParams) Offset() int {
|
||||
if p.Page < 1 {
|
||||
p.Page = 1
|
||||
}
|
||||
return (p.Page - 1) * p.PageSize
|
||||
}
|
||||
|
||||
// Limit 获取限制数
|
||||
func (p PaginationParams) Limit() int {
|
||||
if p.PageSize < 1 {
|
||||
return 20
|
||||
}
|
||||
if p.PageSize > 100 {
|
||||
return 100
|
||||
}
|
||||
return p.PageSize
|
||||
}
|
||||
|
||||
55
backend/internal/repository/turnstile_service.go
Normal file
55
backend/internal/repository/turnstile_service.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service"
|
||||
)
|
||||
|
||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
type turnstileVerifier struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||
return &turnstileVerifier{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*service.TurnstileVerifyResponse, error) {
|
||||
formData := url.Values{}
|
||||
formData.Set("secret", secretKey)
|
||||
formData.Set("response", token)
|
||||
if remoteIP != "" {
|
||||
formData.Set("remoteip", remoteIP)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := v.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result service.TurnstileVerifyResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
28
backend/internal/repository/update_cache.go
Normal file
28
backend/internal/repository/update_cache.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const updateCacheKey = "update:latest"
|
||||
|
||||
type updateCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewUpdateCache(rdb *redis.Client) ports.UpdateCache {
|
||||
return &updateCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) {
|
||||
return c.rdb.Get(ctx, updateCacheKey).Result()
|
||||
}
|
||||
|
||||
func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error {
|
||||
return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err()
|
||||
}
|
||||
@@ -3,7 +3,9 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/pkg/usagestats"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -30,7 +32,7 @@ func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.Usag
|
||||
return &log, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -49,7 +51,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &PaginationResult{
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -57,7 +59,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -76,7 +78,7 @@ func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &PaginationResult{
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -270,7 +272,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -289,7 +291,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &PaginationResult{
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -297,7 +299,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
|
||||
@@ -306,7 +308,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
|
||||
@@ -315,7 +317,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||
@@ -324,7 +326,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
|
||||
@@ -337,15 +339,8 @@ func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
|
||||
}
|
||||
|
||||
// AccountStats 账号使用统计
|
||||
type AccountStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
}
|
||||
|
||||
// GetAccountTodayStats 获取账号今日统计
|
||||
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*AccountStats, error) {
|
||||
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
||||
today := timezone.Today()
|
||||
|
||||
var stats struct {
|
||||
@@ -367,7 +362,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountStats{
|
||||
return &usagestats.AccountStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
@@ -375,7 +370,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// GetAccountWindowStats 获取账号时间窗口内的统计
|
||||
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*AccountStats, error) {
|
||||
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||
var stats struct {
|
||||
Requests int64 `gorm:"column:requests"`
|
||||
Tokens int64 `gorm:"column:tokens"`
|
||||
@@ -395,7 +390,7 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountStats{
|
||||
return &usagestats.AccountStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
@@ -436,75 +431,13 @@ type UserUsageTrendPoint struct {
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// GetUsageTrend returns usage trend data grouped by date
|
||||
// granularity: "day" or "hour"
|
||||
func (r *UsageLogRepository) GetUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
|
||||
var results []TrendDataPoint
|
||||
|
||||
// Choose date format based on granularity
|
||||
var dateFormat string
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
} else {
|
||||
dateFormat = "YYYY-MM-DD"
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
TO_CHAR(created_at, ?) as date,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
`, dateFormat).
|
||||
Where("created_at >= ? AND created_at < ?", startTime, endTime).
|
||||
Group("date").
|
||||
Order("date ASC").
|
||||
Scan(&results).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetModelStats returns usage statistics grouped by model
|
||||
func (r *UsageLogRepository) GetModelStats(ctx context.Context, startTime, endTime time.Time) ([]ModelStat, error) {
|
||||
var results []ModelStat
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
model,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
`).
|
||||
Where("created_at >= ? AND created_at < ?", startTime, endTime).
|
||||
Group("model").
|
||||
Order("total_tokens DESC").
|
||||
Scan(&results).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Date string `json:"date"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
@@ -780,7 +713,7 @@ type UsageLogFilters struct {
|
||||
}
|
||||
|
||||
// ListWithFilters lists usage logs with optional filters (for admin)
|
||||
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *PaginationResult, error) {
|
||||
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -816,7 +749,7 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params Paginat
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &PaginationResult{
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -838,7 +771,7 @@ type UsageStats struct {
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
@@ -964,6 +897,76 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
|
||||
func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
|
||||
var results []TrendDataPoint
|
||||
|
||||
var dateFormat string
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
} else {
|
||||
dateFormat = "YYYY-MM-DD"
|
||||
}
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
TO_CHAR(created_at, ?) as date,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
`, dateFormat).
|
||||
Where("created_at >= ? AND created_at < ?", startTime, endTime)
|
||||
|
||||
if userID > 0 {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
}
|
||||
if apiKeyID > 0 {
|
||||
db = db.Where("api_key_id = ?", apiKeyID)
|
||||
}
|
||||
|
||||
err := db.Group("date").Order("date ASC").Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
||||
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]ModelStat, error) {
|
||||
var results []ModelStat
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
model,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
`).
|
||||
Where("created_at >= ? AND created_at < ?", startTime, endTime)
|
||||
|
||||
if userID > 0 {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
}
|
||||
if apiKeyID > 0 {
|
||||
db = db.Where("api_key_id = ?", apiKeyID)
|
||||
}
|
||||
|
||||
err := db.Group("model").Order("total_tokens DESC").Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetGlobalStats gets usage statistics for all users within a time range
|
||||
func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
var stats struct {
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -45,12 +46,12 @@ func (r *UserRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
|
||||
}
|
||||
|
||||
func (r *UserRepository) List(ctx context.Context, params PaginationParams) ([]model.User, *PaginationResult, error) {
|
||||
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists users with optional filtering by status, role, and search query
|
||||
func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationParams, status, role, search string) ([]model.User, *PaginationResult, error) {
|
||||
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
|
||||
var users []model.User
|
||||
var total int64
|
||||
|
||||
@@ -81,7 +82,7 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationP
|
||||
pages++
|
||||
}
|
||||
|
||||
return users, &PaginationResult{
|
||||
return users, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -128,3 +129,15 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
|
||||
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
|
||||
Order("id ASC").
|
||||
First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -100,7 +101,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
|
||||
}
|
||||
|
||||
// ListByGroupID 获取分组的所有订阅(分页)
|
||||
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.UserSubscription, *PaginationResult, error) {
|
||||
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []model.UserSubscription
|
||||
var total int64
|
||||
|
||||
@@ -126,7 +127,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
||||
pages++
|
||||
}
|
||||
|
||||
return subs, &PaginationResult{
|
||||
return subs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
@@ -135,7 +136,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (r *UserSubscriptionRepository) List(ctx context.Context, params PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *PaginationResult, error) {
|
||||
func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []model.UserSubscription
|
||||
var total int64
|
||||
|
||||
@@ -172,7 +173,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params Pagination
|
||||
pages++
|
||||
}
|
||||
|
||||
return subs, &PaginationResult{
|
||||
return subs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
@@ -16,4 +18,34 @@ var ProviderSet = wire.NewSet(
|
||||
NewSettingRepository,
|
||||
NewUserSubscriptionRepository,
|
||||
wire.Struct(new(Repositories), "*"),
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
NewBillingCache,
|
||||
NewApiKeyCache,
|
||||
NewConcurrencyCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
NewRedeemCache,
|
||||
NewUpdateCache,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
NewPricingRemoteClient,
|
||||
NewGitHubReleaseClient,
|
||||
NewProxyExitInfoProber,
|
||||
NewClaudeUsageFetcher,
|
||||
NewClaudeOAuthClient,
|
||||
NewClaudeUpstream,
|
||||
|
||||
// Bind concrete repositories to service port interfaces
|
||||
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
|
||||
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
|
||||
wire.Bind(new(ports.GroupRepository), new(*GroupRepository)),
|
||||
wire.Bind(new(ports.AccountRepository), new(*AccountRepository)),
|
||||
wire.Bind(new(ports.ProxyRepository), new(*ProxyRepository)),
|
||||
wire.Bind(new(ports.RedeemCodeRepository), new(*RedeemCodeRepository)),
|
||||
wire.Bind(new(ports.UsageLogRepository), new(*UsageLogRepository)),
|
||||
wire.Bind(new(ports.SettingRepository), new(*SettingRepository)),
|
||||
wire.Bind(new(ports.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
|
||||
)
|
||||
|
||||
@@ -132,7 +132,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
|
||||
// 管理员接口
|
||||
admin := v1.Group("/admin")
|
||||
admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly())
|
||||
admin.Use(middleware.AdminAuth(s.Auth, repos.User, s.Setting))
|
||||
{
|
||||
// 仪表盘
|
||||
dashboard := admin.Group("/dashboard")
|
||||
@@ -236,6 +236,10 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
|
||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
|
||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||
// Admin API Key 管理
|
||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
|
||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
|
||||
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
|
||||
}
|
||||
|
||||
// 系统管理
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -41,12 +42,12 @@ type UpdateAccountRequest struct {
|
||||
|
||||
// AccountService 账号管理服务
|
||||
type AccountService struct {
|
||||
accountRepo *repository.AccountRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
accountRepo ports.AccountRepository
|
||||
groupRepo ports.GroupRepository
|
||||
}
|
||||
|
||||
// NewAccountService 创建账号服务实例
|
||||
func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService {
|
||||
func NewAccountService(accountRepo ports.AccountRepository, groupRepo ports.GroupRepository) *AccountService {
|
||||
return &AccountService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -108,7 +109,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
|
||||
}
|
||||
|
||||
// List 获取账号列表
|
||||
func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) {
|
||||
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
accounts, pagination, err := s.accountRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list accounts: %w", err)
|
||||
|
||||
@@ -10,13 +10,12 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -37,19 +36,17 @@ type TestEvent struct {
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
repos *repository.Repositories
|
||||
oauthService *OAuthService
|
||||
httpClient *http.Client
|
||||
accountRepo ports.AccountRepository
|
||||
oauthService *OAuthService
|
||||
claudeUpstream ClaudeUpstream
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService {
|
||||
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, claudeUpstream ClaudeUpstream) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
repos: repos,
|
||||
oauthService: oauthService,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
claudeUpstream: claudeUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,7 +102,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get account
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Account not found")
|
||||
}
|
||||
@@ -209,23 +206,13 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
req.Header.Set("x-api-key", authToken)
|
||||
}
|
||||
|
||||
// Configure proxy if account has one
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL := account.Proxy.URL()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
}
|
||||
}
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
resp, err := s.claudeUpstream.Do(req, proxyURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
|
||||
@@ -2,17 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// usageCache 用于缓存usage数据
|
||||
@@ -35,10 +31,10 @@ type WindowStats struct {
|
||||
|
||||
// UsageProgress 使用量进度
|
||||
type UsageProgress struct {
|
||||
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%)
|
||||
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
||||
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%)
|
||||
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
||||
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
@@ -65,21 +61,26 @@ type ClaudeUsageResponse struct {
|
||||
} `json:"seven_day_sonnet"`
|
||||
}
|
||||
|
||||
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
|
||||
type ClaudeUsageFetcher interface {
|
||||
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
|
||||
}
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
type AccountUsageService struct {
|
||||
repos *repository.Repositories
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
oauthService *OAuthService
|
||||
httpClient *http.Client
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService {
|
||||
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
repos: repos,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
oauthService: oauthService,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
usageFetcher: usageFetcher,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,7 +89,7 @@ func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthS
|
||||
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
||||
// API Key账号: 不支持usage查询
|
||||
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account failed: %w", err)
|
||||
}
|
||||
@@ -148,7 +149,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
|
||||
stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
|
||||
return
|
||||
@@ -163,7 +164,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
|
||||
|
||||
// GetTodayStats 获取账号今日统计
|
||||
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
|
||||
stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID)
|
||||
stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get today stats failed: %w", err)
|
||||
}
|
||||
@@ -177,58 +178,23 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
|
||||
|
||||
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
|
||||
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
|
||||
// 获取access token(从credentials中获取)
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
}
|
||||
|
||||
// 获取代理配置
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL := account.Proxy.URL()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
}
|
||||
}
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
||||
usageResp, err := s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var usageResp ClaudeUsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
||||
return nil, fmt.Errorf("decode response failed: %w", err)
|
||||
}
|
||||
|
||||
// 转换为UsageInfo
|
||||
now := time.Now()
|
||||
return s.buildUsageInfo(&usageResp, &now), nil
|
||||
return s.buildUsageInfo(usageResp, &now), nil
|
||||
}
|
||||
|
||||
// parseTime 尝试多种格式解析时间
|
||||
|
||||
@@ -2,20 +2,14 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -177,37 +171,63 @@ type ProxyTestResult struct {
|
||||
Country string `json:"country,omitempty"`
|
||||
}
|
||||
|
||||
// ProxyExitInfo represents proxy exit information from ipinfo.io
|
||||
type ProxyExitInfo struct {
|
||||
IP string
|
||||
City string
|
||||
Region string
|
||||
Country string
|
||||
}
|
||||
|
||||
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
|
||||
type ProxyExitInfoProber interface {
|
||||
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||
}
|
||||
|
||||
// adminServiceImpl implements AdminService
|
||||
type adminServiceImpl struct {
|
||||
userRepo *repository.UserRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
accountRepo *repository.AccountRepository
|
||||
proxyRepo *repository.ProxyRepository
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
redeemCodeRepo *repository.RedeemCodeRepository
|
||||
usageLogRepo *repository.UsageLogRepository
|
||||
userSubRepo *repository.UserSubscriptionRepository
|
||||
userRepo ports.UserRepository
|
||||
groupRepo ports.GroupRepository
|
||||
accountRepo ports.AccountRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
apiKeyRepo ports.ApiKeyRepository
|
||||
redeemCodeRepo ports.RedeemCodeRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
func NewAdminService(repos *repository.Repositories, billingCacheService *BillingCacheService) AdminService {
|
||||
func NewAdminService(
|
||||
userRepo ports.UserRepository,
|
||||
groupRepo ports.GroupRepository,
|
||||
accountRepo ports.AccountRepository,
|
||||
proxyRepo ports.ProxyRepository,
|
||||
apiKeyRepo ports.ApiKeyRepository,
|
||||
redeemCodeRepo ports.RedeemCodeRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: repos.User,
|
||||
groupRepo: repos.Group,
|
||||
accountRepo: repos.Account,
|
||||
proxyRepo: repos.Proxy,
|
||||
apiKeyRepo: repos.ApiKey,
|
||||
redeemCodeRepo: repos.RedeemCode,
|
||||
usageLogRepo: repos.UsageLog,
|
||||
userSubRepo: repos.UserSubscription,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
}
|
||||
}
|
||||
|
||||
// User management implementations
|
||||
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -376,7 +396,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -397,7 +417,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
|
||||
// Group management implementations
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -568,7 +588,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -578,7 +598,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
|
||||
// Account management implementations
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -696,7 +716,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
|
||||
|
||||
// Proxy management implementations
|
||||
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -781,7 +801,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
|
||||
|
||||
// Redeem code management implementations
|
||||
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -865,79 +885,12 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return testProxyConnection(ctx, proxy)
|
||||
}
|
||||
|
||||
// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json
|
||||
func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestResult, error) {
|
||||
proxyURL := proxy.URL()
|
||||
|
||||
// Create HTTP client with proxy
|
||||
transport, err := createProxyTransport(proxyURL)
|
||||
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Failed to create proxy transport: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 15 * time.Second,
|
||||
}
|
||||
|
||||
// Measure latency
|
||||
startTime := time.Now()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Failed to create request: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Proxy connection failed: %v", err),
|
||||
}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
latencyMs := time.Since(startTime).Milliseconds()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Request failed with status: %d", resp.StatusCode),
|
||||
LatencyMs: latencyMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Parse ipinfo.io response
|
||||
var ipInfo struct {
|
||||
IP string `json:"ip"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
Country string `json:"country"`
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: true,
|
||||
Message: "Proxy is accessible but failed to read response",
|
||||
LatencyMs: latencyMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||
return &ProxyTestResult{
|
||||
Success: true,
|
||||
Message: "Proxy is accessible but failed to parse response",
|
||||
LatencyMs: latencyMs,
|
||||
Message: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -945,38 +898,9 @@ func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestRes
|
||||
Success: true,
|
||||
Message: "Proxy is accessible",
|
||||
LatencyMs: latencyMs,
|
||||
IPAddress: ipInfo.IP,
|
||||
City: ipInfo.City,
|
||||
Region: ipInfo.Region,
|
||||
Country: ipInfo.Country,
|
||||
IPAddress: exitInfo.IP,
|
||||
City: exitInfo.City,
|
||||
Region: exitInfo.Region,
|
||||
Country: exitInfo.Country,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createProxyTransport creates an HTTP transport with the given proxy URL
|
||||
func createProxyTransport(proxyURL string) (*http.Transport, error) {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https":
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
case "socks5":
|
||||
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
@@ -8,8 +8,9 @@ import (
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -17,18 +18,16 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApiKeyNotFound = errors.New("api key not found")
|
||||
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
|
||||
ErrApiKeyExists = errors.New("api key already exists")
|
||||
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
|
||||
ErrApiKeyNotFound = errors.New("api key not found")
|
||||
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
|
||||
ErrApiKeyExists = errors.New("api key already exists")
|
||||
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
apiKeyRateLimitDuration = time.Hour
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
)
|
||||
|
||||
// CreateApiKeyRequest 创建API Key请求
|
||||
@@ -47,21 +46,21 @@ type UpdateApiKeyRequest struct {
|
||||
|
||||
// ApiKeyService API Key服务
|
||||
type ApiKeyService struct {
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
userRepo *repository.UserRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
userSubRepo *repository.UserSubscriptionRepository
|
||||
rdb *redis.Client
|
||||
apiKeyRepo ports.ApiKeyRepository
|
||||
userRepo ports.UserRepository
|
||||
groupRepo ports.GroupRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
cache ports.ApiKeyCache
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewApiKeyService 创建API Key服务实例
|
||||
func NewApiKeyService(
|
||||
apiKeyRepo *repository.ApiKeyRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
groupRepo *repository.GroupRepository,
|
||||
userSubRepo *repository.UserSubscriptionRepository,
|
||||
rdb *redis.Client,
|
||||
apiKeyRepo ports.ApiKeyRepository,
|
||||
userRepo ports.UserRepository,
|
||||
groupRepo ports.GroupRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
cache ports.ApiKeyCache,
|
||||
cfg *config.Config,
|
||||
) *ApiKeyService {
|
||||
return &ApiKeyService{
|
||||
@@ -69,7 +68,7 @@ func NewApiKeyService(
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
@@ -112,13 +111,11 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
|
||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
count, err := s.rdb.Get(ctx, key).Int()
|
||||
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
@@ -133,16 +130,11 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
|
||||
|
||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, _ = pipe.Exec(ctx)
|
||||
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
|
||||
}
|
||||
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
@@ -237,7 +229,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) {
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
@@ -272,7 +264,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
|
||||
}
|
||||
|
||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
||||
if s.rdb != nil {
|
||||
if s.cache != nil {
|
||||
// 这里可以序列化并缓存API Key
|
||||
_ = cacheKey // 使用变量避免未使用错误
|
||||
}
|
||||
@@ -325,9 +317,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
if req.Status != nil {
|
||||
apiKey.Status = *req.Status
|
||||
// 如果状态改变,清除Redis缓存
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
||||
_ = s.rdb.Del(ctx, cacheKey)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,9 +345,8 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
}
|
||||
|
||||
// 清除Redis缓存
|
||||
if s.rdb != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
|
||||
_ = s.rdb.Del(ctx, cacheKey)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||
@@ -399,13 +389,13 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.rdb != nil {
|
||||
if s.cache != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
|
||||
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
|
||||
return fmt.Errorf("increment usage: %w", err)
|
||||
}
|
||||
// 设置24小时过期
|
||||
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
|
||||
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"log"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -35,7 +35,7 @@ type JWTClaims struct {
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
userRepo *repository.UserRepository
|
||||
userRepo ports.UserRepository
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
@@ -45,7 +45,7 @@ type AuthService struct {
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(
|
||||
userRepo *repository.UserRepository,
|
||||
userRepo ports.UserRepository,
|
||||
cfg *config.Config,
|
||||
settingService *SettingService,
|
||||
emailService *EmailService,
|
||||
|
||||
@@ -5,30 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 缓存Key前缀和TTL
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// 订阅缓存Hash字段
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
subFieldDailyUsage = "daily_usage"
|
||||
subFieldWeeklyUsage = "weekly_usage"
|
||||
subFieldMonthlyUsage = "monthly_usage"
|
||||
subFieldVersion = "version"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
@@ -38,35 +18,6 @@ var (
|
||||
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
|
||||
)
|
||||
|
||||
// 预编译的Lua脚本
|
||||
var (
|
||||
// deductBalanceScript: 扣减余额缓存,key不存在则忽略
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
return 0
|
||||
end
|
||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], newVal)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
// updateSubUsageScript: 更新订阅用量缓存,key不存在则忽略
|
||||
updateSubUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
type subscriptionCacheData struct {
|
||||
Status string
|
||||
@@ -80,15 +31,15 @@ type subscriptionCacheData struct {
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
rdb *redis.Client
|
||||
userRepo *repository.UserRepository
|
||||
subRepo *repository.UserSubscriptionRepository
|
||||
cache ports.BillingCache
|
||||
userRepo ports.UserRepository
|
||||
subRepo ports.UserSubscriptionRepository
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserRepository, subRepo *repository.UserSubscriptionRepository) *BillingCacheService {
|
||||
func NewBillingCacheService(cache ports.BillingCache, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
|
||||
return &BillingCacheService{
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
}
|
||||
@@ -100,24 +51,19 @@ func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserReposito
|
||||
|
||||
// GetUserBalance 获取用户余额(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
// Redis不可用,直接查询数据库
|
||||
return s.getUserBalanceFromDB(ctx, userID)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
// 尝试从缓存读取
|
||||
val, err := s.rdb.Get(ctx, key).Result()
|
||||
balance, err := s.cache.GetUserBalance(ctx, userID)
|
||||
if err == nil {
|
||||
balance, parseErr := strconv.ParseFloat(val, 64)
|
||||
if parseErr == nil {
|
||||
return balance, nil
|
||||
}
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
// 缓存未命中或解析错误,从数据库读取
|
||||
balance, err := s.getUserBalanceFromDB(ctx, userID)
|
||||
// 缓存未命中,从数据库读取
|
||||
balance, err = s.getUserBalanceFromDB(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -143,39 +89,28 @@ func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID i
|
||||
|
||||
// setBalanceCache 设置余额缓存
|
||||
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
if err := s.rdb.Set(ctx, key, balance, billingCacheTTL).Err(); err != nil {
|
||||
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
|
||||
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
|
||||
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
// 使用预编译的Lua脚本原子性扣减,如果key不存在则忽略
|
||||
_, err := deductBalanceScript.Run(ctx, s.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
return nil
|
||||
return s.cache.DeductUserBalance(ctx, userID, amount)
|
||||
}
|
||||
|
||||
// InvalidateUserBalance 失效用户余额缓存
|
||||
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
if err := s.rdb.Del(ctx, key).Err(); err != nil {
|
||||
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
|
||||
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
@@ -188,19 +123,14 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID
|
||||
|
||||
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return s.getSubscriptionFromDB(ctx, userID, groupID)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
// 尝试从缓存读取
|
||||
result, err := s.rdb.HGetAll(ctx, key).Result()
|
||||
if err == nil && len(result) > 0 {
|
||||
data, parseErr := s.parseSubscriptionCache(result)
|
||||
if parseErr == nil {
|
||||
return data, nil
|
||||
}
|
||||
cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
if err == nil && cacheData != nil {
|
||||
return s.convertFromPortsData(cacheData), nil
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库读取
|
||||
@@ -219,6 +149,28 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCacheData) *subscriptionCacheData {
|
||||
return &subscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
DailyUsage: data.DailyUsage,
|
||||
WeeklyUsage: data.WeeklyUsage,
|
||||
MonthlyUsage: data.MonthlyUsage,
|
||||
Version: data.Version,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *ports.SubscriptionCacheData {
|
||||
return &ports.SubscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
DailyUsage: data.DailyUsage,
|
||||
WeeklyUsage: data.WeeklyUsage,
|
||||
MonthlyUsage: data.MonthlyUsage,
|
||||
Version: data.Version,
|
||||
}
|
||||
}
|
||||
|
||||
// getSubscriptionFromDB 从数据库获取订阅数据
|
||||
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
@@ -236,90 +188,30 @@ func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseSubscriptionCache 解析订阅缓存数据
|
||||
func (s *BillingCacheService) parseSubscriptionCache(data map[string]string) (*subscriptionCacheData, error) {
|
||||
result := &subscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
return nil, errors.New("invalid cache: missing status")
|
||||
}
|
||||
|
||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err == nil {
|
||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||
}
|
||||
|
||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||
}
|
||||
|
||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||
}
|
||||
|
||||
if versionStr, ok := data[subFieldVersion]; ok {
|
||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// setSubscriptionCache 设置订阅缓存
|
||||
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
|
||||
if s.rdb == nil || data == nil {
|
||||
if s.cache == nil || data == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||
subFieldVersion: data.Version,
|
||||
}
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
|
||||
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
|
||||
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
// 使用预编译的Lua脚本原子性增加用量,如果key不存在则忽略
|
||||
_, err := updateSubUsageScript.Run(ctx, s.rdb, []string{key}, costUSD, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
|
||||
}
|
||||
|
||||
// InvalidateSubscription 失效指定订阅缓存
|
||||
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
if err := s.rdb.Del(ctx, key).Err(); err != nil {
|
||||
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
|
||||
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,22 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
const (
|
||||
// Redis key prefixes
|
||||
accountConcurrencyKey = "concurrency:account:"
|
||||
userConcurrencyKey = "concurrency:user:"
|
||||
userWaitCountKey = "concurrency:wait:"
|
||||
|
||||
// TTL for concurrency keys (auto-release safety net)
|
||||
concurrencyKeyTTL = 10 * time.Minute
|
||||
|
||||
// Wait polling interval
|
||||
waitPollInterval = 100 * time.Millisecond
|
||||
|
||||
@@ -28,70 +19,14 @@ const (
|
||||
defaultExtraWaitSlots = 20
|
||||
)
|
||||
|
||||
// Pre-compiled Lua scripts for better performance
|
||||
var (
|
||||
// acquireScript: increment counter if below max, return 1 if successful
|
||||
acquireScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current < tonumber(ARGV[1]) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
// releaseScript: decrement counter, but don't go below 0
|
||||
releaseScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// incrementWaitScript: increment wait counter if below max, return 1 if successful
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local waitKey = KEYS[1]
|
||||
local maxWait = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current = redis.call('GET', waitKey)
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
if current >= maxWait then
|
||||
return 0
|
||||
end
|
||||
redis.call('INCR', waitKey)
|
||||
redis.call('EXPIRE', waitKey, ttl)
|
||||
return 1
|
||||
`)
|
||||
|
||||
// decrementWaitScript: decrement wait counter, but don't go below 0
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||||
type ConcurrencyService struct {
|
||||
rdb *redis.Client
|
||||
cache ports.ConcurrencyCache
|
||||
}
|
||||
|
||||
// NewConcurrencyService creates a new ConcurrencyService
|
||||
func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService {
|
||||
return &ConcurrencyService{rdb: rdb}
|
||||
func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
|
||||
return &ConcurrencyService{cache: cache}
|
||||
}
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
@@ -104,20 +39,6 @@ type AcquireResult struct {
|
||||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
||||
}
|
||||
|
||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
||||
return s.acquireSlot(ctx, key, maxConcurrency)
|
||||
}
|
||||
|
||||
// acquireSlot is the core implementation for acquiring a concurrency slot
|
||||
func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
@@ -126,8 +47,7 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Try to acquire immediately
|
||||
acquired, err := s.tryAcquire(ctx, key, maxConcurrency)
|
||||
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -135,64 +55,56 @@ func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxCon
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: s.makeReleaseFunc(key),
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID); err != nil {
|
||||
log.Printf("Warning: failed to release account slot for %d: %v", accountID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Not acquired, return with Acquired=false
|
||||
// The caller (gateway handler) will handle waiting with ping support
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// tryAcquire attempts to increment the counter if below max
|
||||
// Uses pre-compiled Lua script for atomicity and performance
|
||||
func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) {
|
||||
result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int()
|
||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {}, // no-op
|
||||
}, nil
|
||||
}
|
||||
|
||||
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("acquire slot failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// makeReleaseFunc creates a function to release a concurrency slot
|
||||
func (s *ConcurrencyService) makeReleaseFunc(key string) func() {
|
||||
return func() {
|
||||
// Use background context to ensure release even if original context is cancelled
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil {
|
||||
// Log error but don't panic - TTL will eventually clean up
|
||||
log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err)
|
||||
}
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseUserSlot(bgCtx, userID); err != nil {
|
||||
log.Printf("Warning: failed to release user slot for %d: %v", userID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetCurrentCount returns the current concurrency count for debugging/monitoring
|
||||
func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) {
|
||||
val, err := s.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// GetAccountCurrentCount returns current concurrency count for an account
|
||||
func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
}
|
||||
|
||||
// GetUserCurrentCount returns current concurrency count for a user
|
||||
func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
@@ -203,44 +115,36 @@ func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int
|
||||
// Returns true if successful, false if the wait queue is full.
|
||||
// maxWait should be user.Concurrency + defaultExtraWaitSlots
|
||||
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
// Redis not available, allow request
|
||||
return true, nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int()
|
||||
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
|
||||
if err != nil {
|
||||
// On error, allow the request to proceed (fail open)
|
||||
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result == 1, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait queue counter for a user.
|
||||
// Should be called when a request completes or exits the wait queue.
|
||||
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
// Use background context to ensure decrement even if original context is cancelled
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil {
|
||||
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
|
||||
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserWaitCount returns current wait queue count for a user
|
||||
func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
|
||||
return s.GetCurrentCount(ctx, key)
|
||||
}
|
||||
|
||||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||||
func CalculateMaxWait(userConcurrency int) int {
|
||||
|
||||
@@ -4,17 +4,14 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -25,19 +22,11 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
verifyCodeKeyPrefix = "email_verify:"
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
maxVerifyCodeAttempts = 5
|
||||
)
|
||||
|
||||
// verifyCodeData Redis 中存储的验证码数据
|
||||
type verifyCodeData struct {
|
||||
Code string `json:"code"`
|
||||
Attempts int `json:"attempts"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SmtpConfig SMTP配置
|
||||
type SmtpConfig struct {
|
||||
Host string
|
||||
@@ -51,15 +40,15 @@ type SmtpConfig struct {
|
||||
|
||||
// EmailService 邮件服务
|
||||
type EmailService struct {
|
||||
settingRepo *repository.SettingRepository
|
||||
rdb *redis.Client
|
||||
settingRepo ports.SettingRepository
|
||||
cache ports.EmailCache
|
||||
}
|
||||
|
||||
// NewEmailService 创建邮件服务实例
|
||||
func NewEmailService(settingRepo *repository.SettingRepository, rdb *redis.Client) *EmailService {
|
||||
func NewEmailService(settingRepo ports.SettingRepository, cache ports.EmailCache) *EmailService {
|
||||
return &EmailService{
|
||||
settingRepo: settingRepo,
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,10 +190,8 @@ func (s *EmailService) GenerateVerifyCode() (string, error) {
|
||||
|
||||
// SendVerifyCode 发送验证码邮件
|
||||
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
|
||||
// 检查是否在冷却期内
|
||||
existing, err := s.getVerifyCodeData(ctx, key)
|
||||
existing, err := s.cache.GetVerificationCode(ctx, email)
|
||||
if err == nil && existing != nil {
|
||||
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
|
||||
return ErrVerifyCodeTooFrequent
|
||||
@@ -218,12 +205,12 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
||||
}
|
||||
|
||||
// 保存验证码到 Redis
|
||||
data := &verifyCodeData{
|
||||
data := &ports.VerificationCodeData{
|
||||
Code: code,
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := s.setVerifyCodeData(ctx, key, data); err != nil {
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
return fmt.Errorf("save verify code: %w", err)
|
||||
}
|
||||
|
||||
@@ -241,9 +228,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
||||
|
||||
// VerifyCode 验证验证码
|
||||
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
|
||||
data, err := s.getVerifyCodeData(ctx, key)
|
||||
data, err := s.cache.GetVerificationCode(ctx, email)
|
||||
if err != nil || data == nil {
|
||||
return ErrInvalidVerifyCode
|
||||
}
|
||||
@@ -256,7 +241,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
// 验证码不匹配
|
||||
if data.Code != code {
|
||||
data.Attempts++
|
||||
_ = s.setVerifyCodeData(ctx, key, data)
|
||||
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL)
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
@@ -264,32 +249,10 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
}
|
||||
|
||||
// 验证成功,删除验证码
|
||||
s.rdb.Del(ctx, key)
|
||||
_ = s.cache.DeleteVerificationCode(ctx, email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getVerifyCodeData 从 Redis 获取验证码数据
|
||||
func (s *EmailService) getVerifyCodeData(ctx context.Context, key string) (*verifyCodeData, error) {
|
||||
val, err := s.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data verifyCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// setVerifyCodeData 保存验证码数据到 Redis
|
||||
func (s *EmailService) setVerifyCodeData(ctx context.Context, key string, data *verifyCodeData) error {
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.rdb.Set(ctx, key, val, verifyCodeTTL).Err()
|
||||
}
|
||||
|
||||
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
|
||||
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
||||
return fmt.Sprintf(`
|
||||
|
||||
@@ -12,27 +12,27 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ClaudeUpstream handles HTTP requests to Claude API
|
||||
type ClaudeUpstream interface {
|
||||
Do(req *http.Request, proxyURL string) (*http.Response, error)
|
||||
}
|
||||
|
||||
const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
stickySessionPrefix = "sticky_session:"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
|
||||
)
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
@@ -78,46 +78,48 @@ type ForwardResult struct {
|
||||
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
repos *repository.Repositories
|
||||
rdb *redis.Client
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
cache ports.GatewayCache
|
||||
cfg *config.Config
|
||||
oauthService *OAuthService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
httpClient *http.Client
|
||||
claudeUpstream ClaudeUpstream
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config, oauthService *OAuthService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService) *GatewayService {
|
||||
// 计算响应头超时时间
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second // 默认5分钟,LLM高负载时可能排队较久
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout, // 等待上游响应头的超时
|
||||
// 注意:不设置整体 Timeout,让流式响应可以无限时间传输
|
||||
}
|
||||
func NewGatewayService(
|
||||
accountRepo ports.AccountRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userRepo ports.UserRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
cache ports.GatewayCache,
|
||||
cfg *config.Config,
|
||||
oauthService *OAuthService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
identityService *IdentityService,
|
||||
claudeUpstream ClaudeUpstream,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
repos: repos,
|
||||
rdb: rdb,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
oauthService: oauthService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
identityService: identityService,
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
// 不设置 Timeout:流式请求可能持续十几分钟
|
||||
// 超时控制由 Transport.ResponseHeaderTimeout 负责(只控制等待响应头)
|
||||
},
|
||||
claudeUpstream: claudeUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,14 +274,14 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
|
||||
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64()
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// 续期粘性会话
|
||||
s.rdb.Expire(ctx, stickySessionPrefix+sessionHash, stickySessionTTL)
|
||||
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -289,9 +291,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
var accounts []model.Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.repos.Account.ListSchedulableByGroupID(ctx, *groupID)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
|
||||
} else {
|
||||
accounts, err = s.repos.Account.ListSchedulable(ctx)
|
||||
accounts, err = s.accountRepo.ListSchedulable(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -329,7 +331,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" {
|
||||
s.rdb.Set(ctx, stickySessionPrefix+sessionHash, selected.ID, stickySessionTTL)
|
||||
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
@@ -354,37 +356,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
|
||||
|
||||
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
expiresAtStr := account.GetCredential("expires_at")
|
||||
|
||||
// 检查是否需要刷新
|
||||
needRefresh := false
|
||||
if expiresAtStr != "" {
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err == nil && time.Now().Unix()+tokenRefreshBuffer > expiresAt {
|
||||
needRefresh = true
|
||||
}
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
if needRefresh || accessToken == "" {
|
||||
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("refresh token failed: %w", err)
|
||||
}
|
||||
|
||||
// 更新账号凭证
|
||||
account.Credentials["access_token"] = tokenInfo.AccessToken
|
||||
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if err := s.repos.Account.Update(ctx, account); err != nil {
|
||||
log.Printf("Failed to update account credentials: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo.AccessToken, "oauth", nil
|
||||
}
|
||||
|
||||
// Token刷新由后台 TokenRefreshService 处理,此处只返回当前token
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
|
||||
@@ -420,48 +395,25 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
}
|
||||
|
||||
// 构建上游请求
|
||||
upstreamResult, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 选择使用的client:如果有代理则使用独立的client,否则使用共享的httpClient
|
||||
httpClient := s.httpClient
|
||||
if upstreamResult.Client != nil {
|
||||
httpClient = upstreamResult.Client
|
||||
// 获取代理URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := httpClient.Do(upstreamResult.Request)
|
||||
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理401错误:刷新token重试
|
||||
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
|
||||
resp.Body.Close()
|
||||
token, tokenType, err = s.forceRefreshToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh failed: %w", err)
|
||||
}
|
||||
upstreamResult, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 重试时也需要使用正确的client
|
||||
httpClient = s.httpClient
|
||||
if upstreamResult.Client != nil {
|
||||
httpClient = upstreamResult.Client
|
||||
}
|
||||
resp, err = httpClient.Do(upstreamResult.Request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retry request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
||||
if resp.StatusCode >= 400 {
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
@@ -493,13 +445,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildUpstreamRequestResult contains the request and optional custom client for proxy
|
||||
type buildUpstreamRequestResult struct {
|
||||
Request *http.Request
|
||||
Client *http.Client // nil means use default s.httpClient
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) {
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
@@ -508,7 +454,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
// OAuth账号:应用统一指纹
|
||||
var fingerprint *Fingerprint
|
||||
var fingerprint *ports.Fingerprint
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
@@ -568,36 +514,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||
}
|
||||
|
||||
// 配置代理 - 创建独立的client避免并发修改共享httpClient
|
||||
var customClient *http.Client
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL := account.Proxy.URL()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
// 计算响应头超时时间(与默认 Transport 保持一致)
|
||||
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
}
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
// 创建独立的client,避免并发时修改共享的s.httpClient.Transport
|
||||
customClient = &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &buildUpstreamRequestResult{
|
||||
Request: req,
|
||||
Client: customClient,
|
||||
}, nil
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// getBetaHeader 处理anthropic-beta header
|
||||
@@ -655,25 +572,6 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
return claude.DefaultBetaHeader
|
||||
}
|
||||
|
||||
func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
account.Credentials["access_token"] = tokenInfo.AccessToken
|
||||
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if err := s.repos.Account.Update(ctx, account); err != nil {
|
||||
log.Printf("Failed to update account credentials: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo.AccessToken, "oauth", nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -999,7 +897,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
if err := s.repos.UsageLog.Create(ctx, usageLog); err != nil {
|
||||
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
|
||||
log.Printf("Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -1007,7 +905,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if cost.TotalCost > 0 {
|
||||
if err := s.repos.UserSubscription.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
@@ -1022,7 +920,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if cost.ActualCost > 0 {
|
||||
if err := s.repos.User.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
@@ -1037,7 +935,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
|
||||
// 更新账号最后使用时间
|
||||
if err := s.repos.Account.UpdateLastUsed(ctx, account.ID); err != nil {
|
||||
if err := s.accountRepo.UpdateLastUsed(ctx, account.ID); err != nil {
|
||||
log.Printf("Update last used failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -1069,50 +967,26 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 构建上游请求
|
||||
upstreamResult, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
||||
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
||||
return err
|
||||
}
|
||||
|
||||
// 选择 HTTP client
|
||||
httpClient := s.httpClient
|
||||
if upstreamResult.Client != nil {
|
||||
httpClient = upstreamResult.Client
|
||||
// 获取代理URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := httpClient.Do(upstreamResult.Request)
|
||||
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理 401 错误:刷新 token 重试(仅 OAuth)
|
||||
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
|
||||
resp.Body.Close()
|
||||
token, tokenType, err = s.forceRefreshToken(ctx, account)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed")
|
||||
return fmt.Errorf("token refresh failed: %w", err)
|
||||
}
|
||||
upstreamResult, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpClient = s.httpClient
|
||||
if upstreamResult.Client != nil {
|
||||
httpClient = upstreamResult.Client
|
||||
}
|
||||
resp, err = httpClient.Do(upstreamResult.Request)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed")
|
||||
return fmt.Errorf("retry request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -1143,7 +1017,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// buildCountTokensRequest 构建 count_tokens 上游请求
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) {
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
@@ -1207,32 +1081,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||
}
|
||||
|
||||
// 配置代理
|
||||
var customClient *http.Client
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL := account.Proxy.URL()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
}
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
customClient = &http.Client{Transport: transport}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &buildUpstreamRequestResult{
|
||||
Request: req,
|
||||
Client: customClient,
|
||||
}, nil
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// countTokensError 返回 count_tokens 错误响应
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -34,11 +35,11 @@ type UpdateGroupRequest struct {
|
||||
|
||||
// GroupService 分组管理服务
|
||||
type GroupService struct {
|
||||
groupRepo *repository.GroupRepository
|
||||
groupRepo ports.GroupRepository
|
||||
}
|
||||
|
||||
// NewGroupService 创建分组服务实例
|
||||
func NewGroupService(groupRepo *repository.GroupRepository) *GroupService {
|
||||
func NewGroupService(groupRepo ports.GroupRepository) *GroupService {
|
||||
return &GroupService{
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
@@ -84,7 +85,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err
|
||||
}
|
||||
|
||||
// List 获取分组列表
|
||||
func (s *GroupService) List(ctx context.Context, params repository.PaginationParams) ([]model.Group, *repository.PaginationResult, error) {
|
||||
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
groups, pagination, err := s.groupRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list groups: %w", err)
|
||||
|
||||
@@ -11,15 +11,10 @@ import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
// Redis key prefix
|
||||
identityFingerprintKey = "identity:fingerprint:"
|
||||
)
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
var (
|
||||
@@ -29,20 +24,8 @@ var (
|
||||
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
||||
)
|
||||
|
||||
// Fingerprint 存储的指纹数据结构
|
||||
type Fingerprint struct {
|
||||
ClientID string `json:"client_id"` // 64位hex客户端ID(首次随机生成)
|
||||
UserAgent string `json:"user_agent"` // User-Agent
|
||||
StainlessLang string `json:"x_stainless_lang"` // x-stainless-lang
|
||||
StainlessPackageVersion string `json:"x_stainless_package_version"` // x-stainless-package-version
|
||||
StainlessOS string `json:"x_stainless_os"` // x-stainless-os
|
||||
StainlessArch string `json:"x_stainless_arch"` // x-stainless-arch
|
||||
StainlessRuntime string `json:"x_stainless_runtime"` // x-stainless-runtime
|
||||
StainlessRuntimeVersion string `json:"x_stainless_runtime_version"` // x-stainless-runtime-version
|
||||
}
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = Fingerprint{
|
||||
var defaultFingerprint = ports.Fingerprint{
|
||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.52.0",
|
||||
@@ -54,39 +37,31 @@ var defaultFingerprint = Fingerprint{
|
||||
|
||||
// IdentityService 管理OAuth账号的请求身份指纹
|
||||
type IdentityService struct {
|
||||
rdb *redis.Client
|
||||
cache ports.IdentityCache
|
||||
}
|
||||
|
||||
// NewIdentityService 创建新的IdentityService
|
||||
func NewIdentityService(rdb *redis.Client) *IdentityService {
|
||||
return &IdentityService{rdb: rdb}
|
||||
func NewIdentityService(cache ports.IdentityCache) *IdentityService {
|
||||
return &IdentityService{cache: cache}
|
||||
}
|
||||
|
||||
// GetOrCreateFingerprint 获取或创建账号的指纹
|
||||
// 如果缓存存在,检测user-agent版本,新版本则更新
|
||||
// 如果缓存不存在,生成随机ClientID并从请求头创建指纹,然后缓存
|
||||
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
|
||||
key := identityFingerprintKey + strconv.FormatInt(accountID, 10)
|
||||
|
||||
// 尝试从Redis获取缓存的指纹
|
||||
data, err := s.rdb.Get(ctx, key).Bytes()
|
||||
if err == nil && len(data) > 0 {
|
||||
// 缓存存在,解析指纹
|
||||
var cached Fingerprint
|
||||
if err := json.Unmarshal(data, &cached); err == nil {
|
||||
// 检查客户端的user-agent是否是更新版本
|
||||
clientUA := headers.Get("User-Agent")
|
||||
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
|
||||
// 更新user-agent
|
||||
cached.UserAgent = clientUA
|
||||
// 保存更新后的指纹
|
||||
if newData, err := json.Marshal(cached); err == nil {
|
||||
s.rdb.Set(ctx, key, newData, 0) // 永不过期
|
||||
}
|
||||
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||
}
|
||||
return &cached, nil
|
||||
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*ports.Fingerprint, error) {
|
||||
// 尝试从缓存获取指纹
|
||||
cached, err := s.cache.GetFingerprint(ctx, accountID)
|
||||
if err == nil && cached != nil {
|
||||
// 检查客户端的user-agent是否是更新版本
|
||||
clientUA := headers.Get("User-Agent")
|
||||
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
|
||||
// 更新user-agent
|
||||
cached.UserAgent = clientUA
|
||||
// 保存更新后的指纹
|
||||
_ = s.cache.SetFingerprint(ctx, accountID, cached)
|
||||
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// 缓存不存在或解析失败,创建新指纹
|
||||
@@ -95,11 +70,9 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
||||
// 生成随机ClientID
|
||||
fp.ClientID = generateClientID()
|
||||
|
||||
// 保存到Redis(永不过期)
|
||||
if data, err := json.Marshal(fp); err == nil {
|
||||
if err := s.rdb.Set(ctx, key, data, 0).Err(); err != nil {
|
||||
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||
}
|
||||
// 保存到缓存(永不过期)
|
||||
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
|
||||
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||
}
|
||||
|
||||
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
|
||||
@@ -107,8 +80,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// createFingerprintFromHeaders 从请求头创建指纹
|
||||
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
|
||||
fp := &Fingerprint{}
|
||||
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *ports.Fingerprint {
|
||||
fp := &ports.Fingerprint{}
|
||||
|
||||
// 获取User-Agent
|
||||
if ua := headers.Get("User-Agent"); ua != "" {
|
||||
@@ -137,7 +110,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
|
||||
}
|
||||
|
||||
// ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头)
|
||||
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
|
||||
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerprint) {
|
||||
if fp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,32 +2,36 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/oauth"
|
||||
"sub2api/internal/repository"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
||||
type ClaudeOAuthClient interface {
|
||||
GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error)
|
||||
GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
|
||||
ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
|
||||
}
|
||||
|
||||
// OAuthService handles OAuth authentication flows
|
||||
type OAuthService struct {
|
||||
sessionStore *oauth.SessionStore
|
||||
proxyRepo *repository.ProxyRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
oauthClient ClaudeOAuthClient
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuth service
|
||||
func NewOAuthService(proxyRepo *repository.ProxyRepository) *OAuthService {
|
||||
func NewOAuthService(proxyRepo ports.ProxyRepository, oauthClient ClaudeOAuthClient) *OAuthService {
|
||||
return &OAuthService{
|
||||
sessionStore: oauth.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
oauthClient: oauthClient,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,177 +214,21 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
|
||||
|
||||
// getOrganizationUUID gets the organization UUID from claude.ai using sessionKey
|
||||
func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
targetURL := "https://claude.ai/api/organizations"
|
||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetSuccessResult(&orgs).
|
||||
Get(targetURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if len(orgs) == 0 {
|
||||
return "", fmt.Errorf("no organizations found")
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
||||
return orgs[0].UUID, nil
|
||||
return s.oauthClient.GetOrganizationUUID(ctx, sessionKey, proxyURL)
|
||||
}
|
||||
|
||||
// getAuthorizationCode gets the authorization code using sessionKey
|
||||
func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||
|
||||
// Build request body - must include organization_uuid as per CRS
|
||||
reqBody := map[string]interface{}{
|
||||
"response_type": "code",
|
||||
"client_id": oauth.ClientID,
|
||||
"organization_uuid": orgUUID, // Required field!
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"scope": scope,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
// Response contains redirect_uri with code, not direct code field
|
||||
var result struct {
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetHeader("Accept-Language", "en-US,en;q=0.9").
|
||||
SetHeader("Cache-Control", "no-cache").
|
||||
SetHeader("Origin", "https://claude.ai").
|
||||
SetHeader("Referer", "https://claude.ai/new").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&result).
|
||||
Post(authURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if result.RedirectURI == "" {
|
||||
return "", fmt.Errorf("no redirect_uri in response")
|
||||
}
|
||||
|
||||
// Parse redirect_uri to extract code and state
|
||||
parsedURL, err := url.Parse(result.RedirectURI)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
|
||||
}
|
||||
|
||||
queryParams := parsedURL.Query()
|
||||
authCode := queryParams.Get("code")
|
||||
responseState := queryParams.Get("state")
|
||||
|
||||
if authCode == "" {
|
||||
return "", fmt.Errorf("no authorization code in redirect_uri")
|
||||
}
|
||||
|
||||
// Combine code with state if present (as CRS does)
|
||||
fullCode := authCode
|
||||
if responseState != "" {
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
|
||||
return fullCode, nil
|
||||
return s.oauthClient.GetAuthorizationCode(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges authorization code for tokens
|
||||
func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
// Parse code#state format if present
|
||||
authCode := code
|
||||
codeState := ""
|
||||
if parts := strings.Split(code, "#"); len(parts) > 1 {
|
||||
authCode = parts[0]
|
||||
codeState = parts[1]
|
||||
}
|
||||
|
||||
// Build JSON body as CRS does (not form data!)
|
||||
reqBody := map[string]interface{}{
|
||||
"code": authCode,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauth.ClientID,
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"code_verifier": codeVerifier,
|
||||
}
|
||||
|
||||
// Add state if present
|
||||
if codeState != "" {
|
||||
reqBody["state"] = codeState
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
|
||||
tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
|
||||
|
||||
tokenInfo := &TokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
@@ -390,7 +238,6 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
|
||||
Scope: tokenResp.Scope,
|
||||
}
|
||||
|
||||
// Extract org_uuid and account_uuid from response
|
||||
if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" {
|
||||
tokenInfo.OrgUUID = tokenResp.Organization.UUID
|
||||
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
|
||||
@@ -405,27 +252,9 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
|
||||
|
||||
// RefreshToken refreshes an OAuth token
|
||||
func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) {
|
||||
client := s.createReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", oauth.ClientID)
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(oauth.TokenURL)
|
||||
|
||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenInfo{
|
||||
@@ -455,17 +284,3 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
// createReqClient creates a req client with Chrome impersonation and optional proxy
|
||||
func (s *OAuthService) createReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
// Set proxy if specified
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
35
backend/internal/service/ports/account.go
Normal file
35
backend/internal/service/ports/account.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type AccountRepository interface {
|
||||
Create(ctx context.Context, account *model.Account) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Account, error)
|
||||
Update(ctx context.Context, account *model.Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
ListActive(ctx context.Context) ([]model.Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error)
|
||||
|
||||
UpdateLastUsed(ctx context.Context, id int64) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]model.Account, error)
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
}
|
||||
24
backend/internal/service/ports/api_key.go
Normal file
24
backend/internal/service/ports/api_key.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
Create(ctx context.Context, key *model.ApiKey) error
|
||||
GetByID(ctx context.Context, id int64) (*model.ApiKey, error)
|
||||
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
|
||||
Update(ctx context.Context, key *model.ApiKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
|
||||
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
16
backend/internal/service/ports/api_key_cache.go
Normal file
16
backend/internal/service/ports/api_key_cache.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ApiKeyCache defines cache operations for API key service
|
||||
type ApiKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
|
||||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
}
|
||||
31
backend/internal/service/ports/billing_cache.go
Normal file
31
backend/internal/service/ports/billing_cache.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SubscriptionCacheData represents cached subscription data
|
||||
type SubscriptionCacheData struct {
|
||||
Status string
|
||||
ExpiresAt time.Time
|
||||
DailyUsage float64
|
||||
WeeklyUsage float64
|
||||
MonthlyUsage float64
|
||||
Version int64
|
||||
}
|
||||
|
||||
// BillingCache defines cache operations for billing service
|
||||
type BillingCache interface {
|
||||
// Balance operations
|
||||
GetUserBalance(ctx context.Context, userID int64) (float64, error)
|
||||
SetUserBalance(ctx context.Context, userID int64, balance float64) error
|
||||
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
|
||||
InvalidateUserBalance(ctx context.Context, userID int64) error
|
||||
|
||||
// Subscription operations
|
||||
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
|
||||
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
|
||||
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
|
||||
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
|
||||
}
|
||||
19
backend/internal/service/ports/concurrency_cache.go
Normal file
19
backend/internal/service/ports/concurrency_cache.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package ports
|
||||
|
||||
import "context"
|
||||
|
||||
// ConcurrencyCache defines cache operations for concurrency service
|
||||
type ConcurrencyCache interface {
|
||||
// Slot management
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (bool, error)
|
||||
ReleaseUserSlot(ctx context.Context, userID int64) error
|
||||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
// Wait queue
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
}
|
||||
20
backend/internal/service/ports/email_cache.go
Normal file
20
backend/internal/service/ports/email_cache.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// VerificationCodeData represents verification code data
|
||||
type VerificationCodeData struct {
|
||||
Code string
|
||||
Attempts int
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// EmailCache defines cache operations for email service
|
||||
type EmailCache interface {
|
||||
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
|
||||
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
|
||||
DeleteVerificationCode(ctx context.Context, email string) error
|
||||
}
|
||||
13
backend/internal/service/ports/gateway_cache.go
Normal file
13
backend/internal/service/ports/gateway_cache.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GatewayCache defines cache operations for gateway service
|
||||
type GatewayCache interface {
|
||||
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
|
||||
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
|
||||
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
||||
}
|
||||
28
backend/internal/service/ports/group.go
Normal file
28
backend/internal/service/ports/group.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GroupRepository interface {
|
||||
Create(ctx context.Context, group *model.Group) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Group, error)
|
||||
Update(ctx context.Context, group *model.Group) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]model.Group, error)
|
||||
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error)
|
||||
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
||||
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
|
||||
DB() *gorm.DB
|
||||
}
|
||||
21
backend/internal/service/ports/identity_cache.go
Normal file
21
backend/internal/service/ports/identity_cache.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package ports
|
||||
|
||||
import "context"
|
||||
|
||||
// Fingerprint represents account fingerprint data
|
||||
type Fingerprint struct {
|
||||
ClientID string
|
||||
UserAgent string
|
||||
StainlessLang string
|
||||
StainlessPackageVersion string
|
||||
StainlessOS string
|
||||
StainlessArch string
|
||||
StainlessRuntime string
|
||||
StainlessRuntimeVersion string
|
||||
}
|
||||
|
||||
// IdentityCache defines cache operations for identity service
|
||||
type IdentityCache interface {
|
||||
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
|
||||
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
|
||||
}
|
||||
23
backend/internal/service/ports/proxy.go
Normal file
23
backend/internal/service/ports/proxy.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type ProxyRepository interface {
|
||||
Create(ctx context.Context, proxy *model.Proxy) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Proxy, error)
|
||||
Update(ctx context.Context, proxy *model.Proxy) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]model.Proxy, error)
|
||||
ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
|
||||
|
||||
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
|
||||
}
|
||||
15
backend/internal/service/ports/redeem_cache.go
Normal file
15
backend/internal/service/ports/redeem_cache.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RedeemCache defines cache operations for redeem service
|
||||
type RedeemCache interface {
|
||||
GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementRedeemAttemptCount(ctx context.Context, userID int64) error
|
||||
|
||||
AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error)
|
||||
ReleaseRedeemLock(ctx context.Context, code string) error
|
||||
}
|
||||
22
backend/internal/service/ports/redeem_code.go
Normal file
22
backend/internal/service/ports/redeem_code.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type RedeemCodeRepository interface {
|
||||
Create(ctx context.Context, code *model.RedeemCode) error
|
||||
CreateBatch(ctx context.Context, codes []model.RedeemCode) error
|
||||
GetByID(ctx context.Context, id int64) (*model.RedeemCode, error)
|
||||
GetByCode(ctx context.Context, code string) (*model.RedeemCode, error)
|
||||
Update(ctx context.Context, code *model.RedeemCode) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
Use(ctx context.Context, id, userID int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error)
|
||||
ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error)
|
||||
}
|
||||
17
backend/internal/service/ports/setting.go
Normal file
17
backend/internal/service/ports/setting.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
)
|
||||
|
||||
type SettingRepository interface {
|
||||
Get(ctx context.Context, key string) (*model.Setting, error)
|
||||
GetValue(ctx context.Context, key string) (string, error)
|
||||
Set(ctx context.Context, key, value string) error
|
||||
GetMultiple(ctx context.Context, keys []string) (map[string]string, error)
|
||||
SetMultiple(ctx context.Context, settings map[string]string) error
|
||||
GetAll(ctx context.Context) (map[string]string, error)
|
||||
Delete(ctx context.Context, key string) error
|
||||
}
|
||||
12
backend/internal/service/ports/update_cache.go
Normal file
12
backend/internal/service/ports/update_cache.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UpdateCache defines cache operations for update service
|
||||
type UpdateCache interface {
|
||||
GetUpdateInfo(ctx context.Context) (string, error)
|
||||
SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error
|
||||
}
|
||||
28
backend/internal/service/ports/usage_log.go
Normal file
28
backend/internal/service/ports/usage_log.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
Create(ctx context.Context, log *model.UsageLog) error
|
||||
GetByID(ctx context.Context, id int64) (*model.UsageLog, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
|
||||
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
|
||||
}
|
||||
26
backend/internal/service/ports/user.go
Normal file
26
backend/internal/service/ports/user.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
Create(ctx context.Context, user *model.User) error
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
GetByEmail(ctx context.Context, email string) (*model.User, error)
|
||||
GetFirstAdmin(ctx context.Context) (*model.User, error)
|
||||
Update(ctx context.Context, user *model.User) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error)
|
||||
|
||||
UpdateBalance(ctx context.Context, id int64, amount float64) error
|
||||
DeductBalance(ctx context.Context, id int64, amount float64) error
|
||||
UpdateConcurrency(ctx context.Context, id int64, amount int) error
|
||||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||||
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
36
backend/internal/service/ports/user_subscription.go
Normal file
36
backend/internal/service/ports/user_subscription.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type UserSubscriptionRepository interface {
|
||||
Create(ctx context.Context, sub *model.UserSubscription) error
|
||||
GetByID(ctx context.Context, id int64) (*model.UserSubscription, error)
|
||||
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
|
||||
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
|
||||
Update(ctx context.Context, sub *model.UserSubscription) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
|
||||
ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error)
|
||||
|
||||
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
|
||||
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
|
||||
UpdateStatus(ctx context.Context, subscriptionID int64, status string) error
|
||||
UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error
|
||||
|
||||
ActivateWindows(ctx context.Context, id int64, start time.Time) error
|
||||
ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
|
||||
ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
|
||||
ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
|
||||
IncrementUsage(ctx context.Context, id int64, costUSD float64) error
|
||||
|
||||
BatchUpdateExpiredStatus(ctx context.Context) (int64, error)
|
||||
}
|
||||
@@ -1,13 +1,12 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -20,13 +19,19 @@ import (
|
||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||
type LiteLLMModelPricing struct {
|
||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
}
|
||||
|
||||
// PricingRemoteClient 远程价格数据获取接口
|
||||
type PricingRemoteClient interface {
|
||||
FetchPricingJSON(ctx context.Context, url string) ([]byte, error)
|
||||
FetchHashText(ctx context.Context, url string) (string, error)
|
||||
}
|
||||
|
||||
// LiteLLMRawEntry 用于解析原始JSON数据
|
||||
@@ -42,11 +47,12 @@ type LiteLLMRawEntry struct {
|
||||
|
||||
// PricingService 动态价格服务
|
||||
type PricingService struct {
|
||||
cfg *config.Config
|
||||
mu sync.RWMutex
|
||||
pricingData map[string]*LiteLLMModelPricing
|
||||
lastUpdated time.Time
|
||||
localHash string
|
||||
cfg *config.Config
|
||||
remoteClient PricingRemoteClient
|
||||
mu sync.RWMutex
|
||||
pricingData map[string]*LiteLLMModelPricing
|
||||
lastUpdated time.Time
|
||||
localHash string
|
||||
|
||||
// 停止信号
|
||||
stopCh chan struct{}
|
||||
@@ -54,11 +60,12 @@ type PricingService struct {
|
||||
}
|
||||
|
||||
// NewPricingService 创建价格服务
|
||||
func NewPricingService(cfg *config.Config) *PricingService {
|
||||
func NewPricingService(cfg *config.Config, remoteClient PricingRemoteClient) *PricingService {
|
||||
s := &PricingService{
|
||||
cfg: cfg,
|
||||
pricingData: make(map[string]*LiteLLMModelPricing),
|
||||
stopCh: make(chan struct{}),
|
||||
cfg: cfg,
|
||||
remoteClient: remoteClient,
|
||||
pricingData: make(map[string]*LiteLLMModelPricing),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -199,21 +206,13 @@ func (s *PricingService) syncWithRemote() error {
|
||||
func (s *PricingService) downloadPricingData() error {
|
||||
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Get(s.cfg.Pricing.RemoteURL)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download failed: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read response failed: %w", err)
|
||||
}
|
||||
|
||||
// 解析JSON数据(使用灵活的解析方式)
|
||||
data, err := s.parsePricingData(body)
|
||||
@@ -367,29 +366,10 @@ func (s *PricingService) useFallbackPricing() error {
|
||||
|
||||
// fetchRemoteHash 从远程获取哈希值
|
||||
func (s *PricingService) fetchRemoteHash() (string, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Get(s.cfg.Pricing.HashURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 哈希文件格式:hash filename 或者纯 hash
|
||||
hash := strings.TrimSpace(string(body))
|
||||
parts := strings.Fields(hash)
|
||||
if len(parts) > 0 {
|
||||
return parts[0], nil
|
||||
}
|
||||
return hash, nil
|
||||
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
|
||||
}
|
||||
|
||||
// computeFileHash 计算文件哈希
|
||||
@@ -466,14 +446,14 @@ func (s *PricingService) extractBaseName(model string) string {
|
||||
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// Claude模型系列匹配规则
|
||||
familyPatterns := map[string][]string{
|
||||
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
|
||||
"opus-4": {"claude-opus-4", "claude-3-opus"},
|
||||
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
|
||||
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
|
||||
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
|
||||
"sonnet-3": {"claude-3-sonnet"},
|
||||
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
|
||||
"haiku-3": {"claude-3-haiku"},
|
||||
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
|
||||
"opus-4": {"claude-opus-4", "claude-3-opus"},
|
||||
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
|
||||
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
|
||||
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
|
||||
"sonnet-3": {"claude-3-sonnet"},
|
||||
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
|
||||
"haiku-3": {"claude-3-haiku"},
|
||||
}
|
||||
|
||||
// 确定模型属于哪个系列
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -37,11 +38,11 @@ type UpdateProxyRequest struct {
|
||||
|
||||
// ProxyService 代理管理服务
|
||||
type ProxyService struct {
|
||||
proxyRepo *repository.ProxyRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
}
|
||||
|
||||
// NewProxyService 创建代理服务实例
|
||||
func NewProxyService(proxyRepo *repository.ProxyRepository) *ProxyService {
|
||||
func NewProxyService(proxyRepo ports.ProxyRepository) *ProxyService {
|
||||
return &ProxyService{
|
||||
proxyRepo: proxyRepo,
|
||||
}
|
||||
@@ -80,7 +81,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err
|
||||
}
|
||||
|
||||
// List 获取代理列表
|
||||
func (s *ProxyService) List(ctx context.Context, params repository.PaginationParams) ([]model.Proxy, *repository.PaginationResult, error) {
|
||||
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
proxies, pagination, err := s.proxyRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list proxies: %w", err)
|
||||
|
||||
@@ -9,20 +9,20 @@ import (
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// RateLimitService 处理限流和过载状态管理
|
||||
type RateLimitService struct {
|
||||
repos *repository.Repositories
|
||||
cfg *config.Config
|
||||
accountRepo ports.AccountRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewRateLimitService 创建RateLimitService实例
|
||||
func NewRateLimitService(repos *repository.Repositories, cfg *config.Config) *RateLimitService {
|
||||
func NewRateLimitService(accountRepo ports.AccountRepository, cfg *config.Config) *RateLimitService {
|
||||
return &RateLimitService{
|
||||
repos: repos,
|
||||
cfg: cfg,
|
||||
accountRepo: accountRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod
|
||||
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
|
||||
if err := s.repos.Account.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
if resetTimestamp == "" {
|
||||
// 没有重置时间,使用默认5分钟
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
return
|
||||
@@ -88,7 +88,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
if err != nil {
|
||||
log.Printf("Parse reset timestamp failed: %v", err)
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
return
|
||||
@@ -97,7 +97,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
resetAt := time.Unix(ts, 0)
|
||||
|
||||
// 标记限流状态
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
@@ -105,7 +105,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
// 根据重置时间反推5h窗口
|
||||
windowEnd := resetAt
|
||||
windowStart := resetAt.Add(-5 * time.Hour)
|
||||
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account
|
||||
}
|
||||
|
||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||
if err := s.repos.Account.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
@@ -152,13 +152,13 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
|
||||
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
|
||||
}
|
||||
|
||||
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
||||
if status == "allowed" && account.IsRateLimited() {
|
||||
if err := s.repos.Account.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
@@ -166,5 +166,5 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
|
||||
|
||||
// ClearRateLimit 清除账号的限流状态
|
||||
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
|
||||
return s.repos.Account.ClearRateLimit(ctx, accountID)
|
||||
return s.accountRepo.ClearRateLimit(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -8,7 +8,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -25,11 +26,9 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
redeemRateLimitKeyPrefix = "redeem:rate_limit:"
|
||||
redeemLockKeyPrefix = "redeem:lock:"
|
||||
redeemMaxErrorsPerHour = 20
|
||||
redeemRateLimitDuration = time.Hour
|
||||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||||
redeemMaxErrorsPerHour = 20
|
||||
redeemRateLimitDuration = time.Hour
|
||||
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
|
||||
)
|
||||
|
||||
// GenerateCodesRequest 生成兑换码请求
|
||||
@@ -49,26 +48,26 @@ type RedeemCodeResponse struct {
|
||||
|
||||
// RedeemService 兑换码服务
|
||||
type RedeemService struct {
|
||||
redeemRepo *repository.RedeemCodeRepository
|
||||
userRepo *repository.UserRepository
|
||||
redeemRepo ports.RedeemCodeRepository
|
||||
userRepo ports.UserRepository
|
||||
subscriptionService *SubscriptionService
|
||||
rdb *redis.Client
|
||||
cache ports.RedeemCache
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
// NewRedeemService 创建兑换码服务实例
|
||||
func NewRedeemService(
|
||||
redeemRepo *repository.RedeemCodeRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
redeemRepo ports.RedeemCodeRepository,
|
||||
userRepo ports.UserRepository,
|
||||
subscriptionService *SubscriptionService,
|
||||
rdb *redis.Client,
|
||||
cache ports.RedeemCache,
|
||||
billingCacheService *BillingCacheService,
|
||||
) *RedeemService {
|
||||
return &RedeemService{
|
||||
redeemRepo: redeemRepo,
|
||||
userRepo: userRepo,
|
||||
subscriptionService: subscriptionService,
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
}
|
||||
@@ -139,13 +138,11 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
|
||||
|
||||
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
|
||||
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
|
||||
count, err := s.rdb.Get(ctx, key).Int()
|
||||
count, err := s.cache.GetRedeemAttemptCount(ctx, userID)
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
@@ -160,27 +157,21 @@ func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64)
|
||||
|
||||
// incrementRedeemErrorCount 增加用户兑换错误计数
|
||||
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
|
||||
pipe := s.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||
_, _ = pipe.Exec(ctx)
|
||||
_ = s.cache.IncrementRedeemAttemptCount(ctx, userID)
|
||||
}
|
||||
|
||||
// acquireRedeemLock 尝试获取兑换码的分布式锁
|
||||
// 返回 true 表示获取成功,false 表示锁已被占用
|
||||
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return true // 无 Redis 时降级为不加锁
|
||||
}
|
||||
|
||||
key := redeemLockKeyPrefix + code
|
||||
ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result()
|
||||
ok, err := s.cache.AcquireRedeemLock(ctx, code, redeemLockDuration)
|
||||
if err != nil {
|
||||
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
|
||||
return true
|
||||
@@ -190,12 +181,11 @@ func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool
|
||||
|
||||
// releaseRedeemLock 释放兑换码的分布式锁
|
||||
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
|
||||
if s.rdb == nil {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
key := redeemLockKeyPrefix + code
|
||||
s.rdb.Del(ctx, key)
|
||||
_ = s.cache.ReleaseRedeemLock(ctx, code)
|
||||
}
|
||||
|
||||
// Redeem 使用兑换码
|
||||
@@ -337,7 +327,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede
|
||||
}
|
||||
|
||||
// List 获取兑换码列表(管理员功能)
|
||||
func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) {
|
||||
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
codes, pagination, err := s.redeemRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
|
||||
|
||||
@@ -26,4 +26,6 @@ type Services struct {
|
||||
Subscription *SubscriptionService
|
||||
Concurrency *ConcurrencyService
|
||||
Identity *IdentityService
|
||||
Update *UpdateService
|
||||
TokenRefresh *TokenRefreshService
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -18,12 +20,12 @@ var (
|
||||
|
||||
// SettingService 系统设置服务
|
||||
type SettingService struct {
|
||||
settingRepo *repository.SettingRepository
|
||||
settingRepo ports.SettingRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewSettingService 创建系统设置服务实例
|
||||
func NewSettingService(settingRepo *repository.SettingRepository, cfg *config.Config) *SettingService {
|
||||
func NewSettingService(settingRepo ports.SettingRepository, cfg *config.Config) *SettingService {
|
||||
return &SettingService{
|
||||
settingRepo: settingRepo,
|
||||
cfg: cfg,
|
||||
@@ -262,3 +264,63 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// GenerateAdminApiKey 生成新的管理员 API Key
|
||||
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
|
||||
// 生成 32 字节随机数 = 64 位十六进制字符
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
key := model.AdminApiKeyPrefix + hex.EncodeToString(bytes)
|
||||
|
||||
// 存储到 settings 表
|
||||
if err := s.settingRepo.Set(ctx, model.SettingKeyAdminApiKey, key); err != nil {
|
||||
return "", fmt.Errorf("save admin api key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GetAdminApiKeyStatus 获取管理员 API Key 状态
|
||||
// 返回脱敏的 key、是否存在、错误
|
||||
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", false, nil
|
||||
}
|
||||
return "", false, err
|
||||
}
|
||||
if key == "" {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
// 脱敏:显示前 10 位和后 4 位
|
||||
if len(key) > 14 {
|
||||
maskedKey = key[:10] + "..." + key[len(key)-4:]
|
||||
} else {
|
||||
maskedKey = key
|
||||
}
|
||||
|
||||
return maskedKey, true, nil
|
||||
}
|
||||
|
||||
// GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用)
|
||||
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
|
||||
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", nil // 未配置,返回空字符串
|
||||
}
|
||||
return "", err // 数据库错误
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
|
||||
return s.settingRepo.Delete(ctx, model.SettingKeyAdminApiKey)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -23,14 +24,16 @@ var (
|
||||
|
||||
// SubscriptionService 订阅服务
|
||||
type SubscriptionService struct {
|
||||
repos *repository.Repositories
|
||||
groupRepo ports.GroupRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
// NewSubscriptionService 创建订阅服务
|
||||
func NewSubscriptionService(repos *repository.Repositories, billingCacheService *BillingCacheService) *SubscriptionService {
|
||||
func NewSubscriptionService(groupRepo ports.GroupRepository, userSubRepo ports.UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
|
||||
return &SubscriptionService{
|
||||
repos: repos,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
}
|
||||
@@ -47,7 +50,7 @@ type AssignSubscriptionInput struct {
|
||||
// AssignSubscription 分配订阅给用户(不允许重复分配)
|
||||
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
@@ -56,7 +59,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
}
|
||||
|
||||
// 检查是否已存在订阅
|
||||
exists, err := s.repos.UserSubscription.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -90,7 +93,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
// 如果没有订阅:创建新订阅
|
||||
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
@@ -99,7 +102,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 查询是否已有订阅
|
||||
existingSub, err := s.repos.UserSubscription.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
existingSub, err := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if err != nil {
|
||||
// 不存在记录是正常情况,其他错误需要返回
|
||||
existingSub = nil
|
||||
@@ -124,13 +127,13 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 更新过期时间
|
||||
if err := s.repos.UserSubscription.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
|
||||
if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
|
||||
return nil, false, fmt.Errorf("extend subscription: %w", err)
|
||||
}
|
||||
|
||||
// 如果订阅已过期或被暂停,恢复为active状态
|
||||
if existingSub.Status != model.SubscriptionStatusActive {
|
||||
if err := s.repos.UserSubscription.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
|
||||
return nil, false, fmt.Errorf("update subscription status: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -142,7 +145,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
newNotes += "\n"
|
||||
}
|
||||
newNotes += input.Notes
|
||||
if err := s.repos.UserSubscription.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
// 备注更新失败不影响主流程
|
||||
}
|
||||
}
|
||||
@@ -158,7 +161,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 返回更新后的订阅
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, existingSub.ID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, existingSub.ID)
|
||||
return sub, true, err // true 表示是续期
|
||||
}
|
||||
|
||||
@@ -205,12 +208,12 @@ func (s *SubscriptionService) createSubscription(ctx context.Context, input *Ass
|
||||
sub.AssignedBy = &input.AssignedBy
|
||||
}
|
||||
|
||||
if err := s.repos.UserSubscription.Create(ctx, sub); err != nil {
|
||||
if err := s.userSubRepo.Create(ctx, sub); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 重新获取完整订阅信息(包含关联)
|
||||
return s.repos.UserSubscription.GetByID(ctx, sub.ID)
|
||||
return s.userSubRepo.GetByID(ctx, sub.ID)
|
||||
}
|
||||
|
||||
// BulkAssignSubscriptionInput 批量分配订阅输入
|
||||
@@ -260,12 +263,12 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
|
||||
// RevokeSubscription 撤销订阅
|
||||
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
|
||||
// 先获取订阅信息用于失效缓存
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.repos.UserSubscription.Delete(ctx, subscriptionID); err != nil {
|
||||
if err := s.userSubRepo.Delete(ctx, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -284,20 +287,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
||||
|
||||
// ExtendSubscription 延长订阅
|
||||
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
// 计算新的过期时间
|
||||
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
|
||||
if err := s.repos.UserSubscription.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
|
||||
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果订阅已过期,恢复为active状态
|
||||
if sub.Status == model.SubscriptionStatusExpired {
|
||||
if err := s.repos.UserSubscription.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -312,17 +315,17 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
}()
|
||||
}
|
||||
|
||||
return s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
return s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取订阅
|
||||
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.GetByID(ctx, id)
|
||||
return s.userSubRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// GetActiveSubscription 获取用户对特定分组的有效订阅
|
||||
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
sub, err := s.repos.UserSubscription.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
@@ -331,24 +334,29 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
|
||||
|
||||
// ListUserSubscriptions 获取用户的所有订阅
|
||||
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.ListByUserID(ctx, userID)
|
||||
return s.userSubRepo.ListByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// ListActiveUserSubscriptions 获取用户的所有有效订阅
|
||||
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
|
||||
return s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// ListGroupSubscriptions 获取分组的所有订阅
|
||||
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *repository.PaginationResult, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.repos.UserSubscription.ListByGroupID(ctx, groupID, params)
|
||||
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.userSubRepo.ListByGroupID(ctx, groupID, params)
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *repository.PaginationResult, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.repos.UserSubscription.List(ctx, params, userID, groupID, status)
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.userSubRepo.List(ctx, params, userID, groupID, status)
|
||||
}
|
||||
|
||||
// startOfDay 返回给定时间所在日期的零点(保持原时区)
|
||||
func startOfDay(t time.Time) time.Time {
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||
}
|
||||
|
||||
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
|
||||
@@ -357,39 +365,50 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
return s.repos.UserSubscription.ActivateWindows(ctx, sub.ID, now)
|
||||
// 使用当天零点作为窗口起始时间
|
||||
windowStart := startOfDay(time.Now())
|
||||
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
|
||||
}
|
||||
|
||||
// CheckAndResetWindows 检查并重置过期的窗口
|
||||
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error {
|
||||
now := time.Now()
|
||||
// 使用当天零点作为新窗口起始时间
|
||||
windowStart := startOfDay(time.Now())
|
||||
needsInvalidateCache := false
|
||||
|
||||
// 日窗口重置(24小时)
|
||||
if sub.NeedsDailyReset() {
|
||||
if err := s.repos.UserSubscription.ResetDailyUsage(ctx, sub.ID, now); err != nil {
|
||||
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.DailyWindowStart = &now
|
||||
sub.DailyWindowStart = &windowStart
|
||||
sub.DailyUsageUSD = 0
|
||||
needsInvalidateCache = true
|
||||
}
|
||||
|
||||
// 周窗口重置(7天)
|
||||
if sub.NeedsWeeklyReset() {
|
||||
if err := s.repos.UserSubscription.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
|
||||
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.WeeklyWindowStart = &now
|
||||
sub.WeeklyWindowStart = &windowStart
|
||||
sub.WeeklyUsageUSD = 0
|
||||
needsInvalidateCache = true
|
||||
}
|
||||
|
||||
// 月窗口重置(30天)
|
||||
if sub.NeedsMonthlyReset() {
|
||||
if err := s.repos.UserSubscription.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
|
||||
if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, windowStart); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.MonthlyWindowStart = &now
|
||||
sub.MonthlyWindowStart = &windowStart
|
||||
sub.MonthlyUsageUSD = 0
|
||||
needsInvalidateCache = true
|
||||
}
|
||||
|
||||
// 如果有窗口被重置,失效 Redis 缓存以保持一致性
|
||||
if needsInvalidateCache && s.billingCacheService != nil {
|
||||
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -411,7 +430,7 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.U
|
||||
|
||||
// RecordUsage 记录使用量到订阅
|
||||
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
|
||||
return s.repos.UserSubscription.IncrementUsage(ctx, subscriptionID, costUSD)
|
||||
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
|
||||
}
|
||||
|
||||
// SubscriptionProgress 订阅进度
|
||||
@@ -438,14 +457,14 @@ type UsageWindowProgress struct {
|
||||
|
||||
// GetSubscriptionProgress 获取订阅使用进度
|
||||
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
group := sub.Group
|
||||
if group == nil {
|
||||
group, err = s.repos.Group.GetByID(ctx, sub.GroupID)
|
||||
group, err = s.groupRepo.GetByID(ctx, sub.GroupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -535,7 +554,7 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
|
||||
|
||||
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
|
||||
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
|
||||
subs, err := s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
|
||||
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -554,7 +573,7 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
|
||||
|
||||
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
|
||||
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
|
||||
return s.repos.UserSubscription.BatchUpdateExpiredStatus(ctx)
|
||||
return s.userSubRepo.BatchUpdateExpiredStatus(ctx)
|
||||
}
|
||||
|
||||
// ValidateSubscription 验证订阅是否有效
|
||||
@@ -567,7 +586,7 @@ func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *mod
|
||||
}
|
||||
if sub.IsExpired() {
|
||||
// 更新状态
|
||||
_ = s.repos.UserSubscription.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
|
||||
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
|
||||
return ErrSubscriptionExpired
|
||||
}
|
||||
return nil
|
||||
|
||||
185
backend/internal/service/token_refresh_service.go
Normal file
185
backend/internal/service/token_refresh_service.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// TokenRefreshService OAuth token自动刷新服务
|
||||
// 定期检查并刷新即将过期的token
|
||||
type TokenRefreshService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
refreshers []TokenRefresher
|
||||
cfg *config.TokenRefreshConfig
|
||||
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewTokenRefreshService 创建token刷新服务
|
||||
func NewTokenRefreshService(
|
||||
accountRepo ports.AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
s := &TokenRefreshService{
|
||||
accountRepo: accountRepo,
|
||||
cfg: &cfg.TokenRefresh,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 注册平台特定的刷新器
|
||||
s.refreshers = []TokenRefresher{
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
// 未来可以添加其他平台的刷新器:
|
||||
// NewOpenAITokenRefresher(...),
|
||||
// NewGeminiTokenRefresher(...),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start 启动后台刷新服务
|
||||
func (s *TokenRefreshService) Start() {
|
||||
if !s.cfg.Enabled {
|
||||
log.Println("[TokenRefresh] Service disabled by configuration")
|
||||
return
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.refreshLoop()
|
||||
|
||||
log.Printf("[TokenRefresh] Service started (check every %d minutes, refresh %v hours before expiry)",
|
||||
s.cfg.CheckIntervalMinutes, s.cfg.RefreshBeforeExpiryHours)
|
||||
}
|
||||
|
||||
// Stop 停止刷新服务
|
||||
func (s *TokenRefreshService) Stop() {
|
||||
close(s.stopCh)
|
||||
s.wg.Wait()
|
||||
log.Println("[TokenRefresh] Service stopped")
|
||||
}
|
||||
|
||||
// refreshLoop 刷新循环
|
||||
func (s *TokenRefreshService) refreshLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
// 计算检查间隔
|
||||
checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute
|
||||
if checkInterval < time.Minute {
|
||||
checkInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动时立即执行一次检查
|
||||
s.processRefresh()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.processRefresh()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processRefresh 执行一次刷新检查
|
||||
func (s *TokenRefreshService) processRefresh() {
|
||||
ctx := context.Background()
|
||||
|
||||
// 计算刷新窗口
|
||||
refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour))
|
||||
|
||||
// 获取所有active状态的账号
|
||||
accounts, err := s.listActiveAccounts(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to list accounts: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
refreshed, failed := 0, 0
|
||||
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
|
||||
// 遍历所有刷新器,找到能处理此账号的
|
||||
for _, refresher := range s.refreshers {
|
||||
if !refresher.CanRefresh(account) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否需要刷新
|
||||
if !refresher.NeedsRefresh(account, refreshWindow) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 执行刷新
|
||||
if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
|
||||
log.Printf("[TokenRefresh] Account %d (%s) failed: %v", account.ID, account.Name, err)
|
||||
failed++
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Account %d (%s) refreshed successfully", account.ID, account.Name)
|
||||
refreshed++
|
||||
}
|
||||
|
||||
// 每个账号只由一个refresher处理
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if refreshed > 0 || failed > 0 {
|
||||
log.Printf("[TokenRefresh] Cycle complete: %d refreshed, %d failed", refreshed, failed)
|
||||
}
|
||||
}
|
||||
|
||||
// listActiveAccounts 获取所有active状态的账号
|
||||
// 使用ListActive确保刷新所有活跃账号的token(包括临时禁用的)
|
||||
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]model.Account, error) {
|
||||
return s.accountRepo.ListActive(ctx)
|
||||
}
|
||||
|
||||
// refreshWithRetry 带重试的刷新
|
||||
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *model.Account, refresher TokenRefresher) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||
newCredentials, err := refresher.Refresh(ctx, account)
|
||||
if err == nil {
|
||||
// 刷新成功,更新账号credentials
|
||||
account.Credentials = model.JSONB(newCredentials)
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Printf("[TokenRefresh] Account %d attempt %d/%d failed: %v",
|
||||
account.ID, attempt, s.cfg.MaxRetries, err)
|
||||
|
||||
// 如果还有重试机会,等待后重试
|
||||
if attempt < s.cfg.MaxRetries {
|
||||
// 指数退避:2^(attempt-1) * baseSeconds
|
||||
backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1))
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
}
|
||||
|
||||
// 所有重试都失败,标记账号为error状态
|
||||
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
90
backend/internal/service/token_refresher.go
Normal file
90
backend/internal/service/token_refresher.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
)
|
||||
|
||||
// TokenRefresher 定义平台特定的token刷新策略接口
|
||||
// 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini)
|
||||
type TokenRefresher interface {
|
||||
// CanRefresh 检查此刷新器是否能处理指定账号
|
||||
CanRefresh(account *model.Account) bool
|
||||
|
||||
// NeedsRefresh 检查账号的token是否需要刷新
|
||||
NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool
|
||||
|
||||
// Refresh 执行token刷新,返回更新后的credentials
|
||||
// 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段
|
||||
Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
|
||||
type ClaudeTokenRefresher struct {
|
||||
oauthService *OAuthService
|
||||
}
|
||||
|
||||
// NewClaudeTokenRefresher 创建Claude token刷新器
|
||||
func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
|
||||
return &ClaudeTokenRefresher{
|
||||
oauthService: oauthService,
|
||||
}
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 anthropic 平台的 oauth 类型账号
|
||||
// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新
|
||||
func (r *ClaudeTokenRefresher) CanRefresh(account *model.Account) bool {
|
||||
return account.Platform == model.PlatformAnthropic &&
|
||||
account.Type == model.AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
// 基于 expires_at 字段判断是否在刷新窗口内
|
||||
func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
|
||||
expiresAtStr := account.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
expiryTime := time.Unix(expiresAt, 0)
|
||||
return time.Until(expiryTime) < refreshWindow
|
||||
}
|
||||
|
||||
// Refresh 执行token刷新
|
||||
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error) {
|
||||
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留现有credentials中的所有字段
|
||||
newCredentials := make(map[string]interface{})
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
|
||||
// 只更新token相关字段
|
||||
// 注意:expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
@@ -2,14 +2,9 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -19,10 +14,15 @@ var (
|
||||
|
||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
// TurnstileVerifier 验证 Turnstile token 的接口
|
||||
type TurnstileVerifier interface {
|
||||
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
|
||||
}
|
||||
|
||||
// TurnstileService Turnstile 验证服务
|
||||
type TurnstileService struct {
|
||||
settingService *SettingService
|
||||
httpClient *http.Client
|
||||
verifier TurnstileVerifier
|
||||
}
|
||||
|
||||
// TurnstileVerifyResponse Cloudflare Turnstile 验证响应
|
||||
@@ -36,12 +36,10 @@ type TurnstileVerifyResponse struct {
|
||||
}
|
||||
|
||||
// NewTurnstileService 创建 Turnstile 服务实例
|
||||
func NewTurnstileService(settingService *SettingService) *TurnstileService {
|
||||
func NewTurnstileService(settingService *SettingService, verifier TurnstileVerifier) *TurnstileService {
|
||||
return &TurnstileService{
|
||||
settingService: settingService,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
verifier: verifier,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,35 +64,12 @@ func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remote
|
||||
return ErrTurnstileVerificationFailed
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
formData := url.Values{}
|
||||
formData.Set("secret", secretKey)
|
||||
formData.Set("response", token)
|
||||
if remoteIP != "" {
|
||||
formData.Set("remoteip", remoteIP)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
// 发送请求
|
||||
log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP)
|
||||
resp, err := s.httpClient.Do(req)
|
||||
result, err := s.verifier.VerifyToken(ctx, secretKey, token, remoteIP)
|
||||
if err != nil {
|
||||
log.Printf("[Turnstile] Request failed: %v", err)
|
||||
return fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 解析响应
|
||||
var result TurnstileVerifyResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
log.Printf("[Turnstile] Failed to decode response: %v", err)
|
||||
return fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -18,7 +17,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -34,17 +33,26 @@ const (
|
||||
maxDownloadSize = 500 * 1024 * 1024
|
||||
)
|
||||
|
||||
// GitHubReleaseClient 获取 GitHub release 信息的接口
|
||||
type GitHubReleaseClient interface {
|
||||
FetchLatestRelease(ctx context.Context, repo string) (*GitHubRelease, error)
|
||||
DownloadFile(ctx context.Context, url, dest string, maxSize int64) error
|
||||
FetchChecksumFile(ctx context.Context, url string) ([]byte, error)
|
||||
}
|
||||
|
||||
// UpdateService handles software updates
|
||||
type UpdateService struct {
|
||||
rdb *redis.Client
|
||||
cache ports.UpdateCache
|
||||
githubClient GitHubReleaseClient
|
||||
currentVersion string
|
||||
buildType string // "source" for manual builds, "release" for CI builds
|
||||
}
|
||||
|
||||
// NewUpdateService creates a new UpdateService
|
||||
func NewUpdateService(rdb *redis.Client, version, buildType string) *UpdateService {
|
||||
func NewUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, version, buildType string) *UpdateService {
|
||||
return &UpdateService{
|
||||
rdb: rdb,
|
||||
cache: cache,
|
||||
githubClient: githubClient,
|
||||
currentVersion: version,
|
||||
buildType: buildType,
|
||||
}
|
||||
@@ -260,42 +268,11 @@ func (s *UpdateService) Rollback() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
release, err := s.githubClient.FetchLatestRelease(ctx, githubRepo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "Sub2API-Updater")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return &UpdateInfo{
|
||||
CurrentVersion: s.currentVersion,
|
||||
LatestVersion: s.currentVersion,
|
||||
HasUpdate: false,
|
||||
Warning: "No releases found",
|
||||
BuildType: s.buildType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release GitHubRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
latestVersion := strings.TrimPrefix(release.TagName, "v")
|
||||
|
||||
@@ -325,47 +302,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
|
||||
}
|
||||
|
||||
func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// SECURITY: Check Content-Length if available
|
||||
if resp.ContentLength > maxDownloadSize {
|
||||
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxDownloadSize)
|
||||
}
|
||||
|
||||
out, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxDownloadSize+1)
|
||||
written, err := io.Copy(out, limited)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if we hit the limit (downloaded more than maxDownloadSize)
|
||||
if written > maxDownloadSize {
|
||||
os.Remove(dest) // Clean up partial file
|
||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxDownloadSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
return s.githubClient.DownloadFile(ctx, downloadURL, dest, maxDownloadSize)
|
||||
}
|
||||
|
||||
func (s *UpdateService) getArchiveName() string {
|
||||
@@ -402,20 +339,9 @@ func validateDownloadURL(rawURL string) error {
|
||||
|
||||
func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error {
|
||||
// Download checksums file
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil)
|
||||
checksumData, err := s.githubClient.FetchChecksumFile(ctx, checksumURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("failed to download checksums: %d", resp.StatusCode)
|
||||
return fmt.Errorf("failed to download checksums: %w", err)
|
||||
}
|
||||
|
||||
// Calculate file hash
|
||||
@@ -433,7 +359,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
|
||||
|
||||
// Find expected hash in checksums file
|
||||
fileName := filepath.Base(filePath)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner := bufio.NewScanner(strings.NewReader(string(checksumData)))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
parts := strings.Fields(line)
|
||||
@@ -533,7 +459,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
}
|
||||
|
||||
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
||||
data, err := s.rdb.Get(ctx, updateCacheKey).Result()
|
||||
data, err := s.cache.GetUpdateInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -573,7 +499,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(cacheData)
|
||||
s.rdb.Set(ctx, updateCacheKey, data, time.Duration(updateCacheTTL)*time.Second)
|
||||
s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||
}
|
||||
|
||||
// compareVersions compares two semantic versions
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -41,24 +42,24 @@ type CreateUsageLogRequest struct {
|
||||
|
||||
// UsageStats 使用统计
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
}
|
||||
|
||||
// UsageService 使用统计服务
|
||||
type UsageService struct {
|
||||
usageRepo *repository.UsageLogRepository
|
||||
userRepo *repository.UserRepository
|
||||
usageRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
}
|
||||
|
||||
// NewUsageService 创建使用统计服务实例
|
||||
func NewUsageService(usageRepo *repository.UsageLogRepository, userRepo *repository.UserRepository) *UsageService {
|
||||
func NewUsageService(usageRepo ports.UsageLogRepository, userRepo ports.UserRepository) *UsageService {
|
||||
return &UsageService{
|
||||
usageRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
@@ -127,7 +128,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog,
|
||||
}
|
||||
|
||||
// ListByUser 获取用户的使用日志列表
|
||||
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -136,7 +137,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repo
|
||||
}
|
||||
|
||||
// ListByApiKey 获取API Key的使用日志列表
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -145,7 +146,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params
|
||||
}
|
||||
|
||||
// ListByAccount 获取账号的使用日志列表
|
||||
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -233,15 +234,15 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
|
||||
}
|
||||
|
||||
result = append(result, map[string]interface{}{
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_tokens": stats.TotalCacheTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost,
|
||||
"total_actual_cost": stats.TotalActualCost,
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_tokens": stats.TotalCacheTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost,
|
||||
"total_actual_cost": stats.TotalActualCost,
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,16 +6,17 @@ import (
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrPasswordIncorrect = errors.New("current password is incorrect")
|
||||
ErrInsufficientPerms = errors.New("insufficient permissions")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrPasswordIncorrect = errors.New("current password is incorrect")
|
||||
ErrInsufficientPerms = errors.New("insufficient permissions")
|
||||
)
|
||||
|
||||
// UpdateProfileRequest 更新用户资料请求
|
||||
@@ -32,12 +33,12 @@ type ChangePasswordRequest struct {
|
||||
|
||||
// UserService 用户服务
|
||||
type UserService struct {
|
||||
userRepo *repository.UserRepository
|
||||
userRepo ports.UserRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务实例
|
||||
func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService {
|
||||
func NewUserService(userRepo ports.UserRepository, cfg *config.Config) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
cfg: cfg,
|
||||
@@ -133,7 +134,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error
|
||||
}
|
||||
|
||||
// List 获取用户列表(管理员功能)
|
||||
func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) {
|
||||
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
|
||||
users, pagination, err := s.userRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list users: %w", err)
|
||||
|
||||
@@ -2,13 +2,20 @@ package service
|
||||
|
||||
import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// BuildInfo contains build information
|
||||
type BuildInfo struct {
|
||||
Version string
|
||||
BuildType string
|
||||
}
|
||||
|
||||
// ProvidePricingService creates and initializes PricingService
|
||||
func ProvidePricingService(cfg *config.Config) (*PricingService, error) {
|
||||
svc := NewPricingService(cfg)
|
||||
func ProvidePricingService(cfg *config.Config, remoteClient PricingRemoteClient) (*PricingService, error) {
|
||||
svc := NewPricingService(cfg, remoteClient)
|
||||
if err := svc.Initialize(); err != nil {
|
||||
// 价格服务初始化失败不应阻止启动,使用回退价格
|
||||
println("[Service] Warning: Pricing service initialization failed:", err.Error())
|
||||
@@ -16,11 +23,27 @@ func ProvidePricingService(cfg *config.Config) (*PricingService, error) {
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// ProvideUpdateService creates UpdateService with BuildInfo
|
||||
func ProvideUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, buildInfo BuildInfo) *UpdateService {
|
||||
return NewUpdateService(cache, githubClient, buildInfo.Version, buildInfo.BuildType)
|
||||
}
|
||||
|
||||
// ProvideEmailQueueService creates EmailQueueService with default worker count
|
||||
func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
||||
return NewEmailQueueService(emailService, 3)
|
||||
}
|
||||
|
||||
// ProvideTokenRefreshService creates and starts TokenRefreshService
|
||||
func ProvideTokenRefreshService(
|
||||
accountRepo ports.AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all services
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
@@ -48,6 +71,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewSubscriptionService,
|
||||
NewConcurrencyService,
|
||||
NewIdentityService,
|
||||
ProvideUpdateService,
|
||||
ProvideTokenRefreshService,
|
||||
|
||||
// Provide the Services container struct
|
||||
wire.Struct(new(Services), "*"),
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
<script setup lang="ts">
|
||||
import { RouterView, useRouter, useRoute } from 'vue-router'
|
||||
import { onMounted } from 'vue'
|
||||
import { onMounted, watch } from 'vue'
|
||||
import Toast from '@/components/common/Toast.vue'
|
||||
import { getPublicSettings } from '@/api/auth'
|
||||
import { useAppStore } from '@/stores'
|
||||
import { getSetupStatus } from '@/api/setup'
|
||||
|
||||
const router = useRouter()
|
||||
const route = useRoute()
|
||||
const appStore = useAppStore()
|
||||
|
||||
/**
|
||||
* Update favicon dynamically
|
||||
@@ -24,6 +25,19 @@ function updateFavicon(logoUrl: string) {
|
||||
link.href = logoUrl
|
||||
}
|
||||
|
||||
// Watch for site settings changes and update favicon/title
|
||||
watch(() => appStore.siteLogo, (newLogo) => {
|
||||
if (newLogo) {
|
||||
updateFavicon(newLogo)
|
||||
}
|
||||
}, { immediate: true })
|
||||
|
||||
watch(() => appStore.siteName, (newName) => {
|
||||
if (newName) {
|
||||
document.title = `${newName} - AI API Gateway`
|
||||
}
|
||||
}, { immediate: true })
|
||||
|
||||
onMounted(async () => {
|
||||
// Check if setup is needed
|
||||
try {
|
||||
@@ -36,21 +50,8 @@ onMounted(async () => {
|
||||
// If setup endpoint fails, assume normal mode and continue
|
||||
}
|
||||
|
||||
try {
|
||||
const settings = await getPublicSettings()
|
||||
|
||||
// Update favicon if logo is set
|
||||
if (settings.site_logo) {
|
||||
updateFavicon(settings.site_logo)
|
||||
}
|
||||
|
||||
// Update page title if site name is set
|
||||
if (settings.site_name) {
|
||||
document.title = `${settings.site_name} - AI API Gateway`
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load public settings for favicon:', error)
|
||||
}
|
||||
// Load public settings into appStore (will be cached for other components)
|
||||
await appStore.fetchPublicSettings()
|
||||
})
|
||||
</script>
|
||||
|
||||
|
||||
@@ -38,6 +38,8 @@ export interface TrendParams {
|
||||
start_date?: string;
|
||||
end_date?: string;
|
||||
granularity?: 'day' | 'hour';
|
||||
user_id?: number;
|
||||
api_key_id?: number;
|
||||
}
|
||||
|
||||
export interface TrendResponse {
|
||||
@@ -57,6 +59,13 @@ export async function getUsageTrend(params?: TrendParams): Promise<TrendResponse
|
||||
return data;
|
||||
}
|
||||
|
||||
export interface ModelStatsParams {
|
||||
start_date?: string;
|
||||
end_date?: string;
|
||||
user_id?: number;
|
||||
api_key_id?: number;
|
||||
}
|
||||
|
||||
export interface ModelStatsResponse {
|
||||
models: ModelStat[];
|
||||
start_date: string;
|
||||
@@ -68,7 +77,7 @@ export interface ModelStatsResponse {
|
||||
* @param params - Query parameters for filtering
|
||||
* @returns Model usage statistics
|
||||
*/
|
||||
export async function getModelStats(params?: { start_date?: string; end_date?: string }): Promise<ModelStatsResponse> {
|
||||
export async function getModelStats(params?: ModelStatsParams): Promise<ModelStatsResponse> {
|
||||
const { data } = await apiClient.get<ModelStatsResponse>('/admin/dashboard/models', { params });
|
||||
return data;
|
||||
}
|
||||
|
||||
@@ -99,11 +99,49 @@ export async function sendTestEmail(request: SendTestEmailRequest): Promise<{ me
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Admin API Key status response
|
||||
*/
|
||||
export interface AdminApiKeyStatus {
|
||||
exists: boolean;
|
||||
masked_key: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get admin API key status
|
||||
* @returns Status indicating if key exists and masked version
|
||||
*/
|
||||
export async function getAdminApiKey(): Promise<AdminApiKeyStatus> {
|
||||
const { data } = await apiClient.get<AdminApiKeyStatus>('/admin/settings/admin-api-key');
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Regenerate admin API key
|
||||
* @returns The new full API key (only shown once)
|
||||
*/
|
||||
export async function regenerateAdminApiKey(): Promise<{ key: string }> {
|
||||
const { data } = await apiClient.post<{ key: string }>('/admin/settings/admin-api-key/regenerate');
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete admin API key
|
||||
* @returns Success message
|
||||
*/
|
||||
export async function deleteAdminApiKey(): Promise<{ message: string }> {
|
||||
const { data } = await apiClient.delete<{ message: string }>('/admin/settings/admin-api-key');
|
||||
return data;
|
||||
}
|
||||
|
||||
export const settingsAPI = {
|
||||
getSettings,
|
||||
updateSettings,
|
||||
testSmtpConnection,
|
||||
sendTestEmail,
|
||||
getAdminApiKey,
|
||||
regenerateAdminApiKey,
|
||||
deleteAdminApiKey,
|
||||
};
|
||||
|
||||
export default settingsAPI;
|
||||
|
||||
125
frontend/src/components/charts/ModelDistributionChart.vue
Normal file
125
frontend/src/components/charts/ModelDistributionChart.vue
Normal file
@@ -0,0 +1,125 @@
|
||||
<template>
|
||||
<div class="card p-4">
|
||||
<h3 class="text-sm font-semibold text-gray-900 dark:text-white mb-4">{{ t('admin.dashboard.modelDistribution') }}</h3>
|
||||
<div v-if="loading" class="flex items-center justify-center h-48">
|
||||
<LoadingSpinner />
|
||||
</div>
|
||||
<div v-else-if="modelStats.length > 0 && chartData" class="flex items-center gap-6">
|
||||
<div class="w-48 h-48">
|
||||
<Doughnut :data="chartData" :options="doughnutOptions" />
|
||||
</div>
|
||||
<div class="flex-1 max-h-48 overflow-y-auto">
|
||||
<table class="w-full text-xs">
|
||||
<thead>
|
||||
<tr class="text-gray-500 dark:text-gray-400">
|
||||
<th class="text-left pb-2">{{ t('admin.dashboard.model') }}</th>
|
||||
<th class="text-right pb-2">{{ t('admin.dashboard.requests') }}</th>
|
||||
<th class="text-right pb-2">{{ t('admin.dashboard.tokens') }}</th>
|
||||
<th class="text-right pb-2">{{ t('admin.dashboard.actual') }}</th>
|
||||
<th class="text-right pb-2">{{ t('admin.dashboard.standard') }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="model in modelStats" :key="model.model" class="border-t border-gray-100 dark:border-gray-700">
|
||||
<td class="py-1.5 text-gray-900 dark:text-white font-medium truncate max-w-[100px]" :title="model.model">{{ model.model }}</td>
|
||||
<td class="py-1.5 text-right text-gray-600 dark:text-gray-400">{{ formatNumber(model.requests) }}</td>
|
||||
<td class="py-1.5 text-right text-gray-600 dark:text-gray-400">{{ formatTokens(model.total_tokens) }}</td>
|
||||
<td class="py-1.5 text-right text-green-600 dark:text-green-400">${{ formatCost(model.actual_cost) }}</td>
|
||||
<td class="py-1.5 text-right text-gray-400 dark:text-gray-500">${{ formatCost(model.cost) }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div v-else class="flex items-center justify-center h-48 text-gray-500 dark:text-gray-400 text-sm">
|
||||
{{ t('admin.dashboard.noDataAvailable') }}
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import {
|
||||
Chart as ChartJS,
|
||||
ArcElement,
|
||||
Tooltip,
|
||||
Legend
|
||||
} from 'chart.js'
|
||||
import { Doughnut } from 'vue-chartjs'
|
||||
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
|
||||
import type { ModelStat } from '@/types'
|
||||
|
||||
ChartJS.register(ArcElement, Tooltip, Legend)
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
modelStats: ModelStat[]
|
||||
loading?: boolean
|
||||
}>()
|
||||
|
||||
const chartColors = [
|
||||
'#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6',
|
||||
'#ec4899', '#14b8a6', '#f97316', '#6366f1', '#84cc16'
|
||||
]
|
||||
|
||||
const chartData = computed(() => {
|
||||
if (!props.modelStats?.length) return null
|
||||
|
||||
return {
|
||||
labels: props.modelStats.map(m => m.model),
|
||||
datasets: [{
|
||||
data: props.modelStats.map(m => m.total_tokens),
|
||||
backgroundColor: chartColors.slice(0, props.modelStats.length),
|
||||
borderWidth: 0,
|
||||
}],
|
||||
}
|
||||
})
|
||||
|
||||
const doughnutOptions = computed(() => ({
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
legend: {
|
||||
display: false,
|
||||
},
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
label: (context: any) => {
|
||||
const value = context.raw as number
|
||||
const total = context.dataset.data.reduce((a: number, b: number) => a + b, 0)
|
||||
const percentage = ((value / total) * 100).toFixed(1)
|
||||
return `${context.label}: ${formatTokens(value)} (${percentage}%)`
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
const formatTokens = (value: number): string => {
|
||||
if (value >= 1_000_000_000) {
|
||||
return `${(value / 1_000_000_000).toFixed(2)}B`
|
||||
} else if (value >= 1_000_000) {
|
||||
return `${(value / 1_000_000).toFixed(2)}M`
|
||||
} else if (value >= 1_000) {
|
||||
return `${(value / 1_000).toFixed(2)}K`
|
||||
}
|
||||
return value.toLocaleString()
|
||||
}
|
||||
|
||||
const formatNumber = (value: number): string => {
|
||||
return value.toLocaleString()
|
||||
}
|
||||
|
||||
const formatCost = (value: number): string => {
|
||||
if (value >= 1000) {
|
||||
return (value / 1000).toFixed(2) + 'K'
|
||||
} else if (value >= 1) {
|
||||
return value.toFixed(2)
|
||||
} else if (value >= 0.01) {
|
||||
return value.toFixed(3)
|
||||
}
|
||||
return value.toFixed(4)
|
||||
}
|
||||
</script>
|
||||
182
frontend/src/components/charts/TokenUsageTrend.vue
Normal file
182
frontend/src/components/charts/TokenUsageTrend.vue
Normal file
@@ -0,0 +1,182 @@
|
||||
<template>
|
||||
<div class="card p-4">
|
||||
<h3 class="text-sm font-semibold text-gray-900 dark:text-white mb-4">{{ t('admin.dashboard.tokenUsageTrend') }}</h3>
|
||||
<div v-if="loading" class="flex items-center justify-center h-48">
|
||||
<LoadingSpinner />
|
||||
</div>
|
||||
<div v-else-if="trendData.length > 0 && chartData" class="h-48">
|
||||
<Line :data="chartData" :options="lineOptions" />
|
||||
</div>
|
||||
<div v-else class="flex items-center justify-center h-48 text-gray-500 dark:text-gray-400 text-sm">
|
||||
{{ t('admin.dashboard.noDataAvailable') }}
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import {
|
||||
Chart as ChartJS,
|
||||
CategoryScale,
|
||||
LinearScale,
|
||||
PointElement,
|
||||
LineElement,
|
||||
Title,
|
||||
Tooltip,
|
||||
Legend,
|
||||
Filler
|
||||
} from 'chart.js'
|
||||
import { Line } from 'vue-chartjs'
|
||||
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
|
||||
import type { TrendDataPoint } from '@/types'
|
||||
|
||||
ChartJS.register(
|
||||
CategoryScale,
|
||||
LinearScale,
|
||||
PointElement,
|
||||
LineElement,
|
||||
Title,
|
||||
Tooltip,
|
||||
Legend,
|
||||
Filler
|
||||
)
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
trendData: TrendDataPoint[]
|
||||
loading?: boolean
|
||||
}>()
|
||||
|
||||
const isDarkMode = computed(() => {
|
||||
return document.documentElement.classList.contains('dark')
|
||||
})
|
||||
|
||||
const chartColors = computed(() => ({
|
||||
text: isDarkMode.value ? '#e5e7eb' : '#374151',
|
||||
grid: isDarkMode.value ? '#374151' : '#e5e7eb',
|
||||
input: '#3b82f6',
|
||||
output: '#10b981',
|
||||
cache: '#f59e0b',
|
||||
}))
|
||||
|
||||
const chartData = computed(() => {
|
||||
if (!props.trendData?.length) return null
|
||||
|
||||
return {
|
||||
labels: props.trendData.map(d => d.date),
|
||||
datasets: [
|
||||
{
|
||||
label: 'Input',
|
||||
data: props.trendData.map(d => d.input_tokens),
|
||||
borderColor: chartColors.value.input,
|
||||
backgroundColor: `${chartColors.value.input}20`,
|
||||
fill: true,
|
||||
tension: 0.3,
|
||||
},
|
||||
{
|
||||
label: 'Output',
|
||||
data: props.trendData.map(d => d.output_tokens),
|
||||
borderColor: chartColors.value.output,
|
||||
backgroundColor: `${chartColors.value.output}20`,
|
||||
fill: true,
|
||||
tension: 0.3,
|
||||
},
|
||||
{
|
||||
label: 'Cache',
|
||||
data: props.trendData.map(d => d.cache_tokens),
|
||||
borderColor: chartColors.value.cache,
|
||||
backgroundColor: `${chartColors.value.cache}20`,
|
||||
fill: true,
|
||||
tension: 0.3,
|
||||
},
|
||||
],
|
||||
}
|
||||
})
|
||||
|
||||
const lineOptions = computed(() => ({
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
interaction: {
|
||||
intersect: false,
|
||||
mode: 'index' as const,
|
||||
},
|
||||
plugins: {
|
||||
legend: {
|
||||
position: 'top' as const,
|
||||
labels: {
|
||||
color: chartColors.value.text,
|
||||
usePointStyle: true,
|
||||
pointStyle: 'circle',
|
||||
padding: 15,
|
||||
font: {
|
||||
size: 11,
|
||||
},
|
||||
},
|
||||
},
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
label: (context: any) => {
|
||||
return `${context.dataset.label}: ${formatTokens(context.raw)}`
|
||||
},
|
||||
footer: (tooltipItems: any) => {
|
||||
const dataIndex = tooltipItems[0]?.dataIndex
|
||||
if (dataIndex !== undefined && props.trendData[dataIndex]) {
|
||||
const data = props.trendData[dataIndex]
|
||||
return `Actual: $${formatCost(data.actual_cost)} | Standard: $${formatCost(data.cost)}`
|
||||
}
|
||||
return ''
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
grid: {
|
||||
color: chartColors.value.grid,
|
||||
},
|
||||
ticks: {
|
||||
color: chartColors.value.text,
|
||||
font: {
|
||||
size: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
y: {
|
||||
grid: {
|
||||
color: chartColors.value.grid,
|
||||
},
|
||||
ticks: {
|
||||
color: chartColors.value.text,
|
||||
font: {
|
||||
size: 10,
|
||||
},
|
||||
callback: (value: string | number) => formatTokens(Number(value)),
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
const formatTokens = (value: number): string => {
|
||||
if (value >= 1_000_000_000) {
|
||||
return `${(value / 1_000_000_000).toFixed(2)}B`
|
||||
} else if (value >= 1_000_000) {
|
||||
return `${(value / 1_000_000).toFixed(2)}M`
|
||||
} else if (value >= 1_000) {
|
||||
return `${(value / 1_000).toFixed(2)}K`
|
||||
}
|
||||
return value.toLocaleString()
|
||||
}
|
||||
|
||||
const formatCost = (value: number): string => {
|
||||
if (value >= 1000) {
|
||||
return (value / 1000).toFixed(2) + 'K'
|
||||
} else if (value >= 1) {
|
||||
return value.toFixed(2)
|
||||
} else if (value >= 0.01) {
|
||||
return value.toFixed(3)
|
||||
}
|
||||
return value.toFixed(4)
|
||||
}
|
||||
</script>
|
||||
@@ -156,7 +156,6 @@ import { ref, computed, onMounted, onBeforeUnmount } from 'vue';
|
||||
import { useRouter, useRoute } from 'vue-router';
|
||||
import { useI18n } from 'vue-i18n';
|
||||
import { useAppStore, useAuthStore } from '@/stores';
|
||||
import { authAPI } from '@/api';
|
||||
import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue';
|
||||
import SubscriptionProgressMini from '@/components/common/SubscriptionProgressMini.vue';
|
||||
|
||||
@@ -169,7 +168,7 @@ const authStore = useAuthStore();
|
||||
const user = computed(() => authStore.user);
|
||||
const dropdownOpen = ref(false);
|
||||
const dropdownRef = ref<HTMLElement | null>(null);
|
||||
const contactInfo = ref('');
|
||||
const contactInfo = computed(() => appStore.contactInfo);
|
||||
|
||||
const userInitials = computed(() => {
|
||||
if (!user.value) return '';
|
||||
@@ -230,14 +229,8 @@ function handleClickOutside(event: MouseEvent) {
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(async () => {
|
||||
onMounted(() => {
|
||||
document.addEventListener('click', handleClickOutside);
|
||||
try {
|
||||
const settings = await authAPI.getPublicSettings();
|
||||
contactInfo.value = settings.contact_info || '';
|
||||
} catch (error) {
|
||||
console.error('Failed to load contact info:', error);
|
||||
}
|
||||
});
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
|
||||
@@ -131,11 +131,10 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, h, ref, onMounted } from 'vue';
|
||||
import { computed, h, ref } from 'vue';
|
||||
import { useRoute } from 'vue-router';
|
||||
import { useI18n } from 'vue-i18n';
|
||||
import { useAppStore, useAuthStore } from '@/stores';
|
||||
import { getPublicSettings } from '@/api/auth';
|
||||
import VersionBadge from '@/components/common/VersionBadge.vue';
|
||||
|
||||
const { t } = useI18n();
|
||||
@@ -149,21 +148,10 @@ const mobileOpen = computed(() => appStore.mobileOpen);
|
||||
const isAdmin = computed(() => authStore.isAdmin);
|
||||
const isDark = ref(document.documentElement.classList.contains('dark'));
|
||||
|
||||
// Site settings
|
||||
const siteName = ref('Sub2API');
|
||||
const siteLogo = ref('');
|
||||
const siteVersion = ref('');
|
||||
|
||||
onMounted(async () => {
|
||||
try {
|
||||
const settings = await getPublicSettings();
|
||||
siteName.value = settings.site_name || 'Sub2API';
|
||||
siteLogo.value = settings.site_logo || '';
|
||||
siteVersion.value = settings.version || '';
|
||||
} catch (error) {
|
||||
console.error('Failed to load public settings:', error);
|
||||
}
|
||||
});
|
||||
// Site settings from appStore (cached, no flicker)
|
||||
const siteName = computed(() => appStore.siteName);
|
||||
const siteLogo = computed(() => appStore.siteLogo);
|
||||
const siteVersion = computed(() => appStore.siteVersion);
|
||||
|
||||
// SVG Icon Components
|
||||
const DashboardIcon = {
|
||||
|
||||
@@ -79,6 +79,7 @@ export default {
|
||||
searchPlaceholder: 'Search...',
|
||||
noOptionsFound: 'No options found',
|
||||
saving: 'Saving...',
|
||||
refresh: 'Refresh',
|
||||
},
|
||||
|
||||
// Navigation
|
||||
@@ -991,6 +992,28 @@ export default {
|
||||
sending: 'Sending...',
|
||||
enterRecipientHint: 'Please enter a recipient email address',
|
||||
},
|
||||
adminApiKey: {
|
||||
title: 'Admin API Key',
|
||||
description: 'Global API key for external system integration with full admin access',
|
||||
notConfigured: 'Admin API key not configured',
|
||||
configured: 'Admin API key is active',
|
||||
currentKey: 'Current Key',
|
||||
regenerate: 'Regenerate',
|
||||
regenerating: 'Regenerating...',
|
||||
delete: 'Delete',
|
||||
deleting: 'Deleting...',
|
||||
create: 'Create Key',
|
||||
creating: 'Creating...',
|
||||
regenerateConfirm: 'Are you sure? The current key will be immediately invalidated.',
|
||||
deleteConfirm: 'Are you sure you want to delete the admin API key? External integrations will stop working.',
|
||||
keyGenerated: 'New admin API key generated',
|
||||
keyDeleted: 'Admin API key deleted',
|
||||
copyKey: 'Copy Key',
|
||||
keyCopied: 'Key copied to clipboard',
|
||||
keyWarning: 'This key will only be shown once. Please copy it now.',
|
||||
securityWarning: 'Warning: This key provides full admin access. Keep it secure.',
|
||||
usage: 'Usage: Add to request header - x-api-key: <your-admin-api-key>',
|
||||
},
|
||||
saveSettings: 'Save Settings',
|
||||
saving: 'Saving...',
|
||||
settingsSaved: 'Settings saved successfully',
|
||||
|
||||
@@ -79,6 +79,7 @@ export default {
|
||||
searchPlaceholder: '搜索...',
|
||||
noOptionsFound: '无匹配选项',
|
||||
saving: '保存中...',
|
||||
refresh: '刷新',
|
||||
},
|
||||
|
||||
// Navigation
|
||||
@@ -1170,6 +1171,28 @@ export default {
|
||||
sending: '发送中...',
|
||||
enterRecipientHint: '请输入收件人邮箱地址',
|
||||
},
|
||||
adminApiKey: {
|
||||
title: '管理员 API Key',
|
||||
description: '用于外部系统集成的全局 API Key,拥有完整的管理员权限',
|
||||
notConfigured: '尚未配置管理员 API Key',
|
||||
configured: '管理员 API Key 已启用',
|
||||
currentKey: '当前密钥',
|
||||
regenerate: '重新生成',
|
||||
regenerating: '生成中...',
|
||||
delete: '删除',
|
||||
deleting: '删除中...',
|
||||
create: '创建密钥',
|
||||
creating: '创建中...',
|
||||
regenerateConfirm: '确定要重新生成吗?当前密钥将立即失效。',
|
||||
deleteConfirm: '确定要删除管理员 API Key 吗?外部集成将停止工作。',
|
||||
keyGenerated: '新的管理员 API Key 已生成',
|
||||
keyDeleted: '管理员 API Key 已删除',
|
||||
copyKey: '复制密钥',
|
||||
keyCopied: '密钥已复制到剪贴板',
|
||||
keyWarning: '此密钥仅显示一次,请立即复制保存。',
|
||||
securityWarning: '警告:此密钥拥有完整的管理员权限,请妥善保管。',
|
||||
usage: '使用方法:在请求头中添加 x-api-key: <your-admin-api-key>',
|
||||
},
|
||||
saveSettings: '保存设置',
|
||||
saving: '保存中...',
|
||||
settingsSaved: '设置保存成功',
|
||||
|
||||
@@ -5,8 +5,9 @@
|
||||
|
||||
import { defineStore } from 'pinia';
|
||||
import { ref, computed } from 'vue';
|
||||
import type { Toast, ToastType } from '@/types';
|
||||
import type { Toast, ToastType, PublicSettings } from '@/types';
|
||||
import { checkUpdates as checkUpdatesAPI, type VersionInfo, type ReleaseInfo } from '@/api/admin/system';
|
||||
import { getPublicSettings as fetchPublicSettingsAPI } from '@/api/auth';
|
||||
|
||||
export const useAppStore = defineStore('app', () => {
|
||||
// ==================== State ====================
|
||||
@@ -16,6 +17,15 @@ export const useAppStore = defineStore('app', () => {
|
||||
const loading = ref<boolean>(false);
|
||||
const toasts = ref<Toast[]>([]);
|
||||
|
||||
// Public settings cache state
|
||||
const publicSettingsLoaded = ref<boolean>(false);
|
||||
const publicSettingsLoading = ref<boolean>(false);
|
||||
const siteName = ref<string>('Sub2API');
|
||||
const siteLogo = ref<string>('');
|
||||
const siteVersion = ref<string>('');
|
||||
const contactInfo = ref<string>('');
|
||||
const apiBaseUrl = ref<string>('');
|
||||
|
||||
// Version cache state
|
||||
const versionLoaded = ref<boolean>(false);
|
||||
const versionLoading = ref<boolean>(false);
|
||||
@@ -268,6 +278,59 @@ export const useAppStore = defineStore('app', () => {
|
||||
hasUpdate.value = false;
|
||||
}
|
||||
|
||||
// ==================== Public Settings Management ====================
|
||||
|
||||
/**
|
||||
* Fetch public settings (uses cache unless force=true)
|
||||
* @param force - Force refresh from API
|
||||
*/
|
||||
async function fetchPublicSettings(force = false): Promise<PublicSettings | null> {
|
||||
// Return cached data if available and not forcing refresh
|
||||
if (publicSettingsLoaded.value && !force) {
|
||||
return {
|
||||
registration_enabled: false,
|
||||
email_verify_enabled: false,
|
||||
turnstile_enabled: false,
|
||||
turnstile_site_key: '',
|
||||
site_name: siteName.value,
|
||||
site_logo: siteLogo.value,
|
||||
site_subtitle: '',
|
||||
api_base_url: apiBaseUrl.value,
|
||||
contact_info: contactInfo.value,
|
||||
version: siteVersion.value,
|
||||
};
|
||||
}
|
||||
|
||||
// Prevent duplicate requests
|
||||
if (publicSettingsLoading.value) {
|
||||
return null;
|
||||
}
|
||||
|
||||
publicSettingsLoading.value = true;
|
||||
try {
|
||||
const data = await fetchPublicSettingsAPI();
|
||||
siteName.value = data.site_name || 'Sub2API';
|
||||
siteLogo.value = data.site_logo || '';
|
||||
siteVersion.value = data.version || '';
|
||||
contactInfo.value = data.contact_info || '';
|
||||
apiBaseUrl.value = data.api_base_url || '';
|
||||
publicSettingsLoaded.value = true;
|
||||
return data;
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch public settings:', error);
|
||||
return null;
|
||||
} finally {
|
||||
publicSettingsLoading.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear public settings cache
|
||||
*/
|
||||
function clearPublicSettingsCache(): void {
|
||||
publicSettingsLoaded.value = false;
|
||||
}
|
||||
|
||||
// ==================== Return Store API ====================
|
||||
|
||||
return {
|
||||
@@ -277,6 +340,14 @@ export const useAppStore = defineStore('app', () => {
|
||||
loading,
|
||||
toasts,
|
||||
|
||||
// Public settings state
|
||||
publicSettingsLoaded,
|
||||
siteName,
|
||||
siteLogo,
|
||||
siteVersion,
|
||||
contactInfo,
|
||||
apiBaseUrl,
|
||||
|
||||
// Version state
|
||||
versionLoaded,
|
||||
versionLoading,
|
||||
@@ -309,5 +380,9 @@ export const useAppStore = defineStore('app', () => {
|
||||
// Version actions
|
||||
fetchVersion,
|
||||
clearVersionCache,
|
||||
|
||||
// Public settings actions
|
||||
fetchPublicSettings,
|
||||
clearPublicSettingsCache,
|
||||
};
|
||||
});
|
||||
|
||||
@@ -541,7 +541,7 @@ export interface ModelStat {
|
||||
export interface UserUsageTrendPoint {
|
||||
date: string;
|
||||
user_id: number;
|
||||
username: string;
|
||||
email: string;
|
||||
requests: number;
|
||||
tokens: number;
|
||||
cost: number; // 标准计费
|
||||
|
||||
@@ -2,7 +2,20 @@
|
||||
<AppLayout>
|
||||
<div class="space-y-6">
|
||||
<!-- Page Header Actions -->
|
||||
<div class="flex justify-end">
|
||||
<div class="flex justify-end gap-3">
|
||||
<button
|
||||
@click="loadAccounts"
|
||||
:disabled="loading"
|
||||
class="btn btn-secondary"
|
||||
:title="t('common.refresh')"
|
||||
>
|
||||
<svg
|
||||
:class="['w-5 h-5', loading ? 'animate-spin' : '']"
|
||||
fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5"
|
||||
>
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M16.023 9.348h4.992v-.001M2.985 19.644v-4.992m0 0h4.992m-4.993 0l3.181 3.183a8.25 8.25 0 0013.803-3.7M4.031 9.865a8.25 8.25 0 0113.803-3.7l3.181 3.182m0-4.991v4.99" />
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
@click="showCreateModal = true"
|
||||
class="btn btn-primary"
|
||||
|
||||
@@ -180,51 +180,14 @@
|
||||
|
||||
<!-- Charts Grid -->
|
||||
<div class="grid grid-cols-1 gap-6 lg:grid-cols-2">
|
||||
<!-- Model Distribution Chart -->
|
||||
<div class="card p-4">
|
||||
<h3 class="text-sm font-semibold text-gray-900 dark:text-white mb-4">{{ t('admin.dashboard.modelDistribution') }}</h3>
|
||||
<div class="flex items-center gap-6">
|
||||
<div class="w-48 h-48">
|
||||
<Doughnut v-if="modelChartData" :data="modelChartData" :options="doughnutOptions" />
|
||||
<div v-else class="flex items-center justify-center h-full text-gray-500 dark:text-gray-400 text-sm">
|
||||
{{ t('admin.dashboard.noDataAvailable') }}
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex-1 max-h-48 overflow-y-auto">
|
||||
<table class="w-full text-xs">
|
||||
<thead>
|
||||
<tr class="text-gray-500 dark:text-gray-400">
|
||||
<th class="text-left pb-2">{{ t('admin.dashboard.model') }}</th>
|
||||
<th class="text-right pb-2">{{ t('admin.dashboard.requests') }}</th>
|
||||
<th class="text-right pb-2">{{ t('admin.dashboard.tokens') }}</th>
|
||||
<th class="text-right pb-2">{{ t('admin.dashboard.actual') }}</th>
|
||||
<th class="text-right pb-2">{{ t('admin.dashboard.standard') }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="model in modelStats" :key="model.model" class="border-t border-gray-100 dark:border-gray-700">
|
||||
<td class="py-1.5 text-gray-900 dark:text-white font-medium truncate max-w-[100px]" :title="model.model">{{ model.model }}</td>
|
||||
<td class="py-1.5 text-right text-gray-600 dark:text-gray-400">{{ formatNumber(model.requests) }}</td>
|
||||
<td class="py-1.5 text-right text-gray-600 dark:text-gray-400">{{ formatTokens(model.total_tokens) }}</td>
|
||||
<td class="py-1.5 text-right text-green-600 dark:text-green-400">${{ formatCost(model.actual_cost) }}</td>
|
||||
<td class="py-1.5 text-right text-gray-400 dark:text-gray-500">${{ formatCost(model.cost) }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Token Usage Trend Chart -->
|
||||
<div class="card p-4">
|
||||
<h3 class="text-sm font-semibold text-gray-900 dark:text-white mb-4">{{ t('admin.dashboard.tokenUsageTrend') }}</h3>
|
||||
<div class="h-48">
|
||||
<Line v-if="trendChartData" :data="trendChartData" :options="lineOptions" />
|
||||
<div v-else class="flex items-center justify-center h-full text-gray-500 dark:text-gray-400 text-sm">
|
||||
{{ t('admin.dashboard.noDataAvailable') }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<ModelDistributionChart
|
||||
:model-stats="modelStats"
|
||||
:loading="chartsLoading"
|
||||
/>
|
||||
<TokenUsageTrend
|
||||
:trend-data="trendData"
|
||||
:loading="chartsLoading"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- User Usage Trend (Full Width) -->
|
||||
@@ -244,7 +207,7 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, watch } from 'vue'
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
|
||||
@@ -255,6 +218,8 @@ import AppLayout from '@/components/layout/AppLayout.vue'
|
||||
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
|
||||
import DateRangePicker from '@/components/common/DateRangePicker.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'
|
||||
import TokenUsageTrend from '@/components/charts/TokenUsageTrend.vue'
|
||||
|
||||
import {
|
||||
Chart as ChartJS,
|
||||
@@ -262,13 +227,12 @@ import {
|
||||
LinearScale,
|
||||
PointElement,
|
||||
LineElement,
|
||||
ArcElement,
|
||||
Title,
|
||||
Tooltip,
|
||||
Legend,
|
||||
Filler
|
||||
} from 'chart.js'
|
||||
import { Line, Doughnut } from 'vue-chartjs'
|
||||
import { Line } from 'vue-chartjs'
|
||||
|
||||
// Register Chart.js components
|
||||
ChartJS.register(
|
||||
@@ -276,7 +240,6 @@ ChartJS.register(
|
||||
LinearScale,
|
||||
PointElement,
|
||||
LineElement,
|
||||
ArcElement,
|
||||
Title,
|
||||
Tooltip,
|
||||
Legend,
|
||||
@@ -286,6 +249,7 @@ ChartJS.register(
|
||||
const appStore = useAppStore()
|
||||
const stats = ref<DashboardStats | null>(null)
|
||||
const loading = ref(false)
|
||||
const chartsLoading = ref(false)
|
||||
|
||||
// Chart data
|
||||
const trendData = ref<TrendDataPoint[]>([])
|
||||
@@ -312,34 +276,9 @@ const isDarkMode = computed(() => {
|
||||
const chartColors = computed(() => ({
|
||||
text: isDarkMode.value ? '#e5e7eb' : '#374151',
|
||||
grid: isDarkMode.value ? '#374151' : '#e5e7eb',
|
||||
input: '#3b82f6',
|
||||
output: '#10b981',
|
||||
cache: '#f59e0b',
|
||||
total: '#8b5cf6',
|
||||
}))
|
||||
|
||||
// Doughnut chart options
|
||||
const doughnutOptions = computed(() => ({
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
legend: {
|
||||
display: false,
|
||||
},
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
label: (context: any) => {
|
||||
const value = context.raw as number
|
||||
const total = context.dataset.data.reduce((a: number, b: number) => a + b, 0)
|
||||
const percentage = ((value / total) * 100).toFixed(1)
|
||||
return `${context.label}: ${formatTokens(value)} (${percentage}%)`
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
// Line chart options
|
||||
// Line chart options (for user trend chart)
|
||||
const lineOptions = computed(() => ({
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
@@ -365,15 +304,6 @@ const lineOptions = computed(() => ({
|
||||
label: (context: any) => {
|
||||
return `${context.dataset.label}: ${formatTokens(context.raw)}`
|
||||
},
|
||||
footer: (tooltipItems: any) => {
|
||||
// Show both costs for the day if we have trend data
|
||||
const dataIndex = tooltipItems[0]?.dataIndex
|
||||
if (dataIndex !== undefined && trendData.value[dataIndex]) {
|
||||
const data = trendData.value[dataIndex]
|
||||
return `Actual: $${formatCost(data.actual_cost)} | Standard: $${formatCost(data.cost)}`
|
||||
}
|
||||
return ''
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -404,71 +334,25 @@ const lineOptions = computed(() => ({
|
||||
},
|
||||
}))
|
||||
|
||||
// Model chart data
|
||||
const modelChartData = computed(() => {
|
||||
if (!modelStats.value?.length) return null
|
||||
|
||||
const colors = [
|
||||
'#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6',
|
||||
'#ec4899', '#14b8a6', '#f97316', '#6366f1', '#84cc16'
|
||||
]
|
||||
|
||||
return {
|
||||
labels: modelStats.value.map(m => m.model),
|
||||
datasets: [{
|
||||
data: modelStats.value.map(m => m.total_tokens),
|
||||
backgroundColor: colors.slice(0, modelStats.value.length),
|
||||
borderWidth: 0,
|
||||
}],
|
||||
}
|
||||
})
|
||||
|
||||
// Trend chart data
|
||||
const trendChartData = computed(() => {
|
||||
if (!trendData.value?.length) return null
|
||||
|
||||
return {
|
||||
labels: trendData.value.map(d => d.date),
|
||||
datasets: [
|
||||
{
|
||||
label: 'Input',
|
||||
data: trendData.value.map(d => d.input_tokens),
|
||||
borderColor: chartColors.value.input,
|
||||
backgroundColor: `${chartColors.value.input}20`,
|
||||
fill: true,
|
||||
tension: 0.3,
|
||||
},
|
||||
{
|
||||
label: 'Output',
|
||||
data: trendData.value.map(d => d.output_tokens),
|
||||
borderColor: chartColors.value.output,
|
||||
backgroundColor: `${chartColors.value.output}20`,
|
||||
fill: true,
|
||||
tension: 0.3,
|
||||
},
|
||||
{
|
||||
label: 'Cache',
|
||||
data: trendData.value.map(d => d.cache_tokens),
|
||||
borderColor: chartColors.value.cache,
|
||||
backgroundColor: `${chartColors.value.cache}20`,
|
||||
fill: true,
|
||||
tension: 0.3,
|
||||
},
|
||||
],
|
||||
}
|
||||
})
|
||||
|
||||
// User trend chart data
|
||||
const userTrendChartData = computed(() => {
|
||||
if (!userTrend.value?.length) return null
|
||||
|
||||
// Extract display name from email (part before @)
|
||||
const getDisplayName = (email: string, userId: number): string => {
|
||||
if (email && email.includes('@')) {
|
||||
return email.split('@')[0]
|
||||
}
|
||||
return `User #${userId}`
|
||||
}
|
||||
|
||||
// Group by user
|
||||
const userGroups = new Map<string, { name: string; data: Map<string, number> }>()
|
||||
const allDates = new Set<string>()
|
||||
|
||||
userTrend.value.forEach(point => {
|
||||
allDates.add(point.date)
|
||||
const key = point.username || `User #${point.user_id}`
|
||||
const key = getDisplayName(point.email, point.user_id)
|
||||
if (!userGroups.has(key)) {
|
||||
userGroups.set(key, { name: key, data: new Map() })
|
||||
}
|
||||
@@ -570,6 +454,7 @@ const loadDashboardStats = async () => {
|
||||
}
|
||||
|
||||
const loadChartData = async () => {
|
||||
chartsLoading.value = true
|
||||
try {
|
||||
const params = {
|
||||
start_date: startDate.value,
|
||||
@@ -588,6 +473,8 @@ const loadChartData = async () => {
|
||||
userTrend.value = userResponse.trend || []
|
||||
} catch (error) {
|
||||
console.error('Error loading chart data:', error)
|
||||
} finally {
|
||||
chartsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -596,11 +483,6 @@ onMounted(() => {
|
||||
initializeDateRange()
|
||||
loadChartData()
|
||||
})
|
||||
|
||||
// Watch for dark mode changes
|
||||
watch(isDarkMode, () => {
|
||||
// Force chart re-render on theme change
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -2,7 +2,20 @@
|
||||
<AppLayout>
|
||||
<div class="space-y-6">
|
||||
<!-- Page Header Actions -->
|
||||
<div class="flex justify-end">
|
||||
<div class="flex justify-end gap-3">
|
||||
<button
|
||||
@click="loadGroups"
|
||||
:disabled="loading"
|
||||
class="btn btn-secondary"
|
||||
:title="t('common.refresh')"
|
||||
>
|
||||
<svg
|
||||
:class="['w-5 h-5', loading ? 'animate-spin' : '']"
|
||||
fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="1.5"
|
||||
>
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M16.023 9.348h4.992v-.001M2.985 19.644v-4.992m0 0h4.992m-4.993 0l3.181 3.183a8.25 8.25 0 0013.803-3.7M4.031 9.865a8.25 8.25 0 0113.803-3.7l3.181 3.182m0-4.991v4.99" />
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
@click="showCreateModal = true"
|
||||
class="btn btn-primary"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user