Compare commits

...

7 Commits

Author SHA1 Message Date
shaw
b3463769dc chore: 调整403重试次数跟间隔 2025-12-26 14:19:57 +08:00
shaw
d9e6cfc44d feat: apikey使用弹出适配codex分组 2025-12-26 13:46:40 +08:00
Forest
57fd172287 refactor: 调整 server 目录结构 2025-12-26 10:42:35 +08:00
NepetaLemon
8d7a497553 refactor: 自定义业务错误 (#33)
* refactor: 自定义业务错误

* refactor: 隐藏服务器错误与统一 panic 响应
2025-12-26 08:47:00 +08:00
shaw
b31698b9f2 fix: 修复账户代理ip编辑保存不生效的bug 2025-12-25 21:58:09 +08:00
Forest
eeaff85e47 refactor: 自定义业务错误 2025-12-25 21:06:40 +08:00
Forest
f51ad2e126 refactor: 删除 ports 目录 2025-12-25 17:15:01 +08:00
136 changed files with 2691 additions and 1841 deletions

View File

@@ -19,14 +19,16 @@ linters:
files: files:
- "**/internal/service/**" - "**/internal/service/**"
deny: deny:
- pkg: sub2api/internal/repository - pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "service must not import repository" desc: "service must not import repository"
- pkg: gorm.io/gorm
desc: "service must not import gorm"
handler-no-repository: handler-no-repository:
list-mode: original list-mode: original
files: files:
- "**/internal/handler/**" - "**/internal/handler/**"
deny: deny:
- pkg: sub2api/internal/repository - pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "handler must not import repository" desc: "handler must not import repository"
errcheck: errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`. # Report about not checking of errors in type assertions: `a := b.(MyStruct)`.

View File

@@ -16,14 +16,14 @@ build-embed:
@echo "构建完成: bin/server (with embedded frontend)" @echo "构建完成: bin/server (with embedded frontend)"
test-unit: test-unit:
@go test ./... $(TEST_ARGS) @go test -tags unit ./... -count=1
test-integration: test-integration:
@go test -tags integration ./internal/repository -count=1 -race -parallel=8 @go test -tags integration ./... -count=1 -race -parallel=8
test-cover-integration: test-cover-integration:
@echo "运行集成测试并生成覆盖率报告..." @echo "运行集成测试并生成覆盖率报告..."
@go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./internal/repository/... @go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./...
@go tool cover -func=coverage.out | tail -1 @go tool cover -func=coverage.out | tail -1
@go tool cover -html=coverage.out -o coverage.html @go tool cover -html=coverage.out -o coverage.html
@echo "覆盖率报告已生成: coverage.html" @echo "覆盖率报告已生成: coverage.html"

View File

@@ -17,7 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/setup" "github.com/Wei-Shaw/sub2api/internal/setup"
"github.com/Wei-Shaw/sub2api/internal/web" "github.com/Wei-Shaw/sub2api/internal/web"
@@ -84,7 +84,7 @@ func main() {
func runSetupServer() { func runSetupServer() {
r := gin.New() r := gin.New()
r.Use(gin.Recovery()) r.Use(middleware.Recovery())
r.Use(middleware.CORS()) r.Use(middleware.CORS())
// Register setup routes // Register setup routes

View File

@@ -4,18 +4,19 @@
package main package main
import ( import (
"context"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"context"
"log"
"net/http"
"time"
"github.com/google/wire" "github.com/google/wire"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
@@ -35,6 +36,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// 业务层 ProviderSets // 业务层 ProviderSets
repository.ProviderSet, repository.ProviderSet,
service.ProviderSet, service.ProviderSet,
middleware.ProviderSet,
handler.ProviderSet, handler.ProviderSet,
// 服务器层 ProviderSet // 服务器层 ProviderSet
@@ -62,7 +64,11 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
func provideCleanup( func provideCleanup(
db *gorm.DB, db *gorm.DB,
rdb *redis.Client, rdb *redis.Client,
services *service.Services, tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
) func() { ) func() {
return func() { return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -74,23 +80,23 @@ func provideCleanup(
fn func() error fn func() error
}{ }{
{"TokenRefreshService", func() error { {"TokenRefreshService", func() error {
services.TokenRefresh.Stop() tokenRefresh.Stop()
return nil return nil
}}, }},
{"PricingService", func() error { {"PricingService", func() error {
services.Pricing.Stop() pricing.Stop()
return nil return nil
}}, }},
{"EmailQueueService", func() error { {"EmailQueueService", func() error {
services.EmailQueue.Stop() emailQueue.Stop()
return nil return nil
}}, }},
{"OAuthService", func() error { {"OAuthService", func() error {
services.OAuth.Stop() oauth.Stop()
return nil return nil
}}, }},
{"OpenAIOAuthService", func() error { {"OpenAIOAuthService", func() error {
services.OpenAIOAuth.Stop() openaiOAuth.Stop()
return nil return nil
}}, }},
{"Redis", func() error { {"Redis", func() error {

View File

@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
@@ -116,54 +117,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
groupService := service.NewGroupService(groupRepository) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
accountService := service.NewAccountService(accountRepository, groupRepository) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
proxyService := service.NewProxyService(proxyRepository) apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig) engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware)
services := &service.Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OpenAIGateway: openAIGatewayService,
OAuth: oAuthService,
OpenAIOAuth: openAIOAuthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
Update: updateService,
TokenRefresh: tokenRefreshService,
}
repositories := &repository.Repositories{
User: userRepository,
ApiKey: apiKeyRepository,
Group: groupRepository,
Account: accountRepository,
Proxy: proxyRepository,
RedeemCode: redeemCodeRepository,
UsageLog: usageLogRepository,
Setting: settingRepository,
UserSubscription: userSubscriptionRepository,
}
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
httpServer := server.ProvideHTTPServer(configConfig, engine) httpServer := server.ProvideHTTPServer(configConfig, engine)
v := provideCleanup(db, client, services) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
Cleanup: v, Cleanup: v,
@@ -188,7 +148,11 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
func provideCleanup( func provideCleanup(
db *gorm.DB, db *gorm.DB,
rdb *redis.Client, rdb *redis.Client,
services *service.Services, tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
) func() { ) func() {
return func() { return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -199,23 +163,23 @@ func provideCleanup(
fn func() error fn func() error
}{ }{
{"TokenRefreshService", func() error { {"TokenRefreshService", func() error {
services.TokenRefresh.Stop() tokenRefresh.Stop()
return nil return nil
}}, }},
{"PricingService", func() error { {"PricingService", func() error {
services.Pricing.Stop() pricing.Stop()
return nil return nil
}}, }},
{"EmailQueueService", func() error { {"EmailQueueService", func() error {
services.EmailQueue.Stop() emailQueue.Stop()
return nil return nil
}}, }},
{"OAuthService", func() error { {"OAuthService", func() error {
services.OAuth.Stop() oauth.Stop()
return nil return nil
}}, }},
{"OpenAIOAuthService", func() error { {"OpenAIOAuthService", func() error {
services.OpenAIOAuth.Stop() openaiOAuth.Stop()
return nil return nil
}}, }},
{"Redis", func() error { {"Redis", func() error {

View File

@@ -117,7 +117,7 @@ func (h *AccountHandler) List(c *gin.Context) {
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list accounts: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -156,7 +156,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
account, err := h.adminService.GetAccount(c.Request.Context(), accountID) account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.NotFound(c, "Account not found") response.ErrorFrom(c, err)
return return
} }
@@ -184,7 +184,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to create account: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -218,7 +218,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update account: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -236,7 +236,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteAccount(c.Request.Context(), accountID) err = h.adminService.DeleteAccount(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete account: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -297,7 +297,7 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
SyncProxies: syncProxies, SyncProxies: syncProxies,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Sync failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -332,7 +332,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
// Use OpenAI OAuth service to refresh token // Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil { if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -349,7 +349,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
// Use Anthropic/Claude OAuth service to refresh token // Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil { if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -372,7 +372,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
Credentials: newCredentials, Credentials: newCredentials,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -403,7 +403,7 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime) stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get account stats: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -421,7 +421,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID) account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to clear error: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -570,7 +570,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
Extra: req.Extra, Extra: req.Extra,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to bulk update accounts: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -595,7 +595,7 @@ func (h *OAuthHandler) GenerateAuthURL(c *gin.Context) {
result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID) result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -613,7 +613,7 @@ func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) {
result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID) result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to generate setup token URL: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -642,7 +642,7 @@ func (h *OAuthHandler) ExchangeCode(c *gin.Context) {
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -664,7 +664,7 @@ func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) {
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -692,7 +692,7 @@ func (h *OAuthHandler) CookieAuth(c *gin.Context) {
Scope: "full", Scope: "full",
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -714,7 +714,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
Scope: "inference", Scope: "inference",
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -732,7 +732,7 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID) usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -750,7 +750,7 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID) err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to clear rate limit: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -768,7 +768,7 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID) stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get today stats: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -797,7 +797,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable) account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable)
if err != nil { if err != nil {
response.InternalError(c, "Failed to update schedulable status: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -65,7 +65,7 @@ func (h *GroupHandler) List(c *gin.Context) {
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive) groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list groups: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -87,7 +87,7 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
} }
if err != nil { if err != nil {
response.InternalError(c, "Failed to get groups: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -105,7 +105,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
group, err := h.adminService.GetGroup(c.Request.Context(), groupID) group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
if err != nil { if err != nil {
response.NotFound(c, "Group not found") response.ErrorFrom(c, err)
return return
} }
@@ -133,7 +133,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
MonthlyLimitUSD: req.MonthlyLimitUSD, MonthlyLimitUSD: req.MonthlyLimitUSD,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to create group: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -168,7 +168,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
MonthlyLimitUSD: req.MonthlyLimitUSD, MonthlyLimitUSD: req.MonthlyLimitUSD,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update group: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -186,7 +186,7 @@ func (h *GroupHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteGroup(c.Request.Context(), groupID) err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete group: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -225,7 +225,7 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize) keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get group API keys: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -40,7 +40,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI) result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
if err != nil { if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -71,7 +71,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -103,7 +103,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
if err != nil { if err != nil {
response.BadRequest(c, "Failed to refresh token: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -122,7 +122,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
// Get account // Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID) account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.NotFound(c, "Account not found") response.ErrorFrom(c, err)
return return
} }
@@ -141,7 +141,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
// Use OpenAI OAuth service to refresh token // Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil { if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -159,7 +159,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
Credentials: newCredentials, Credentials: newCredentials,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -192,7 +192,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -220,7 +220,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to create account: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -53,7 +53,7 @@ func (h *ProxyHandler) List(c *gin.Context) {
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search) proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list proxies: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -69,7 +69,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
if withCount { if withCount {
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context()) proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
response.Success(c, proxies) response.Success(c, proxies)
@@ -78,7 +78,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
proxies, err := h.adminService.GetAllProxies(c.Request.Context()) proxies, err := h.adminService.GetAllProxies(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -96,7 +96,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID) proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
if err != nil { if err != nil {
response.NotFound(c, "Proxy not found") response.ErrorFrom(c, err)
return return
} }
@@ -121,7 +121,7 @@ func (h *ProxyHandler) Create(c *gin.Context) {
Password: strings.TrimSpace(req.Password), Password: strings.TrimSpace(req.Password),
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to create proxy: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -153,7 +153,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
Status: strings.TrimSpace(req.Status), Status: strings.TrimSpace(req.Status),
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update proxy: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -171,7 +171,7 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID) err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete proxy: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -189,7 +189,7 @@ func (h *ProxyHandler) Test(c *gin.Context) {
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID) result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to test proxy: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -229,7 +229,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize) accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get proxy accounts: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -272,7 +272,7 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
// Check for duplicates (same host, port, username, password) // Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password) exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
if err != nil { if err != nil {
response.InternalError(c, "Failed to check proxy existence: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -43,7 +43,7 @@ func (h *RedeemHandler) List(c *gin.Context) {
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search) codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list redeem codes: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -61,7 +61,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID) code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
if err != nil { if err != nil {
response.NotFound(c, "Redeem code not found") response.ErrorFrom(c, err)
return return
} }
@@ -85,7 +85,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
ValidityDays: req.ValidityDays, ValidityDays: req.ValidityDays,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to generate redeem codes: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -103,7 +103,7 @@ func (h *RedeemHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID) err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete redeem code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -123,7 +123,7 @@ func (h *RedeemHandler) BatchDelete(c *gin.Context) {
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs) deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
if err != nil { if err != nil {
response.InternalError(c, "Failed to batch delete redeem codes: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -144,7 +144,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID) code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to expire redeem code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -178,7 +178,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
// Get all codes without pagination (use large page size) // Get all codes without pagination (use large page size)
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "") codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
if err != nil { if err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -27,7 +27,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
func (h *SettingHandler) GetSettings(c *gin.Context) { func (h *SettingHandler) GetSettings(c *gin.Context) {
settings, err := h.settingService.GetAllSettings(c.Request.Context()) settings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -111,14 +111,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
response.InternalError(c, "Failed to update settings: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
// 重新获取设置返回 // 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context()) updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get updated settings: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -166,7 +166,7 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
err := h.emailService.TestSmtpConnectionWithConfig(config) err := h.emailService.TestSmtpConnectionWithConfig(config)
if err != nil { if err != nil {
response.BadRequest(c, "SMTP connection test failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -252,7 +252,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
` `
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil { if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
response.BadRequest(c, "Failed to send test email: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -264,7 +264,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) { func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context()) maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get admin API key status: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -279,7 +279,7 @@ func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) { func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context()) key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to generate admin API key: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -292,7 +292,7 @@ func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
// DELETE /api/v1/admin/settings/admin-api-key // DELETE /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) { func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil { if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
response.InternalError(c, "Failed to delete admin API key: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -78,7 +78,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status) subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -96,7 +96,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID) subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
if err != nil { if err != nil {
response.NotFound(c, "Subscription not found") response.ErrorFrom(c, err)
return return
} }
@@ -141,7 +141,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
Notes: req.Notes, Notes: req.Notes,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to assign subscription: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -168,7 +168,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
Notes: req.Notes, Notes: req.Notes,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -192,7 +192,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days) subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
if err != nil { if err != nil {
response.InternalError(c, "Failed to extend subscription: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -210,7 +210,7 @@ func (h *SubscriptionHandler) Revoke(c *gin.Context) {
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID) err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to revoke subscription: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -230,7 +230,7 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize) subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list group subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -248,7 +248,7 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID) subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list user subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -90,7 +90,7 @@ func (h *UsageHandler) List(c *gin.Context) {
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -158,7 +158,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if apiKeyID > 0 { if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime) stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
response.Success(c, stats) response.Success(c, stats)
@@ -168,7 +168,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if userID > 0 { if userID > 0 {
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime) stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
response.Success(c, stats) response.Success(c, stats)
@@ -178,7 +178,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
// Get global stats // Get global stats
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime) stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -197,7 +197,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
// Limit to 30 results // Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword) users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
if err != nil { if err != nil {
response.InternalError(c, "Failed to search users: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -236,7 +236,7 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30) keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil { if err != nil {
response.InternalError(c, "Failed to search API keys: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -64,7 +64,7 @@ func (h *UserHandler) List(c *gin.Context) {
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search) users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list users: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -82,7 +82,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
user, err := h.adminService.GetUser(c.Request.Context(), userID) user, err := h.adminService.GetUser(c.Request.Context(), userID)
if err != nil { if err != nil {
response.NotFound(c, "User not found") response.ErrorFrom(c, err)
return return
} }
@@ -109,7 +109,7 @@ func (h *UserHandler) Create(c *gin.Context) {
AllowedGroups: req.AllowedGroups, AllowedGroups: req.AllowedGroups,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to create user: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -144,7 +144,7 @@ func (h *UserHandler) Update(c *gin.Context) {
AllowedGroups: req.AllowedGroups, AllowedGroups: req.AllowedGroups,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update user: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -162,7 +162,7 @@ func (h *UserHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteUser(c.Request.Context(), userID) err = h.adminService.DeleteUser(c.Request.Context(), userID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete user: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -186,7 +186,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes) user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
if err != nil { if err != nil {
response.InternalError(c, "Failed to update balance: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -206,7 +206,7 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize) keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get user API keys: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -226,7 +226,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period) stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get user usage: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -57,7 +57,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params) keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list API keys: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -87,7 +87,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID) key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
if err != nil { if err != nil {
response.NotFound(c, "API key not found") response.ErrorFrom(c, err)
return return
} }
@@ -128,7 +128,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
} }
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq) key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
if err != nil { if err != nil {
response.InternalError(c, "Failed to create API key: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -173,7 +173,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq) key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
if err != nil { if err != nil {
response.InternalError(c, "Failed to update API key: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -203,7 +203,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID) err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete API key: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -227,7 +227,7 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID) groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get available groups: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -66,14 +66,14 @@ func (h *AuthHandler) Register(c *gin.Context) {
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过) // Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" { if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil { if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
} }
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode) token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
if err != nil { if err != nil {
response.BadRequest(c, "Registration failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -95,13 +95,13 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
// Turnstile 验证 // Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil { if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email) result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
if err != nil { if err != nil {
response.BadRequest(c, "Failed to send verification code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -122,13 +122,13 @@ func (h *AuthHandler) Login(c *gin.Context) {
// Turnstile 验证 // Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil { if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password) token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
if err != nil { if err != nil {
response.Unauthorized(c, "Login failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -10,10 +10,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -41,13 +41,13 @@ func NewGatewayHandler(gatewayService *service.GatewayService, userService *serv
// POST /v1/messages // POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) { func (h *GatewayHandler) Messages(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置 // 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c) apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return return
} }
user, ok := middleware.GetUserFromContext(c) user, ok := middleware2.GetUserFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
@@ -79,7 +79,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
streamStarted := false streamStarted := false
// 获取订阅信息可能为nil- 提前获取用于后续检查 // 获取订阅信息可能为nil- 提前获取用于后续检查
subscription, _ := middleware.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. 检查wait队列是否已满 // 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(user.Concurrency) maxWait := service.CalculateMaxWait(user.Concurrency)
@@ -171,7 +171,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// GET /v1/models // GET /v1/models
// Returns different model lists based on the API key's group platform // Returns different model lists based on the API key's group platform
func (h *GatewayHandler) Models(c *gin.Context) { func (h *GatewayHandler) Models(c *gin.Context) {
apiKey, _ := middleware.GetApiKeyFromContext(c) apiKey, _ := middleware2.GetApiKeyFromContext(c)
// Return OpenAI models for OpenAI platform groups // Return OpenAI models for OpenAI platform groups
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" { if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
@@ -192,13 +192,13 @@ func (h *GatewayHandler) Models(c *gin.Context) {
// Usage handles getting account balance for CC Switch integration // Usage handles getting account balance for CC Switch integration
// GET /v1/usage // GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) { func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c) apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return return
} }
user, ok := middleware.GetUserFromContext(c) user, ok := middleware2.GetUserFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return return
@@ -206,7 +206,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
// 订阅模式:返回订阅限额信息 // 订阅模式:返回订阅限额信息
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() { if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware.GetSubscriptionFromContext(c) subscription, ok := middleware2.GetSubscriptionFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription") h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription")
return return
@@ -328,13 +328,13 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
// 特点:校验订阅/余额,但不计算并发、不记录使用量 // 特点:校验订阅/余额,但不计算并发、不记录使用量
func (h *GatewayHandler) CountTokens(c *gin.Context) { func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置 // 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c) apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return return
} }
user, ok := middleware.GetUserFromContext(c) user, ok := middleware2.GetUserFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
@@ -362,7 +362,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
} }
// 获取订阅信息可能为nil // 获取订阅信息可能为nil
subscription, _ := middleware.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 校验 billing eligibility订阅/余额) // 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额 // 【注意】不计算并发,但需要校验订阅/余额

View File

@@ -9,8 +9,8 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -40,13 +40,13 @@ func NewOpenAIGatewayHandler(
// POST /openai/v1/responses // POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get apiKey and user from context (set by ApiKeyAuth middleware) // Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware.GetApiKeyFromContext(c) apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return return
} }
user, ok := middleware.GetUserFromContext(c) user, ok := middleware2.GetUserFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
@@ -91,7 +91,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
streamStarted := false streamStarted := false
// Get subscription info (may be nil) // Get subscription info (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full // 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(user.Concurrency) maxWait := service.CalculateMaxWait(user.Concurrency)

View File

@@ -57,7 +57,7 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code) result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
if err != nil { if err != nil {
response.BadRequest(c, "Failed to redeem code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -84,7 +84,7 @@ func (h *RedeemHandler) GetHistory(c *gin.Context) {
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit) codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get history: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -26,7 +26,7 @@ func NewSettingHandler(settingService *service.SettingService, version string) *
func (h *SettingHandler) GetPublicSettings(c *gin.Context) { func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
settings, err := h.settingService.GetPublicSettings(c.Request.Context()) settings, err := h.settingService.GetPublicSettings(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -58,7 +58,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -82,7 +82,7 @@ func (h *SubscriptionHandler) GetActive(c *gin.Context) {
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get active subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -107,7 +107,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
// Get all active subscriptions with progress // Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -146,7 +146,7 @@ func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
// Get all active subscriptions // Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -55,7 +55,7 @@ func (h *UsageHandler) List(c *gin.Context) {
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation // [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id) apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
if err != nil { if err != nil {
response.NotFound(c, "API key not found") response.ErrorFrom(c, err)
return return
} }
if apiKey.UserID != user.ID { if apiKey.UserID != user.ID {
@@ -77,7 +77,7 @@ func (h *UsageHandler) List(c *gin.Context) {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params) records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
} }
if err != nil { if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -107,7 +107,7 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
record, err := h.usageService.GetByID(c.Request.Context(), usageID) record, err := h.usageService.GetByID(c.Request.Context(), usageID)
if err != nil { if err != nil {
response.NotFound(c, "Usage record not found") response.ErrorFrom(c, err)
return return
} }
@@ -204,7 +204,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime) stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
} }
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -259,7 +259,7 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID) stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get dashboard statistics") response.ErrorFrom(c, err)
return return
} }
@@ -286,7 +286,7 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity) trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage trend") response.ErrorFrom(c, err)
return return
} }
@@ -317,7 +317,7 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime) stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get model statistics") response.ErrorFrom(c, err)
return return
} }
@@ -362,7 +362,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
// Verify ownership of all requested API keys // Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000}) userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil { if err != nil {
response.InternalError(c, "Failed to verify API key ownership") response.ErrorFrom(c, err)
return return
} }
@@ -386,7 +386,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs) stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get API key usage stats") response.ErrorFrom(c, err)
return return
} }

View File

@@ -49,7 +49,7 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
userData, err := h.userService.GetByID(c.Request.Context(), user.ID) userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get user profile: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -86,7 +86,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
} }
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq) err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
if err != nil { if err != nil {
response.BadRequest(c, "Failed to change password: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -120,7 +120,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
} }
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq) updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
if err != nil { if err != nil {
response.BadRequest(c, "Failed to update profile: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -0,0 +1,158 @@
package errors
import (
"errors"
"fmt"
"net/http"
)
const (
UnknownCode = http.StatusInternalServerError
UnknownReason = ""
UnknownMessage = "internal error"
)
type Status struct {
Code int32 `json:"code"`
Reason string `json:"reason,omitempty"`
Message string `json:"message"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// ApplicationError is the standard error type used to control HTTP responses.
//
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
type ApplicationError struct {
Status
cause error
}
// Error is kept for backwards compatibility within this package.
type Error = ApplicationError
func (e *ApplicationError) Error() string {
if e == nil {
return "<nil>"
}
if e.cause == nil {
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
}
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
}
// Unwrap provides compatibility for Go 1.13 error chains.
func (e *ApplicationError) Unwrap() error { return e.cause }
// Is matches each error in the chain with the target value.
func (e *ApplicationError) Is(err error) bool {
if se := new(ApplicationError); errors.As(err, &se) {
return se.Code == e.Code && se.Reason == e.Reason
}
return false
}
// WithCause attaches the underlying cause of the error.
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
err := Clone(e)
err.cause = cause
return err
}
// WithMetadata deep-copies the given metadata map.
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
err := Clone(e)
if md == nil {
err.Metadata = nil
return err
}
err.Metadata = make(map[string]string, len(md))
for k, v := range md {
err.Metadata[k] = v
}
return err
}
// New returns an error object for the code, message.
func New(code int, reason, message string) *ApplicationError {
return &ApplicationError{
Status: Status{
Code: int32(code),
Message: message,
Reason: reason,
},
}
}
// Newf New(code fmt.Sprintf(format, a...))
func Newf(code int, reason, format string, a ...any) *ApplicationError {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Errorf returns an error object for the code, message and error info.
func Errorf(code int, reason, format string, a ...any) error {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Code returns the http code for an error.
// It supports wrapped errors.
func Code(err error) int {
if err == nil {
return http.StatusOK
}
return int(FromError(err).Code)
}
// Reason returns the reason for a particular error.
// It supports wrapped errors.
func Reason(err error) string {
if err == nil {
return UnknownReason
}
return FromError(err).Reason
}
// Message returns the message for a particular error.
// It supports wrapped errors.
func Message(err error) string {
if err == nil {
return ""
}
return FromError(err).Message
}
// Clone deep clone error to a new error.
func Clone(err *ApplicationError) *ApplicationError {
if err == nil {
return nil
}
var metadata map[string]string
if err.Metadata != nil {
metadata = make(map[string]string, len(err.Metadata))
for k, v := range err.Metadata {
metadata[k] = v
}
}
return &ApplicationError{
cause: err.cause,
Status: Status{
Code: err.Code,
Reason: err.Reason,
Message: err.Message,
Metadata: metadata,
},
}
}
// FromError tries to convert an error to *ApplicationError.
// It supports wrapped errors.
func FromError(err error) *ApplicationError {
if err == nil {
return nil
}
if se := new(ApplicationError); errors.As(err, &se) {
return se
}
// Fall back to a generic internal error.
return New(UnknownCode, UnknownReason, UnknownMessage).WithCause(err)
}

View File

@@ -0,0 +1,168 @@
//go:build unit
package errors
import (
stderrors "errors"
"fmt"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestApplicationError_Basics(t *testing.T) {
tests := []struct {
name string
err *ApplicationError
want Status
wantIs bool
target error
wrapped error
}{
{
name: "new",
err: New(400, "BAD_REQUEST", "invalid input"),
want: Status{
Code: 400,
Reason: "BAD_REQUEST",
Message: "invalid input",
},
},
{
name: "is_matches_code_and_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "UNAUTHORIZED", "ignored message"),
wantIs: true,
},
{
name: "is_does_not_match_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "DIFFERENT", "ignored message"),
wantIs: false,
},
{
name: "from_error_unwraps_wrapped_application_error",
err: New(404, "NOT_FOUND", "missing"),
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
want: Status{
Code: 404,
Reason: "NOT_FOUND",
Message: "missing",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.err != nil {
require.Equal(t, tt.want, tt.err.Status)
}
if tt.target != nil {
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
}
if tt.wrapped != nil {
got := FromError(tt.wrapped)
require.Equal(t, tt.want, got.Status)
}
})
}
}
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
tests := []struct {
name string
md map[string]string
}{
{name: "non_nil", md: map[string]string{"a": "1"}},
{name: "nil", md: nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
if tt.md == nil {
require.Nil(t, appErr.Metadata)
return
}
tt.md["a"] = "changed"
require.Equal(t, "1", appErr.Metadata["a"])
})
}
}
func TestFromError_Generic(t *testing.T) {
tests := []struct {
name string
err error
wantCode int32
wantReason string
wantMsg string
}{
{
name: "plain_error",
err: stderrors.New("boom"),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: UnknownMessage,
},
{
name: "wrapped_plain_error",
err: fmt.Errorf("wrap: %w", io.EOF),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: UnknownMessage,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FromError(tt.err)
require.Equal(t, tt.wantCode, got.Code)
require.Equal(t, tt.wantReason, got.Reason)
require.Equal(t, tt.wantMsg, got.Message)
require.Equal(t, tt.err, got.Unwrap())
})
}
}
func TestToHTTP(t *testing.T) {
tests := []struct {
name string
err error
wantStatusCode int
wantBody Status
}{
{
name: "nil_error",
err: nil,
wantStatusCode: http.StatusOK,
wantBody: Status{Code: int32(http.StatusOK)},
},
{
name: "application_error",
err: Forbidden("FORBIDDEN", "no access"),
wantStatusCode: http.StatusForbidden,
wantBody: Status{
Code: int32(http.StatusForbidden),
Reason: "FORBIDDEN",
Message: "no access",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
code, body := ToHTTP(tt.err)
require.Equal(t, tt.wantStatusCode, code)
require.Equal(t, tt.wantBody, body)
})
}
}

View File

@@ -0,0 +1,21 @@
package errors
import "net/http"
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
//
// The returned body matches the project's Status shape:
// { code, reason, message, metadata }.
func ToHTTP(err error) (statusCode int, body Status) {
if err == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
appErr := FromError(err)
if appErr == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
cloned := Clone(appErr)
return int(cloned.Code), cloned.Status
}

View File

@@ -0,0 +1,114 @@
// nolint:mnd
package errors
import "net/http"
// BadRequest new BadRequest error that is mapped to a 400 response.
func BadRequest(reason, message string) *ApplicationError {
return New(http.StatusBadRequest, reason, message)
}
// IsBadRequest determines if err is an error which indicates a BadRequest error.
// It supports wrapped errors.
func IsBadRequest(err error) bool {
return Code(err) == http.StatusBadRequest
}
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
func TooManyRequests(reason, message string) *ApplicationError {
return New(http.StatusTooManyRequests, reason, message)
}
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
// It supports wrapped errors.
func IsTooManyRequests(err error) bool {
return Code(err) == http.StatusTooManyRequests
}
// Unauthorized new Unauthorized error that is mapped to a 401 response.
func Unauthorized(reason, message string) *ApplicationError {
return New(http.StatusUnauthorized, reason, message)
}
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
// It supports wrapped errors.
func IsUnauthorized(err error) bool {
return Code(err) == http.StatusUnauthorized
}
// Forbidden new Forbidden error that is mapped to a 403 response.
func Forbidden(reason, message string) *ApplicationError {
return New(http.StatusForbidden, reason, message)
}
// IsForbidden determines if err is an error which indicates a Forbidden error.
// It supports wrapped errors.
func IsForbidden(err error) bool {
return Code(err) == http.StatusForbidden
}
// NotFound new NotFound error that is mapped to a 404 response.
func NotFound(reason, message string) *ApplicationError {
return New(http.StatusNotFound, reason, message)
}
// IsNotFound determines if err is an error which indicates an NotFound error.
// It supports wrapped errors.
func IsNotFound(err error) bool {
return Code(err) == http.StatusNotFound
}
// Conflict new Conflict error that is mapped to a 409 response.
func Conflict(reason, message string) *ApplicationError {
return New(http.StatusConflict, reason, message)
}
// IsConflict determines if err is an error which indicates a Conflict error.
// It supports wrapped errors.
func IsConflict(err error) bool {
return Code(err) == http.StatusConflict
}
// InternalServer new InternalServer error that is mapped to a 500 response.
func InternalServer(reason, message string) *ApplicationError {
return New(http.StatusInternalServerError, reason, message)
}
// IsInternalServer determines if err is an error which indicates an Internal error.
// It supports wrapped errors.
func IsInternalServer(err error) bool {
return Code(err) == http.StatusInternalServerError
}
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
func ServiceUnavailable(reason, message string) *ApplicationError {
return New(http.StatusServiceUnavailable, reason, message)
}
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
// It supports wrapped errors.
func IsServiceUnavailable(err error) bool {
return Code(err) == http.StatusServiceUnavailable
}
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
func GatewayTimeout(reason, message string) *ApplicationError {
return New(http.StatusGatewayTimeout, reason, message)
}
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
// It supports wrapped errors.
func IsGatewayTimeout(err error) bool {
return Code(err) == http.StatusGatewayTimeout
}
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
func ClientClosed(reason, message string) *ApplicationError {
return New(499, reason, message)
}
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
// It supports wrapped errors.
func IsClientClosed(err error) bool {
return Code(err) == 499
}

View File

@@ -4,14 +4,17 @@ import (
"math" "math"
"net/http" "net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// Response 标准API响应格式 // Response 标准API响应格式
type Response struct { type Response struct {
Code int `json:"code"` Code int `json:"code"`
Message string `json:"message"` Message string `json:"message"`
Data any `json:"data,omitempty"` Reason string `json:"reason,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Data any `json:"data,omitempty"`
} }
// PaginatedData 分页数据格式(匹配前端期望) // PaginatedData 分页数据格式(匹配前端期望)
@@ -44,11 +47,36 @@ func Created(c *gin.Context, data any) {
// Error 返回错误响应 // Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) { func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{ c.JSON(statusCode, Response{
Code: statusCode, Code: statusCode,
Message: message, Message: message,
Reason: "",
Metadata: nil,
}) })
} }
// ErrorWithDetails returns an error response compatible with the existing envelope while
// optionally providing structured error fields (reason/metadata).
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Reason: reason,
Metadata: metadata,
})
}
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
// It returns true if an error was written.
func ErrorFrom(c *gin.Context, err error) bool {
if err == nil {
return false
}
statusCode, status := infraerrors.ToHTTP(err)
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}
// BadRequest 返回400错误 // BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) { func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message) Error(c, http.StatusBadRequest, message)

View File

@@ -0,0 +1,171 @@
//go:build unit
package response
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
reason string
metadata map[string]string
want Response
}{
{
name: "plain_error",
statusCode: http.StatusBadRequest,
message: "invalid request",
want: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
},
},
{
name: "structured_error",
statusCode: http.StatusForbidden,
message: "no access",
reason: "FORBIDDEN",
metadata: map[string]string{"k": "v"},
want: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"k": "v"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
require.Equal(t, tt.statusCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.want, got)
})
}
}
func TestErrorFrom(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
err error
wantWritten bool
wantHTTPCode int
wantBody Response
}{
{
name: "nil_error",
err: nil,
wantWritten: false,
},
{
name: "application_error",
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true,
wantHTTPCode: http.StatusForbidden,
wantBody: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"scope": "admin"},
},
},
{
name: "bad_request_error",
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true,
wantHTTPCode: http.StatusBadRequest,
wantBody: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
Reason: "INVALID_REQUEST",
},
},
{
name: "unauthorized_error",
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true,
wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{
Code: http.StatusUnauthorized,
Message: "unauthorized",
Reason: "UNAUTHORIZED",
},
},
{
name: "not_found_error",
err: infraerrors.NotFound("NOT_FOUND", "not found"),
wantWritten: true,
wantHTTPCode: http.StatusNotFound,
wantBody: Response{
Code: http.StatusNotFound,
Message: "not found",
Reason: "NOT_FOUND",
},
},
{
name: "conflict_error",
err: infraerrors.Conflict("CONFLICT", "conflict"),
wantWritten: true,
wantHTTPCode: http.StatusConflict,
wantBody: Response{
Code: http.StatusConflict,
Message: "conflict",
Reason: "CONFLICT",
},
},
{
name: "unknown_error_defaults_to_500",
err: errors.New("boom"),
wantWritten: true,
wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{
Code: http.StatusInternalServerError,
Message: infraerrors.UnknownMessage,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
written := ErrorFrom(c, tt.err)
require.Equal(t, tt.wantWritten, written)
if !tt.wantWritten {
require.Equal(t, 200, w.Code)
require.Empty(t, w.Body.String())
return
}
require.Equal(t, tt.wantHTTPCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}

View File

@@ -3,32 +3,33 @@ package repository
import ( import (
"context" "context"
"errors" "errors"
"time"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
type AccountRepository struct { type accountRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewAccountRepository(db *gorm.DB) *AccountRepository { func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return &AccountRepository{db: db} return &accountRepository{db: db}
} }
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error { func (r *accountRepository) Create(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Create(account).Error return r.db.WithContext(ctx).Create(account).Error
} }
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
} }
// 填充 GroupIDs 和 Groups 虚拟字段 // 填充 GroupIDs 和 Groups 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups)) account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
@@ -42,7 +43,7 @@ func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Accou
return &account, nil return &account, nil
} }
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) { func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
if crsAccountID == "" { if crsAccountID == "" {
return nil, nil return nil, nil
} }
@@ -58,11 +59,11 @@ func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return &account, nil return &account, nil
} }
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error { func (r *accountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error return r.db.WithContext(ctx).Save(account).Error
} }
func (r *AccountRepository) Delete(ctx context.Context, id int64) error { func (r *accountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系 // 先删除账号与分组的绑定关系
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil { if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err return err
@@ -71,12 +72,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
} }
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "") return r.ListWithFilters(ctx, params, "", "", "", "")
} }
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query // ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) { func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account var accounts []model.Account
var total int64 var total int64
@@ -130,7 +131,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
}, nil }, nil
} }
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) { func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
@@ -141,7 +142,7 @@ func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]m
return accounts, err return accounts, err
} }
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) { func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive). Where("status = ?", model.StatusActive).
@@ -151,12 +152,12 @@ func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, er
return accounts, err return accounts, err
} }
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error { func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
} }
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"status": model.StatusError, "status": model.StatusError,
@@ -164,7 +165,7 @@ func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg str
}).Error }).Error
} }
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{ ag := &model.AccountGroup{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
@@ -173,12 +174,12 @@ func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID i
return r.db.WithContext(ctx).Create(ag).Error return r.db.WithContext(ctx).Create(ag).Error
} }
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error { func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID). return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error Delete(&model.AccountGroup{}).Error
} }
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) { func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
var groups []model.Group var groups []model.Group
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id"). Joins("JOIN account_groups ON account_groups.group_id = groups.id").
@@ -187,7 +188,7 @@ func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]m
return groups, err return groups, err
} }
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, model.StatusActive). Where("platform = ? AND status = ?", platform, model.StatusActive).
@@ -197,7 +198,7 @@ func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string)
return accounts, err return accounts, err
} }
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定 // 删除现有绑定
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil { if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
return err return err
@@ -220,7 +221,7 @@ func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, gro
} }
// ListSchedulable 获取所有可调度的账号 // ListSchedulable 获取所有可调度的账号
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) { func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
@@ -234,7 +235,7 @@ func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Accoun
} }
// ListSchedulableByGroupID 按组获取可调度的账号 // ListSchedulableByGroupID 按组获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) { func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
@@ -250,7 +251,7 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
} }
// ListSchedulableByPlatform 按平台获取可调度的账号 // ListSchedulableByPlatform 按平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) { func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
@@ -265,7 +266,7 @@ func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platf
} }
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号 // ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) { func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
@@ -282,7 +283,7 @@ func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
} }
// SetRateLimited 标记账号为限流状态(429) // SetRateLimited 标记账号为限流状态(429)
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -292,13 +293,13 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA
} }
// SetOverloaded 标记账号为过载状态(529) // SetOverloaded 标记账号为过载状态(529)
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("overload_until", until).Error Update("overload_until", until).Error
} }
// ClearRateLimit 清除账号的限流状态 // ClearRateLimit 清除账号的限流状态
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error { func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"rate_limited_at": nil, "rate_limited_at": nil,
@@ -308,7 +309,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
} }
// UpdateSessionWindow 更新账号的5小时时间窗口信息 // UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{ updates := map[string]any{
"session_window_status": status, "session_window_status": status,
} }
@@ -322,14 +323,14 @@ func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
} }
// SetSchedulable 设置账号的调度开关 // SetSchedulable 设置账号的调度开关
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("schedulable", schedulable).Error Update("schedulable", schedulable).Error
} }
// UpdateExtra updates specific fields in account's Extra JSONB field // UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields // It merges the updates into existing Extra data without overwriting other fields
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 { if len(updates) == 0 {
return nil return nil
} }
@@ -357,7 +358,7 @@ func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates m
// BulkUpdate updates multiple accounts with the provided fields. // BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them. // It merges credentials/extra JSONB fields instead of overwriting them.
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ports.AccountBulkUpdate) (int64, error) { func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 { if len(ids) == 0 {
return 0, nil return 0, nil
} }

View File

@@ -9,7 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -18,13 +18,13 @@ type AccountRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *AccountRepository repo *accountRepository
} }
func (s *AccountRepoSuite) SetupTest() { func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewAccountRepository(s.db) s.repo = NewAccountRepository(s.db).(*accountRepository)
} }
func TestAccountRepoSuite(t *testing.T) { func TestAccountRepoSuite(t *testing.T) {
@@ -167,7 +167,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Run(tt.name, func() { s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源 // 每个 case 重新获取隔离资源
db := testTx(s.T()) db := testTx(s.T())
repo := NewAccountRepository(db) repo := NewAccountRepository(db).(*accountRepository)
ctx := context.Background() ctx := context.Background()
tt.setup(db) tt.setup(db)
@@ -513,7 +513,7 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1}) a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1})
newPriority := 99 newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, ports.AccountBulkUpdate{ affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
Priority: &newPriority, Priority: &newPriority,
}) })
s.Require().NoError(err) s.Require().NoError(err)
@@ -531,7 +531,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
Credentials: model.JSONB{"existing": "value"}, Credentials: model.JSONB{"existing": "value"},
}) })
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{ _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Credentials: model.JSONB{"new_key": "new_value"}, Credentials: model.JSONB{"new_key": "new_value"},
}) })
s.Require().NoError(err) s.Require().NoError(err)
@@ -547,7 +547,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
Extra: model.JSONB{"existing": "val"}, Extra: model.JSONB{"existing": "val"},
}) })
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{ _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Extra: model.JSONB{"new_key": "new_val"}, Extra: model.JSONB{"new_key": "new_val"},
}) })
s.Require().NoError(err) s.Require().NoError(err)
@@ -558,7 +558,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
} }
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() { func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, ports.AccountBulkUpdate{}) affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, service.AccountBulkUpdate{})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Zero(affected) s.Require().Zero(affected)
} }
@@ -566,7 +566,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() { func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"}) a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{}) affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Zero(affected) s.Require().Zero(affected)
} }

View File

@@ -5,8 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@@ -19,7 +18,7 @@ type apiKeyCache struct {
rdb *redis.Client rdb *redis.Client
} }
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache { func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
return &apiKeyCache{rdb: rdb} return &apiKeyCache{rdb: rdb}
} }

View File

@@ -2,51 +2,55 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
type ApiKeyRepository struct { type apiKeyRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository { func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &ApiKeyRepository{db: db} return &apiKeyRepository{db: db}
} }
func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Create(key).Error err := r.db.WithContext(ctx).Create(key).Error
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
} }
func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
var key model.ApiKey var key model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &key, nil return &key, nil
} }
func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
var apiKey model.ApiKey var apiKey model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &apiKey, nil return &apiKey, nil
} }
func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
} }
func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error { func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
} }
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []model.ApiKey
var total int64 var total int64
@@ -73,19 +77,19 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
}, nil }, nil
} }
func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
return count, err return count, err
} }
func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []model.ApiKey
var total int64 var total int64
@@ -113,7 +117,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
} }
// SearchApiKeys searches API keys by user ID and/or keyword (name) // SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
var keys []model.ApiKey var keys []model.ApiKey
db := r.db.WithContext(ctx).Model(&model.ApiKey{}) db := r.db.WithContext(ctx).Model(&model.ApiKey{})
@@ -135,7 +139,7 @@ func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
} }
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}). result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
Update("group_id", nil) Update("group_id", nil)
@@ -143,7 +147,7 @@ func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
} }
// CountByGroupID 获取分组的 API Key 数量 // CountByGroupID 获取分组的 API Key 数量
func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err

View File

@@ -16,13 +16,13 @@ type ApiKeyRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *ApiKeyRepository repo *apiKeyRepository
} }
func (s *ApiKeyRepoSuite) SetupTest() { func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewApiKeyRepository(s.db) s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository)
} }
func TestApiKeyRepoSuite(t *testing.T) { func TestApiKeyRepoSuite(t *testing.T) {

View File

@@ -8,8 +8,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@@ -58,7 +57,7 @@ type billingCache struct {
rdb *redis.Client rdb *redis.Client
} }
func NewBillingCache(rdb *redis.Client) ports.BillingCache { func NewBillingCache(rdb *redis.Client) service.BillingCache {
return &billingCache{rdb: rdb} return &billingCache{rdb: rdb}
} }
@@ -90,7 +89,7 @@ func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) { func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result() result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil { if err != nil {
@@ -102,8 +101,8 @@ func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID
return c.parseSubscriptionCache(result) return c.parseSubscriptionCache(result)
} }
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) { func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
result := &ports.SubscriptionCacheData{} result := &service.SubscriptionCacheData{}
result.Status = data[subFieldStatus] result.Status = data[subFieldStatus]
if result.Status == "" { if result.Status == "" {
@@ -136,7 +135,7 @@ func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.Su
return result, nil return result, nil
} }
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error { func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
if data == nil { if data == nil {
return nil return nil
} }

View File

@@ -8,7 +8,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@@ -21,18 +21,18 @@ type BillingCacheSuite struct {
func (s *BillingCacheSuite) TestUserBalance() { func (s *BillingCacheSuite) TestUserBalance() {
tests := []struct { tests := []struct {
name string name string
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
}{ }{
{ {
name: "missing_key_returns_redis_nil", name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
_, err := cache.GetUserBalance(ctx, 1) _, err := cache.GetUserBalance(ctx, 1)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key") require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
}, },
}, },
{ {
name: "deduct_on_nonexistent_is_noop", name: "deduct_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(1) userID := int64(1)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
@@ -44,7 +44,7 @@ func (s *BillingCacheSuite) TestUserBalance() {
}, },
{ {
name: "set_and_get_with_ttl", name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(2) userID := int64(2)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
@@ -61,7 +61,7 @@ func (s *BillingCacheSuite) TestUserBalance() {
}, },
{ {
name: "deduct_reduces_balance", name: "deduct_reduces_balance",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(3) userID := int64(3)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance") require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
@@ -74,7 +74,7 @@ func (s *BillingCacheSuite) TestUserBalance() {
}, },
{ {
name: "invalidate_removes_key", name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(100) userID := int64(100)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
@@ -96,7 +96,7 @@ func (s *BillingCacheSuite) TestUserBalance() {
}, },
{ {
name: "deduct_refreshes_ttl", name: "deduct_refreshes_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(103) userID := int64(103)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
@@ -133,11 +133,11 @@ func (s *BillingCacheSuite) TestUserBalance() {
func (s *BillingCacheSuite) TestSubscriptionCache() { func (s *BillingCacheSuite) TestSubscriptionCache() {
tests := []struct { tests := []struct {
name string name string
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
}{ }{
{ {
name: "missing_key_returns_redis_nil", name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(10) userID := int64(10)
groupID := int64(20) groupID := int64(20)
@@ -147,7 +147,7 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
}, },
{ {
name: "update_usage_on_nonexistent_is_noop", name: "update_usage_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(11) userID := int64(11)
groupID := int64(21) groupID := int64(21)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
@@ -161,12 +161,12 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
}, },
{ {
name: "set_and_get_with_ttl", name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(12) userID := int64(12)
groupID := int64(22) groupID := int64(22)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &ports.SubscriptionCacheData{ data := &service.SubscriptionCacheData{
Status: "active", Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour), ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0, DailyUsage: 1.0,
@@ -189,11 +189,11 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
}, },
{ {
name: "update_usage_increments_all_fields", name: "update_usage_increments_all_fields",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(13) userID := int64(13)
groupID := int64(23) groupID := int64(23)
data := &ports.SubscriptionCacheData{ data := &service.SubscriptionCacheData{
Status: "active", Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour), ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0, DailyUsage: 1.0,
@@ -214,12 +214,12 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
}, },
{ {
name: "invalidate_removes_key", name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(101) userID := int64(101)
groupID := int64(10) groupID := int64(10)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &ports.SubscriptionCacheData{ data := &service.SubscriptionCacheData{
Status: "active", Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour), ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0, DailyUsage: 1.0,
@@ -245,7 +245,7 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
}, },
{ {
name: "missing_status_returns_parsing_error", name: "missing_status_returns_parsing_error",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(102) userID := int64(102)
groupID := int64(11) groupID := int64(11)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)

View File

@@ -5,8 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@@ -107,7 +106,7 @@ type concurrencyCache struct {
rdb *redis.Client rdb *redis.Client
} }
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache { func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache {
return &concurrencyCache{rdb: rdb} return &concurrencyCache{rdb: rdb}
} }

View File

@@ -8,7 +8,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@@ -16,7 +16,7 @@ import (
type ConcurrencyCacheSuite struct { type ConcurrencyCacheSuite struct {
IntegrationRedisSuite IntegrationRedisSuite
cache ports.ConcurrencyCache cache service.ConcurrencyCache
} }
func (s *ConcurrencyCacheSuite) SetupTest() { func (s *ConcurrencyCacheSuite) SetupTest() {

View File

@@ -5,8 +5,7 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@@ -16,24 +15,24 @@ type emailCache struct {
rdb *redis.Client rdb *redis.Client
} }
func NewEmailCache(rdb *redis.Client) ports.EmailCache { func NewEmailCache(rdb *redis.Client) service.EmailCache {
return &emailCache{rdb: rdb} return &emailCache{rdb: rdb}
} }
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) { func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
key := verifyCodeKeyPrefix + email key := verifyCodeKeyPrefix + email
val, err := c.rdb.Get(ctx, key).Result() val, err := c.rdb.Get(ctx, key).Result()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var data ports.VerificationCodeData var data service.VerificationCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil { if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err return nil, err
} }
return &data, nil return &data, nil
} }
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error { func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
key := verifyCodeKeyPrefix + email key := verifyCodeKeyPrefix + email
val, err := json.Marshal(data) val, err := json.Marshal(data)
if err != nil { if err != nil {

View File

@@ -7,7 +7,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@@ -15,7 +15,7 @@ import (
type EmailCacheSuite struct { type EmailCacheSuite struct {
IntegrationRedisSuite IntegrationRedisSuite
cache ports.EmailCache cache service.EmailCache
} }
func (s *EmailCacheSuite) SetupTest() { func (s *EmailCacheSuite) SetupTest() {
@@ -31,7 +31,7 @@ func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() { func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
email := "a@example.com" email := "a@example.com"
emailTTL := 2 * time.Minute emailTTL := 2 * time.Minute
data := &ports.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()} data := &service.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode") require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
@@ -44,7 +44,7 @@ func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
func (s *EmailCacheSuite) TestVerificationCode_TTL() { func (s *EmailCacheSuite) TestVerificationCode_TTL() {
email := "ttl@example.com" email := "ttl@example.com"
emailTTL := 2 * time.Minute emailTTL := 2 * time.Minute
data := &ports.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()} data := &service.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode") require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
@@ -56,7 +56,7 @@ func (s *EmailCacheSuite) TestVerificationCode_TTL() {
func (s *EmailCacheSuite) TestDeleteVerificationCode() { func (s *EmailCacheSuite) TestDeleteVerificationCode() {
email := "delete@example.com" email := "delete@example.com"
data := &ports.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()} data := &service.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode") require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")

View File

@@ -0,0 +1,40 @@
package repository
import (
"errors"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"gorm.io/gorm"
)
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
if err == nil {
return nil
}
if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return notFound.WithCause(err)
}
if conflict != nil && isUniqueConstraintViolation(err) {
return conflict.WithCause(err)
}
return err
}
func isUniqueConstraintViolation(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrDuplicatedKey) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "duplicate key") ||
strings.Contains(msg, "unique constraint") ||
strings.Contains(msg, "duplicate entry")
}

View File

@@ -4,8 +4,7 @@ import (
"context" "context"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@@ -15,7 +14,7 @@ type gatewayCache struct {
rdb *redis.Client rdb *redis.Client
} }
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache { func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb} return &gatewayCache{rdb: rdb}
} }

View File

@@ -7,7 +7,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@@ -15,7 +15,7 @@ import (
type GatewayCacheSuite struct { type GatewayCacheSuite struct {
IntegrationRedisSuite IntegrationRedisSuite
cache ports.GatewayCache cache service.GatewayCache
} }
func (s *GatewayCacheSuite) SetupTest() { func (s *GatewayCacheSuite) SetupTest() {

View File

@@ -2,47 +2,52 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type GroupRepository struct { type groupRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewGroupRepository(db *gorm.DB) *GroupRepository { func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return &GroupRepository{db: db} return &groupRepository{db: db}
} }
func (r *GroupRepository) Create(ctx context.Context, group *model.Group) error { func (r *groupRepository) Create(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Create(group).Error err := r.db.WithContext(ctx).Create(group).Error
return translatePersistenceError(err, nil, service.ErrGroupExists)
} }
func (r *GroupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) { func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
var group model.Group var group model.Group
err := r.db.WithContext(ctx).First(&group, id).Error err := r.db.WithContext(ctx).First(&group, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
} }
return &group, nil return &group, nil
} }
func (r *GroupRepository) Update(ctx context.Context, group *model.Group) error { func (r *groupRepository) Update(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Save(group).Error return r.db.WithContext(ctx).Save(group).Error
} }
func (r *GroupRepository) Delete(ctx context.Context, id int64) error { func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
} }
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil) return r.ListWithFilters(ctx, params, "", "", nil)
} }
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive // ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
var groups []model.Group var groups []model.Group
var total int64 var total int64
@@ -86,7 +91,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination
}, nil }, nil
} }
func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) { func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
var groups []model.Group var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
@@ -100,7 +105,7 @@ func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error)
return groups, nil return groups, nil
} }
func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) { func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
var groups []model.Group var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
@@ -114,25 +119,80 @@ func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform str
return groups, nil return groups, nil
} }
func (r *GroupRepository) ExistsByName(ctx context.Context, name string) (bool, error) { func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *GroupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err
} }
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系 // DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *GroupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{}) result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
// DB 返回底层数据库连接,用于事务处理 func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
func (r *GroupRepository) DB() *gorm.DB { group, err := r.GetByID(ctx, id)
return r.db if err != nil {
return nil, err
}
var affectedUserIDs []int64
if group.IsSubscriptionType() {
var subscriptions []model.UserSubscription
if err := r.db.WithContext(ctx).
Model(&model.UserSubscription{}).
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err != nil {
return nil, err
}
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 删除订阅类型分组的订阅记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
return err
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
return err
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
return err
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err
}
// 5. 删除分组本身(带锁,避免并发写)
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return affectedUserIDs, nil
} }

View File

@@ -16,13 +16,13 @@ type GroupRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *GroupRepository repo *groupRepository
} }
func (s *GroupRepoSuite) SetupTest() { func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewGroupRepository(s.db) s.repo = NewGroupRepository(s.db).(*groupRepository)
} }
func TestGroupRepoSuite(t *testing.T) { func TestGroupRepoSuite(t *testing.T) {
@@ -234,11 +234,3 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
count, _ := s.repo.GetAccountCount(s.ctx, g.ID) count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count) s.Require().Zero(count)
} }
// --- DB ---
func (s *GroupRepoSuite) TestDB() {
db := s.repo.DB()
s.Require().NotNil(db, "DB should return non-nil")
s.Require().Equal(s.db, db, "DB should return the underlying gorm.DB")
}

View File

@@ -6,7 +6,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
) )
// httpUpstreamService is a generic HTTP upstream service that can be used for // httpUpstreamService is a generic HTTP upstream service that can be used for
@@ -17,7 +17,7 @@ type httpUpstreamService struct {
} }
// NewHTTPUpstream creates a new generic HTTP upstream service // NewHTTPUpstream creates a new generic HTTP upstream service
func NewHTTPUpstream(cfg *config.Config) ports.HTTPUpstream { func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 { if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second responseHeaderTimeout = 300 * time.Second

View File

@@ -6,8 +6,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@@ -20,24 +19,24 @@ type identityCache struct {
rdb *redis.Client rdb *redis.Client
} }
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache { func NewIdentityCache(rdb *redis.Client) service.IdentityCache {
return &identityCache{rdb: rdb} return &identityCache{rdb: rdb}
} }
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) { func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Result() val, err := c.rdb.Get(ctx, key).Result()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var fp ports.Fingerprint var fp service.Fingerprint
if err := json.Unmarshal([]byte(val), &fp); err != nil { if err := json.Unmarshal([]byte(val), &fp); err != nil {
return nil, err return nil, err
} }
return &fp, nil return &fp, nil
} }
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error { func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
val, err := json.Marshal(fp) val, err := json.Marshal(fp)
if err != nil { if err != nil {

View File

@@ -8,7 +8,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@@ -30,7 +30,7 @@ func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
} }
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() { func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"} fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint") require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
gotFP, err := s.cache.GetFingerprint(s.ctx, 1) gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
require.NoError(s.T(), err, "GetFingerprint") require.NoError(s.T(), err, "GetFingerprint")
@@ -39,7 +39,7 @@ func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
} }
func (s *IdentityCacheSuite) TestFingerprint_TTL() { func (s *IdentityCacheSuite) TestFingerprint_TTL() {
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"} fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp)) require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2) fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)

View File

@@ -7,13 +7,12 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
) )
// NewOpenAIOAuthClient creates a new OpenAI OAuth client // NewOpenAIOAuthClient creates a new OpenAI OAuth client
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient { func NewOpenAIOAuthClient() service.OpenAIOAuthClient {
return &openaiOAuthService{tokenURL: openai.TokenURL} return &openaiOAuthService{tokenURL: openai.TokenURL}
} }

View File

@@ -2,47 +2,50 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
type ProxyRepository struct { type proxyRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewProxyRepository(db *gorm.DB) *ProxyRepository { func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
return &ProxyRepository{db: db} return &proxyRepository{db: db}
} }
func (r *ProxyRepository) Create(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Create(proxy).Error return r.db.WithContext(ctx).Create(proxy).Error
} }
func (r *ProxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
var proxy model.Proxy var proxy model.Proxy
err := r.db.WithContext(ctx).First(&proxy, id).Error err := r.db.WithContext(ctx).First(&proxy, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
} }
return &proxy, nil return &proxy, nil
} }
func (r *ProxyRepository) Update(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Save(proxy).Error return r.db.WithContext(ctx).Save(proxy).Error
} }
func (r *ProxyRepository) Delete(ctx context.Context, id int64) error { func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
} }
func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query // ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy var proxies []model.Proxy
var total int64 var total int64
@@ -81,14 +84,14 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination
}, nil }, nil
} }
func (r *ProxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) { func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
var proxies []model.Proxy var proxies []model.Proxy
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
return proxies, err return proxies, err
} }
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists // ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Proxy{}). err := r.db.WithContext(ctx).Model(&model.Proxy{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password). Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
@@ -100,7 +103,7 @@ func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
} }
// CountAccountsByProxyID returns the number of accounts using a specific proxy // CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Account{}). err := r.db.WithContext(ctx).Model(&model.Account{}).
Where("proxy_id = ?", proxyID). Where("proxy_id = ?", proxyID).
@@ -109,7 +112,7 @@ func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
} }
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies // GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) { func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
type result struct { type result struct {
ProxyID int64 `gorm:"column:proxy_id"` ProxyID int64 `gorm:"column:proxy_id"`
Count int64 `gorm:"column:count"` Count int64 `gorm:"column:count"`
@@ -133,7 +136,7 @@ func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
} }
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending // ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *ProxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
var proxies []model.Proxy var proxies []model.Proxy
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive). Where("status = ?", model.StatusActive).

View File

@@ -17,13 +17,13 @@ type ProxyRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *ProxyRepository repo *proxyRepository
} }
func (s *ProxyRepoSuite) SetupTest() { func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewProxyRepository(s.db) s.repo = NewProxyRepository(s.db).(*proxyRepository)
} }
func TestProxyRepoSuite(t *testing.T) { func TestProxyRepoSuite(t *testing.T) {

View File

@@ -5,8 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@@ -20,7 +19,7 @@ type redeemCache struct {
rdb *redis.Client rdb *redis.Client
} }
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache { func NewRedeemCache(rdb *redis.Client) service.RedeemCache {
return &redeemCache{rdb: rdb} return &redeemCache{rdb: rdb}
} }

View File

@@ -2,57 +2,60 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"time"
"gorm.io/gorm" "gorm.io/gorm"
) )
type RedeemCodeRepository struct { type redeemCodeRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewRedeemCodeRepository(db *gorm.DB) *RedeemCodeRepository { func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
return &RedeemCodeRepository{db: db} return &redeemCodeRepository{db: db}
} }
func (r *RedeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Create(code).Error return r.db.WithContext(ctx).Create(code).Error
} }
func (r *RedeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error { func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
return r.db.WithContext(ctx).Create(&codes).Error return r.db.WithContext(ctx).Create(&codes).Error
} }
func (r *RedeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
var code model.RedeemCode var code model.RedeemCode
err := r.db.WithContext(ctx).First(&code, id).Error err := r.db.WithContext(ctx).First(&code, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &code, nil return &code, nil
} }
func (r *RedeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
var redeemCode model.RedeemCode var redeemCode model.RedeemCode
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &redeemCode, nil return &redeemCode, nil
} }
func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error { func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
} }
func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query // ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) { func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
var codes []model.RedeemCode var codes []model.RedeemCode
var total int64 var total int64
@@ -91,11 +94,11 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
}, nil }, nil
} }
func (r *RedeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Save(code).Error return r.db.WithContext(ctx).Save(code).Error
} }
func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error { func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now() now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}). result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
Where("id = ? AND status = ?", id, model.StatusUnused). Where("id = ? AND status = ?", id, model.StatusUnused).
@@ -108,13 +111,13 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
return result.Error return result.Error
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 兑换码不存在或已被使用 return service.ErrRedeemCodeUsed.WithCause(gorm.ErrRecordNotFound)
} }
return nil return nil
} }
// ListByUser returns all redeem codes used by a specific user // ListByUser returns all redeem codes used by a specific user
func (r *RedeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) { func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
var codes []model.RedeemCode var codes []model.RedeemCode
if limit <= 0 { if limit <= 0 {
limit = 10 limit = 10

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -17,13 +18,13 @@ type RedeemCodeRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *RedeemCodeRepository repo *redeemCodeRepository
} }
func (s *RedeemCodeRepoSuite) SetupTest() { func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewRedeemCodeRepository(s.db) s.repo = NewRedeemCodeRepository(s.db).(*redeemCodeRepository)
} }
func TestRedeemCodeRepoSuite(t *testing.T) { func TestRedeemCodeRepoSuite(t *testing.T) {
@@ -195,7 +196,7 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
// Second use should fail // Second use should fail
err = s.repo.Use(s.ctx, code.ID, user.ID) err = s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "Use expected error on second call") s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
} }
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
@@ -204,7 +205,7 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code") s.Require().Error(err, "expected error for already used code")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
} }
// --- ListByUser --- // --- ListByUser ---
@@ -298,7 +299,7 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use") s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
err = s.repo.Use(s.ctx, codeB.ID, user.ID) err = s.repo.Use(s.ctx, codeB.ID, user.ID)
s.Require().Error(err, "Use expected error on second call") s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
codeA, err := s.repo.GetByCode(s.ctx, "CODEA") codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
s.Require().NoError(err, "GetByCode") s.Require().NoError(err, "GetByCode")

View File

@@ -1,14 +0,0 @@
package repository
// Repositories 所有仓库的集合
type Repositories struct {
User *UserRepository
ApiKey *ApiKeyRepository
Group *GroupRepository
Account *AccountRepository
Proxy *ProxyRepository
RedeemCode *RedeemCodeRepository
UsageLog *UsageLogRepository
Setting *SettingRepository
UserSubscription *UserSubscriptionRepository
}

View File

@@ -2,35 +2,38 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/model"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// SettingRepository 系统设置数据访问层 // SettingRepository 系统设置数据访问层
type SettingRepository struct { type settingRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewSettingRepository 创建系统设置仓库实例 // NewSettingRepository 创建系统设置仓库实例
func NewSettingRepository(db *gorm.DB) *SettingRepository { func NewSettingRepository(db *gorm.DB) service.SettingRepository {
return &SettingRepository{db: db} return &settingRepository{db: db}
} }
// Get 根据Key获取设置值 // Get 根据Key获取设置值
func (r *SettingRepository) Get(ctx context.Context, key string) (*model.Setting, error) { func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
var setting model.Setting var setting model.Setting
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
} }
return &setting, nil return &setting, nil
} }
// GetValue 获取设置值字符串 // GetValue 获取设置值字符串
func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, error) { func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key) setting, err := r.Get(ctx, key)
if err != nil { if err != nil {
return "", err return "", err
@@ -39,7 +42,7 @@ func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, e
} }
// Set 设置值(存在则更新,不存在则创建) // Set 设置值(存在则更新,不存在则创建)
func (r *SettingRepository) Set(ctx context.Context, key, value string) error { func (r *settingRepository) Set(ctx context.Context, key, value string) error {
setting := &model.Setting{ setting := &model.Setting{
Key: key, Key: key,
Value: value, Value: value,
@@ -53,7 +56,7 @@ func (r *SettingRepository) Set(ctx context.Context, key, value string) error {
} }
// GetMultiple 批量获取设置 // GetMultiple 批量获取设置
func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
var settings []model.Setting var settings []model.Setting
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
if err != nil { if err != nil {
@@ -68,7 +71,7 @@ func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map
} }
// SetMultiple 批量设置值 // SetMultiple 批量设置值
func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string]string) error { func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for key, value := range settings { for key, value := range settings {
setting := &model.Setting{ setting := &model.Setting{
@@ -88,7 +91,7 @@ func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string
} }
// GetAll 获取所有设置 // GetAll 获取所有设置
func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, error) { func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
var settings []model.Setting var settings []model.Setting
err := r.db.WithContext(ctx).Find(&settings).Error err := r.db.WithContext(ctx).Find(&settings).Error
if err != nil { if err != nil {
@@ -103,6 +106,6 @@ func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, erro
} }
// Delete 删除设置 // Delete 删除设置
func (r *SettingRepository) Delete(ctx context.Context, key string) error { func (r *settingRepository) Delete(ctx context.Context, key string) error {
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error
} }

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -14,13 +15,13 @@ type SettingRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *SettingRepository repo *settingRepository
} }
func (s *SettingRepoSuite) SetupTest() { func (s *SettingRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewSettingRepository(s.db) s.repo = NewSettingRepository(s.db).(*settingRepository)
} }
func TestSettingRepoSuite(t *testing.T) { func TestSettingRepoSuite(t *testing.T) {
@@ -45,7 +46,7 @@ func (s *SettingRepoSuite) TestSet_Upsert() {
func (s *SettingRepoSuite) TestGetValue_Missing() { func (s *SettingRepoSuite) TestGetValue_Missing() {
_, err := s.repo.GetValue(s.ctx, "nonexistent") _, err := s.repo.GetValue(s.ctx, "nonexistent")
s.Require().Error(err, "expected error for missing key") s.Require().Error(err, "expected error for missing key")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrSettingNotFound)
} }
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() { func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
@@ -86,7 +87,7 @@ func (s *SettingRepoSuite) TestDelete() {
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete") s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
_, err := s.repo.GetValue(s.ctx, "todelete") _, err := s.repo.GetValue(s.ctx, "todelete")
s.Require().Error(err, "expected missing key error after Delete") s.Require().Error(err, "expected missing key error after Delete")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrSettingNotFound)
} }
func (s *SettingRepoSuite) TestDelete_Idempotent() { func (s *SettingRepoSuite) TestDelete_Idempotent() {

View File

@@ -4,8 +4,7 @@ import (
"context" "context"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@@ -15,7 +14,7 @@ type updateCache struct {
rdb *redis.Client rdb *redis.Client
} }
func NewUpdateCache(rdb *redis.Client) ports.UpdateCache { func NewUpdateCache(rdb *redis.Client) service.UpdateCache {
return &updateCache{rdb: rdb} return &updateCache{rdb: rdb}
} }

View File

@@ -2,25 +2,28 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"time"
"gorm.io/gorm" "gorm.io/gorm"
) )
type UsageLogRepository struct { type usageLogRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewUsageLogRepository(db *gorm.DB) *UsageLogRepository { func NewUsageLogRepository(db *gorm.DB) service.UsageLogRepository {
return &UsageLogRepository{db: db} return &usageLogRepository{db: db}
} }
// getPerformanceStats 获取 RPM 和 TPM近5分钟平均值可选按用户过滤 // getPerformanceStats 获取 RPM 和 TPM近5分钟平均值可选按用户过滤
func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) { func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute) fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
var perfStats struct { var perfStats struct {
RequestCount int64 `gorm:"column:request_count"` RequestCount int64 `gorm:"column:request_count"`
@@ -43,20 +46,20 @@ func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int
return perfStats.RequestCount / 5, perfStats.TokenCount / 5 return perfStats.RequestCount / 5, perfStats.TokenCount / 5
} }
func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error { func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error return r.db.WithContext(ctx).Create(log).Error
} }
func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
var log model.UsageLog var log model.UsageLog
err := r.db.WithContext(ctx).First(&log, id).Error err := r.db.WithContext(ctx).First(&log, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil)
} }
return &log, nil return &log, nil
} }
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -83,7 +86,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
}, nil }, nil
} }
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -120,7 +123,7 @@ type UserStats struct {
CacheReadTokens int64 `json:"cache_read_tokens"` CacheReadTokens int64 `json:"cache_read_tokens"`
} }
func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) { func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
var stats UserStats var stats UserStats
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(` Select(`
@@ -139,7 +142,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
// DashboardStats 仪表盘统计 // DashboardStats 仪表盘统计
type DashboardStats = usagestats.DashboardStats type DashboardStats = usagestats.DashboardStats
func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
var stats DashboardStats var stats DashboardStats
today := timezone.Today() today := timezone.Today()
@@ -260,7 +263,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil return &stats, nil
} }
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -287,7 +290,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}, nil }, nil
} }
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime). Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
@@ -296,7 +299,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime). Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
@@ -305,7 +308,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
@@ -314,7 +317,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime). Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
@@ -323,12 +326,12 @@ func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelN
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error { func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
} }
// GetAccountTodayStats 获取账号今日统计 // GetAccountTodayStats 获取账号今日统计
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
today := timezone.Today() today := timezone.Today()
var stats struct { var stats struct {
@@ -358,7 +361,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
} }
// GetAccountWindowStats 获取账号时间窗口内的统计 // GetAccountWindowStats 获取账号时间窗口内的统计
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
var stats struct { var stats struct {
Requests int64 `gorm:"column:requests"` Requests int64 `gorm:"column:requests"`
Tokens int64 `gorm:"column:tokens"` Tokens int64 `gorm:"column:tokens"`
@@ -398,7 +401,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date // GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) { func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
var results []ApiKeyUsageTrendPoint var results []ApiKeyUsageTrendPoint
// Choose date format based on granularity // Choose date format based on granularity
@@ -442,7 +445,7 @@ func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
} }
// GetUserUsageTrend returns usage trend data grouped by user and date // GetUserUsageTrend returns usage trend data grouped by user and date
func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) { func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) {
var results []UserUsageTrendPoint var results []UserUsageTrendPoint
// Choose date format based on granularity // Choose date format based on granularity
@@ -491,7 +494,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
type UserDashboardStats = usagestats.UserDashboardStats type UserDashboardStats = usagestats.UserDashboardStats
// GetUserDashboardStats 获取用户专属的仪表盘统计 // GetUserDashboardStats 获取用户专属的仪表盘统计
func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) { func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
var stats UserDashboardStats var stats UserDashboardStats
today := timezone.Today() today := timezone.Today()
@@ -578,7 +581,7 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
} }
// GetUserUsageTrendByUserID 获取指定用户的使用趋势 // GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) { func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
var results []TrendDataPoint var results []TrendDataPoint
var dateFormat string var dateFormat string
@@ -612,7 +615,7 @@ func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
} }
// GetUserModelStats 获取指定用户的模型统计 // GetUserModelStats 获取指定用户的模型统计
func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) { func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
@@ -641,7 +644,7 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64
type UsageLogFilters = usagestats.UsageLogFilters type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin) // ListWithFilters lists usage logs with optional filters (for admin)
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -692,7 +695,7 @@ type UsageStats = usagestats.UsageStats
type BatchUserUsageStats = usagestats.BatchUserUsageStats type BatchUserUsageStats = usagestats.BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users // GetBatchUserUsageStats gets today and total actual_cost for multiple users
func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
if len(userIDs) == 0 { if len(userIDs) == 0 {
return make(map[int64]*BatchUserUsageStats), nil return make(map[int64]*BatchUserUsageStats), nil
} }
@@ -752,7 +755,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys // GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) { func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return make(map[int64]*BatchApiKeyUsageStats), nil return make(map[int64]*BatchApiKeyUsageStats), nil
} }
@@ -809,7 +812,7 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
} }
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters // GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) { func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
var results []TrendDataPoint var results []TrendDataPoint
var dateFormat string var dateFormat string
@@ -848,7 +851,7 @@ func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
} }
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters // GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
@@ -882,7 +885,7 @@ func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
} }
// GetGlobalStats gets usage statistics for all users within a time range // GetGlobalStats gets usage statistics for all users within a time range
func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) { func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
var stats struct { var stats struct {
TotalRequests int64 `gorm:"column:total_requests"` TotalRequests int64 `gorm:"column:total_requests"`
TotalInputTokens int64 `gorm:"column:total_input_tokens"` TotalInputTokens int64 `gorm:"column:total_input_tokens"`
@@ -932,7 +935,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range // GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) { func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1 daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
if daysCount <= 0 { if daysCount <= 0 {
daysCount = 30 daysCount = 30

View File

@@ -19,13 +19,13 @@ type UsageLogRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *UsageLogRepository repo *usageLogRepository
} }
func (s *UsageLogRepoSuite) SetupTest() { func (s *UsageLogRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewUsageLogRepository(s.db) s.repo = NewUsageLogRepository(s.db).(*usageLogRepository)
} }
func TestUsageLogRepoSuite(t *testing.T) { func TestUsageLogRepoSuite(t *testing.T) {

View File

@@ -2,56 +2,61 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
type UserRepository struct { type userRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewUserRepository(db *gorm.DB) *UserRepository { func NewUserRepository(db *gorm.DB) service.UserRepository {
return &UserRepository{db: db} return &userRepository{db: db}
} }
func (r *UserRepository) Create(ctx context.Context, user *model.User) error { func (r *userRepository) Create(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Create(user).Error err := r.db.WithContext(ctx).Create(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*model.User, error) { func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
var user model.User var user model.User
err := r.db.WithContext(ctx).First(&user, id).Error err := r.db.WithContext(ctx).First(&user, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return &user, nil
} }
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) { func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User var user model.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return &user, nil
} }
func (r *UserRepository) Update(ctx context.Context, user *model.User) error { func (r *userRepository) Update(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Save(user).Error err := r.db.WithContext(ctx).Save(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *UserRepository) Delete(ctx context.Context, id int64) error { func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
} }
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists users with optional filtering by status, role, and search query // ListWithFilters lists users with optional filtering by status, role, and search query
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User var users []model.User
var total int64 var total int64
@@ -120,13 +125,13 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.
}, nil }, nil
} }
func (r *UserRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error Update("balance", gorm.Expr("balance + ?", amount)).Error
} }
// DeductBalance 扣减用户余额,仅当余额充足时执行 // DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&model.User{}).
Where("id = ? AND balance >= ?", id, amount). Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount)) Update("balance", gorm.Expr("balance - ?", amount))
@@ -134,17 +139,17 @@ func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount flo
return result.Error return result.Error
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 余额不足或用户不存在 return service.ErrInsufficientBalance
} }
return nil return nil
} }
func (r *UserRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
} }
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
return count > 0, err return count > 0, err
@@ -152,7 +157,7 @@ func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool,
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID // RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数 // 使用 PostgreSQL 的 array_remove 函数
func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&model.User{}).
Where("? = ANY(allowed_groups)", groupID). Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID)) Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
@@ -160,14 +165,14 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
} }
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证) // GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) { func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
var user model.User var user model.User
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive). Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
Order("id ASC"). Order("id ASC").
First(&user).Error First(&user).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return &user, nil
} }

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
@@ -18,13 +19,13 @@ type UserRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *UserRepository repo *userRepository
} }
func (s *UserRepoSuite) SetupTest() { func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewUserRepository(s.db) s.repo = NewUserRepository(s.db).(*userRepository)
} }
func TestUserRepoSuite(t *testing.T) { func TestUserRepoSuite(t *testing.T) {
@@ -247,7 +248,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
err := s.repo.DeductBalance(s.ctx, user.ID, 999) err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance") s.Require().Error(err, "expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrInsufficientBalance)
} }
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
@@ -432,7 +433,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
err = s.repo.DeductBalance(s.ctx, user1.ID, 999) err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance") s.Require().Error(err, "DeductBalance expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error") s.Require().ErrorIs(err, service.ErrInsufficientBalance, "DeductBalance unexpected error")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency") s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID) got5, err := s.repo.GetByID(s.ctx, user1.ID)

View File

@@ -6,27 +6,29 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm" "gorm.io/gorm"
) )
// UserSubscriptionRepository 用户订阅仓库 // UserSubscriptionRepository 用户订阅仓库
type UserSubscriptionRepository struct { type userSubscriptionRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewUserSubscriptionRepository 创建用户订阅仓库 // NewUserSubscriptionRepository 创建用户订阅仓库
func NewUserSubscriptionRepository(db *gorm.DB) *UserSubscriptionRepository { func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &UserSubscriptionRepository{db: db} return &userSubscriptionRepository{db: db}
} }
// Create 创建订阅 // Create 创建订阅
func (r *UserSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error { func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
return r.db.WithContext(ctx).Create(sub).Error err := r.db.WithContext(ctx).Create(sub).Error
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
} }
// GetByID 根据ID获取订阅 // GetByID 根据ID获取订阅
func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
var sub model.UserSubscription var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("User"). Preload("User").
@@ -34,26 +36,26 @@ func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*mo
Preload("AssignedByUser"). Preload("AssignedByUser").
First(&sub, id).Error First(&sub, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return &sub, nil
} }
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅 // GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
func (r *UserSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID). Where("user_id = ? AND group_id = ?", userID, groupID).
First(&sub).Error First(&sub).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return &sub, nil
} }
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅 // GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
@@ -61,24 +63,24 @@ func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con
userID, groupID, model.SubscriptionStatusActive, time.Now()). userID, groupID, model.SubscriptionStatusActive, time.Now()).
First(&sub).Error First(&sub).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return &sub, nil
} }
// Update 更新订阅 // Update 更新订阅
func (r *UserSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error { func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
sub.UpdatedAt = time.Now() sub.UpdatedAt = time.Now()
return r.db.WithContext(ctx).Save(sub).Error return r.db.WithContext(ctx).Save(sub).Error
} }
// Delete 删除订阅 // Delete 删除订阅
func (r *UserSubscriptionRepository) Delete(ctx context.Context, id int64) error { func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
} }
// ListByUserID 获取用户的所有订阅 // ListByUserID 获取用户的所有订阅
func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
@@ -89,7 +91,7 @@ func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID in
} }
// ListActiveByUserID 获取用户的所有有效订阅 // ListActiveByUserID 获取用户的所有有效订阅
func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
@@ -101,7 +103,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
} }
// ListByGroupID 获取分组的所有订阅(分页) // ListByGroupID 获取分组的所有订阅(分页)
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
var total int64 var total int64
@@ -136,7 +138,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
} }
// List 获取所有订阅(分页,支持筛选) // List 获取所有订阅(分页,支持筛选)
func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
var total int64 var total int64
@@ -182,7 +184,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination
} }
// IncrementUsage 增加使用量 // IncrementUsage 增加使用量
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -194,7 +196,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
} }
// ResetDailyUsage 重置日使用量 // ResetDailyUsage 重置日使用量
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -205,7 +207,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
} }
// ResetWeeklyUsage 重置周使用量 // ResetWeeklyUsage 重置周使用量
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -216,7 +218,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
} }
// ResetMonthlyUsage 重置月使用量 // ResetMonthlyUsage 重置月使用量
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -227,7 +229,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
} }
// ActivateWindows 激活所有窗口(首次使用时) // ActivateWindows 激活所有窗口(首次使用时)
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error { func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -239,7 +241,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
} }
// UpdateStatus 更新订阅状态 // UpdateStatus 更新订阅状态
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error { func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -249,7 +251,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
} }
// ExtendExpiry 延长订阅过期时间 // ExtendExpiry 延长订阅过期时间
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error { func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -259,7 +261,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
} }
// UpdateNotes 更新订阅备注 // UpdateNotes 更新订阅备注
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error { func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -269,7 +271,7 @@ func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64,
} }
// ListExpired 获取所有已过期但状态仍为active的订阅 // ListExpired 获取所有已过期但状态仍为active的订阅
func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) { func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
@@ -278,7 +280,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
} }
// BatchUpdateExpiredStatus 批量更新过期订阅状态 // BatchUpdateExpiredStatus 批量更新过期订阅状态
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}). result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{ Updates(map[string]any{
@@ -289,7 +291,7 @@ func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex
} }
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅 // ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("user_id = ? AND group_id = ?", userID, groupID). Where("user_id = ? AND group_id = ?", userID, groupID).
@@ -298,7 +300,7 @@ func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Contex
} }
// CountByGroupID 获取分组的订阅数量 // CountByGroupID 获取分组的订阅数量
func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
@@ -307,7 +309,7 @@ func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID
} }
// CountActiveByGroupID 获取分组的有效订阅数量 // CountActiveByGroupID 获取分组的有效订阅数量
func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ? AND status = ? AND expires_at > ?", Where("group_id = ? AND status = ? AND expires_at > ?",
@@ -317,7 +319,7 @@ func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g
} }
// DeleteByGroupID 删除分组相关的所有订阅记录 // DeleteByGroupID 删除分组相关的所有订阅记录
func (r *UserSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{}) result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }

View File

@@ -17,13 +17,13 @@ type UserSubscriptionRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *UserSubscriptionRepository repo *userSubscriptionRepository
} }
func (s *UserSubscriptionRepoSuite) SetupTest() { func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewUserSubscriptionRepository(s.db) s.repo = NewUserSubscriptionRepository(s.db).(*userSubscriptionRepository)
} }
func TestUserSubscriptionRepoSuite(t *testing.T) { func TestUserSubscriptionRepoSuite(t *testing.T) {

View File

@@ -1,8 +1,6 @@
package repository package repository
import ( import (
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/google/wire" "github.com/google/wire"
) )
@@ -17,7 +15,6 @@ var ProviderSet = wire.NewSet(
NewUsageLogRepository, NewUsageLogRepository,
NewSettingRepository, NewSettingRepository,
NewUserSubscriptionRepository, NewUserSubscriptionRepository,
wire.Struct(new(Repositories), "*"),
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,
@@ -38,15 +35,4 @@ var ProviderSet = wire.NewSet(
NewClaudeOAuthClient, NewClaudeOAuthClient,
NewHTTPUpstream, NewHTTPUpstream,
NewOpenAIOAuthClient, NewOpenAIOAuthClient,
// Bind concrete repositories to service port interfaces
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
wire.Bind(new(ports.GroupRepository), new(*GroupRepository)),
wire.Bind(new(ports.AccountRepository), new(*AccountRepository)),
wire.Bind(new(ports.ProxyRepository), new(*ProxyRepository)),
wire.Bind(new(ports.RedeemCodeRepository), new(*RedeemCodeRepository)),
wire.Bind(new(ports.UsageLogRepository), new(*UsageLogRepository)),
wire.Bind(new(ports.SettingRepository), new(*SettingRepository)),
wire.Bind(new(ports.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
) )

View File

@@ -1,13 +1,13 @@
package server package server
import ( import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/wire" "github.com/google/wire"
) )
@@ -19,15 +19,21 @@ var ProviderSet = wire.NewSet(
) )
// ProvideRouter 提供路由器 // ProvideRouter 提供路由器
func ProvideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine { func ProvideRouter(
cfg *config.Config,
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
) *gin.Engine {
if cfg.Server.Mode == "release" { if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
r := gin.New() r := gin.New()
r.Use(gin.Recovery()) r.Use(middleware2.Recovery())
return SetupRouter(r, cfg, handlers, services, repos) return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth)
} }
// ProvideHTTPServer 提供 HTTP 服务器 // ProvideHTTPServer 提供 HTTP 服务器

View File

@@ -1,32 +1,39 @@
package middleware package middleware
import ( import (
"context"
"crypto/subtle" "crypto/subtle"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// AdminAuth 管理员认证中间件 // NewAdminAuthMiddleware 创建管理员认证中间件
func NewAdminAuthMiddleware(
authService *service.AuthService,
userService *service.UserService,
settingService *service.SettingService,
) AdminAuthMiddleware {
return AdminAuthMiddleware(adminAuth(authService, userService, settingService))
}
// adminAuth 管理员认证中间件实现
// 支持两种认证方式(通过不同的 header 区分): // 支持两种认证方式(通过不同的 header 区分):
// 1. Admin API Key: x-api-key: <admin-api-key> // 1. Admin API Key: x-api-key: <admin-api-key>
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色) // 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
func AdminAuth( func adminAuth(
authService *service.AuthService, authService *service.AuthService,
userRepo interface { userService *service.UserService,
GetByID(ctx context.Context, id int64) (*model.User, error)
GetFirstAdmin(ctx context.Context) (*model.User, error)
},
settingService *service.SettingService, settingService *service.SettingService,
) gin.HandlerFunc { ) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 检查 x-api-key headerAdmin API Key 认证) // 检查 x-api-key headerAdmin API Key 认证)
apiKey := c.GetHeader("x-api-key") apiKey := c.GetHeader("x-api-key")
if apiKey != "" { if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userRepo) { if !validateAdminApiKey(c, apiKey, settingService, userService) {
return return
} }
c.Next() c.Next()
@@ -38,7 +45,7 @@ func AdminAuth(
if authHeader != "" { if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2) parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" { if len(parts) == 2 && parts[0] == "Bearer" {
if !validateJWTForAdmin(c, parts[1], authService, userRepo) { if !validateJWTForAdmin(c, parts[1], authService, userService) {
return return
} }
c.Next() c.Next()
@@ -56,9 +63,7 @@ func validateAdminApiKey(
c *gin.Context, c *gin.Context,
key string, key string,
settingService *service.SettingService, settingService *service.SettingService,
userRepo interface { userService *service.UserService,
GetFirstAdmin(ctx context.Context) (*model.User, error)
},
) bool { ) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context()) storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
if err != nil { if err != nil {
@@ -73,7 +78,7 @@ func validateAdminApiKey(
} }
// 获取真实的管理员用户 // 获取真实的管理员用户
admin, err := userRepo.GetFirstAdmin(c.Request.Context()) admin, err := userService.GetFirstAdmin(c.Request.Context())
if err != nil { if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found") AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
return false return false
@@ -89,14 +94,12 @@ func validateJWTForAdmin(
c *gin.Context, c *gin.Context,
token string, token string,
authService *service.AuthService, authService *service.AuthService,
userRepo interface { userService *service.UserService,
GetByID(ctx context.Context, id int64) (*model.User, error)
},
) bool { ) bool {
// 验证 JWT token // 验证 JWT token
claims, err := authService.ValidateToken(token) claims, err := authService.ValidateToken(token)
if err != nil { if err != nil {
if err == service.ErrTokenExpired { if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired") AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return false return false
} }
@@ -105,7 +108,7 @@ func validateJWTForAdmin(
} }
// 从数据库获取用户 // 从数据库获取用户
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID) user, err := userService.GetByID(c.Request.Context(), claims.UserID)
if err != nil { if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found") AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return false return false

View File

@@ -1,37 +1,24 @@
package middleware package middleware
import ( import (
"context"
"errors" "errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"log" "log"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
// ApiKeyAuthService 定义API Key认证服务需要的接口 // NewApiKeyAuthMiddleware 创建 API Key 认证中间件
type ApiKeyAuthService interface { func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) ApiKeyAuthMiddleware {
GetByKey(ctx context.Context, key string) (*model.ApiKey, error) return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService))
} }
// SubscriptionAuthService 定义订阅认证服务需要的接口 // apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
type SubscriptionAuthService interface { func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) gin.HandlerFunc {
GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error
CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error
CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error
CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error
}
// ApiKeyAuth API Key认证中间件
func ApiKeyAuth(apiKeyRepo ApiKeyAuthService) gin.HandlerFunc {
return ApiKeyAuthWithSubscription(apiKeyRepo, nil)
}
// ApiKeyAuthWithSubscription API Key认证中间件支持订阅验证
func ApiKeyAuthWithSubscription(apiKeyRepo ApiKeyAuthService, subscriptionService SubscriptionAuthService) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme) // 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
@@ -57,7 +44,7 @@ func ApiKeyAuthWithSubscription(apiKeyRepo ApiKeyAuthService, subscriptionServic
} }
// 从数据库验证API key // 从数据库验证API key
apiKey, err := apiKeyRepo.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key") AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")

View File

@@ -1,18 +1,22 @@
package middleware package middleware
import ( import (
"context" "errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// JWTAuth JWT认证中间件 // NewJWTAuthMiddleware 创建 JWT 认证中间件
func JWTAuth(authService *service.AuthService, userRepo interface { func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
GetByID(ctx context.Context, id int64) (*model.User, error) return JWTAuthMiddleware(jwtAuth(authService, userService))
}) gin.HandlerFunc { }
// jwtAuth JWT认证中间件实现
func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 从Authorization header中提取token // 从Authorization header中提取token
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
@@ -37,7 +41,7 @@ func JWTAuth(authService *service.AuthService, userRepo interface {
// 验证token // 验证token
claims, err := authService.ValidateToken(tokenString) claims, err := authService.ValidateToken(tokenString)
if err != nil { if err != nil {
if err == service.ErrTokenExpired { if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired") AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return return
} }
@@ -46,7 +50,7 @@ func JWTAuth(authService *service.AuthService, userRepo interface {
} }
// 从数据库获取最新的用户信息 // 从数据库获取最新的用户信息
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID) user, err := userService.GetByID(c.Request.Context(), claims.UserID)
if err != nil { if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found") AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return return

View File

@@ -0,0 +1,64 @@
package middleware
import (
"errors"
"net"
"net/http"
"os"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
)
// Recovery converts panics into the project's standard JSON error envelope.
//
// It preserves Gin's broken-pipe handling by not attempting to write a response
// when the client connection is already gone.
func Recovery() gin.HandlerFunc {
return gin.CustomRecoveryWithWriter(gin.DefaultErrorWriter, func(c *gin.Context, recovered any) {
recoveredErr, _ := recovered.(error)
if isBrokenPipe(recoveredErr) {
if recoveredErr != nil {
_ = c.Error(recoveredErr)
}
c.Abort()
return
}
if c.Writer.Written() {
c.Abort()
return
}
response.ErrorWithDetails(
c,
http.StatusInternalServerError,
infraerrors.UnknownMessage,
infraerrors.UnknownReason,
nil,
)
c.Abort()
})
}
func isBrokenPipe(err error) bool {
if err == nil {
return false
}
var opErr *net.OpError
if !errors.As(err, &opErr) {
return false
}
var syscallErr *os.SyscallError
if !errors.As(opErr.Err, &syscallErr) {
return false
}
msg := strings.ToLower(syscallErr.Error())
return strings.Contains(msg, "broken pipe") || strings.Contains(msg, "connection reset by peer")
}

View File

@@ -0,0 +1,81 @@
//go:build unit
package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestRecovery(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
handler gin.HandlerFunc
wantHTTPCode int
wantBody response.Response
}{
{
name: "panic_returns_standard_json_500",
handler: func(c *gin.Context) {
panic("boom")
},
wantHTTPCode: http.StatusInternalServerError,
wantBody: response.Response{
Code: http.StatusInternalServerError,
Message: infraerrors.UnknownMessage,
},
},
{
name: "no_panic_passthrough",
handler: func(c *gin.Context) {
response.Success(c, gin.H{"ok": true})
},
wantHTTPCode: http.StatusOK,
wantBody: response.Response{
Code: 0,
Message: "success",
Data: map[string]any{"ok": true},
},
},
{
name: "panic_after_write_does_not_override_body",
handler: func(c *gin.Context) {
response.Success(c, gin.H{"ok": true})
panic("boom")
},
wantHTTPCode: http.StatusOK,
wantBody: response.Response{
Code: 0,
Message: "success",
Data: map[string]any{"ok": true},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := gin.New()
r.Use(Recovery())
r.GET("/t", tt.handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
require.Equal(t, tt.wantHTTPCode, w.Code)
var got response.Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}

View File

@@ -0,0 +1,22 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/google/wire"
)
// JWTAuthMiddleware JWT 认证中间件类型
type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc
// ApiKeyAuthMiddleware API Key 认证中间件类型
type ApiKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware,
NewAdminAuthMiddleware,
NewApiKeyAuthMiddleware,
)

View File

@@ -1,312 +1,54 @@
package server package server
import ( import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/server/routes"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/web" "github.com/Wei-Shaw/sub2api/internal/web"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// SetupRouter 配置路由器中间件和路由 // SetupRouter 配置路由器中间件和路由
func SetupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine { func SetupRouter(
r *gin.Engine,
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
) *gin.Engine {
// 应用中间件 // 应用中间件
r.Use(middleware.Logger()) r.Use(middleware2.Logger())
r.Use(middleware.CORS()) r.Use(middleware2.CORS())
// 注册路由
registerRoutes(r, handlers, services, repos)
// Serve embedded frontend if available // Serve embedded frontend if available
if web.HasEmbeddedFrontend() { if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend()) r.Use(web.ServeEmbeddedFrontend())
} }
// 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth)
return r return r
} }
// registerRoutes 注册所有 HTTP 路由 // registerRoutes 注册所有 HTTP 路由
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) { func registerRoutes(
// 健康检查 r *gin.Engine,
r.GET("/health", func(c *gin.Context) { h *handler.Handlers,
c.JSON(http.StatusOK, gin.H{"status": "ok"}) jwtAuth middleware2.JWTAuthMiddleware,
}) adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
// Claude Code 遥测日志忽略直接返回200 ) {
r.POST("/api/event_logging/batch", func(c *gin.Context) { // 通用路由(健康检查、状态等)
c.Status(http.StatusOK) routes.RegisterCommonRoutes(r)
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
// API v1 // API v1
v1 := r.Group("/api/v1") v1 := r.Group("/api/v1")
{
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证) // 注册各模块路由
settings := v1.Group("/settings") routes.RegisterAuthRoutes(v1, h, jwtAuth)
{ routes.RegisterUserRoutes(v1, h, jwtAuth)
settings.GET("/public", h.Setting.GetPublicSettings) routes.RegisterAdminRoutes(v1, h, adminAuth)
} routes.RegisterGatewayRoutes(r, h, apiKeyAuth)
// 需要认证的接口
authenticated := v1.Group("")
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
{
// 当前用户信息
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
// 管理员接口
admin := v1.Group("/admin")
admin.Use(middleware.AdminAuth(s.Auth, repos.User, s.Setting))
{
// 仪表盘
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
// 用户管理
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
}
// 分组管理
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
// 账号管理
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
// OpenAI OAuth routes
openai := admin.Group("/openai")
{
openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
// 代理管理
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
// 卡密管理
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
// 系统设置
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
}
// 系统管理
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
// 订阅管理
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
// 使用记录管理
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
}
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription), h.OpenAIGateway.Responses)
} }

View File

@@ -0,0 +1,221 @@
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterAdminRoutes 注册管理员路由
func RegisterAdminRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
adminAuth middleware.AdminAuthMiddleware,
) {
admin := v1.Group("/admin")
admin.Use(gin.HandlerFunc(adminAuth))
{
// 仪表盘
registerDashboardRoutes(admin, h)
// 用户管理
registerUserManagementRoutes(admin, h)
// 分组管理
registerGroupRoutes(admin, h)
// 账号管理
registerAccountRoutes(admin, h)
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
// 代理管理
registerProxyRoutes(admin, h)
// 卡密管理
registerRedeemCodeRoutes(admin, h)
// 系统设置
registerSettingsRoutes(admin, h)
// 系统管理
registerSystemRoutes(admin, h)
// 订阅管理
registerSubscriptionRoutes(admin, h)
// 使用记录管理
registerUsageRoutes(admin, h)
}
}
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
}
func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
}
}
func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
}
func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
}
func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
openai := admin.Group("/openai")
{
openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
}
func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
}
func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
}
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
}
}
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
}
func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
}
func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}

View File

@@ -0,0 +1,36 @@
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterAuthRoutes 注册认证相关路由
func RegisterAuthRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
) {
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的当前用户信息
authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth))
{
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
}
}

View File

@@ -0,0 +1,32 @@
package routes
import (
"net/http"
"github.com/gin-gonic/gin"
)
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)
func RegisterCommonRoutes(r *gin.Engine) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Claude Code 遥测日志忽略直接返回200
r.POST("/api/event_logging/batch", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
}

View File

@@ -0,0 +1,30 @@
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterGatewayRoutes 注册 API 网关路由Claude/OpenAI 兼容)
func RegisterGatewayRoutes(
r *gin.Engine,
h *handler.Handlers,
apiKeyAuth middleware.ApiKeyAuthMiddleware,
) {
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(gin.HandlerFunc(apiKeyAuth))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
}

View File

@@ -0,0 +1,72 @@
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterUserRoutes 注册用户相关路由(需要认证)
func RegisterUserRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
) {
authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth))
{
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
}

View File

@@ -2,19 +2,63 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"gorm.io/gorm"
) )
var ( var (
ErrAccountNotFound = errors.New("account not found") ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
) )
type AccountRepository interface {
Create(ctx context.Context, account *model.Account) error
GetByID(ctx context.Context, id int64) (*model.Account, error)
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error)
Update(ctx context.Context, account *model.Account) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error)
ListActive(ctx context.Context) ([]model.Account, error)
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error)
UpdateLastUsed(ctx context.Context, id int64) error
SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]model.Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
}
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
// Nil pointers mean "do not change".
type AccountBulkUpdate struct {
Name *string
ProxyID *int64
Concurrency *int
Priority *int
Status *string
Credentials map[string]any
Extra map[string]any
}
// CreateAccountRequest 创建账号请求 // CreateAccountRequest 创建账号请求
type CreateAccountRequest struct { type CreateAccountRequest struct {
Name string `json:"name"` Name string `json:"name"`
@@ -42,12 +86,12 @@ type UpdateAccountRequest struct {
// AccountService 账号管理服务 // AccountService 账号管理服务
type AccountService struct { type AccountService struct {
accountRepo ports.AccountRepository accountRepo AccountRepository
groupRepo ports.GroupRepository groupRepo GroupRepository
} }
// NewAccountService 创建账号服务实例 // NewAccountService 创建账号服务实例
func NewAccountService(accountRepo ports.AccountRepository, groupRepo ports.GroupRepository) *AccountService { func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
return &AccountService{ return &AccountService{
accountRepo: accountRepo, accountRepo: accountRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
@@ -61,9 +105,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
for _, groupID := range req.GroupIDs { for _, groupID := range req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID) _, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
} }
@@ -100,9 +141,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
} }
return account, nil return account, nil
@@ -139,9 +177,6 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) { func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
} }
@@ -184,9 +219,6 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
for _, groupID := range *req.GroupIDs { for _, groupID := range *req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID) _, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
} }
@@ -204,9 +236,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
// 检查账号是否存在 // 检查账号是否存在
_, err := s.accountRepo.GetByID(ctx, id) _, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }
@@ -221,9 +250,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error { func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }
@@ -249,9 +275,6 @@ func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) { func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrAccountNotFound
}
return "", fmt.Errorf("get account: %w", err) return "", fmt.Errorf("get account: %w", err)
} }
@@ -262,9 +285,6 @@ func (s *AccountService) GetCredential(ctx context.Context, id int64, key string
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error { func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }

View File

@@ -17,8 +17,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -40,14 +38,14 @@ type TestEvent struct {
// AccountTestService handles account testing operations // AccountTestService handles account testing operations
type AccountTestService struct { type AccountTestService struct {
accountRepo ports.AccountRepository accountRepo AccountRepository
oauthService *OAuthService oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService openaiOAuthService *OpenAIOAuthService
httpUpstream ports.HTTPUpstream httpUpstream HTTPUpstream
} }
// NewAccountTestService creates a new AccountTestService // NewAccountTestService creates a new AccountTestService
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream ports.HTTPUpstream) *AccountTestService { func NewAccountTestService(accountRepo AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream HTTPUpstream) *AccountTestService {
return &AccountTestService{ return &AccountTestService{
accountRepo: accountRepo, accountRepo: accountRepo,
oauthService: oauthService, oauthService: oauthService,

View File

@@ -8,10 +8,49 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
type UsageLogRepository interface {
Create(ctx context.Context, log *model.UsageLog) error
GetByID(ctx context.Context, id int64) (*model.UsageLog, error)
Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
// Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error)
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)
// Admin usage listing/stats
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error)
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
// Account stats
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)
}
// usageCache 用于缓存usage数据 // usageCache 用于缓存usage数据
type usageCache struct { type usageCache struct {
data *UsageInfo data *UsageInfo
@@ -69,13 +108,13 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务 // AccountUsageService 账号使用量查询服务
type AccountUsageService struct { type AccountUsageService struct {
accountRepo ports.AccountRepository accountRepo AccountRepository
usageLogRepo ports.UsageLogRepository usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher usageFetcher ClaudeUsageFetcher
} }
// NewAccountUsageService 创建AccountUsageService实例 // NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService { func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
return &AccountUsageService{ return &AccountUsageService{
accountRepo: accountRepo, accountRepo: accountRepo,
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,

View File

@@ -9,9 +9,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"gorm.io/gorm"
) )
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
@@ -221,24 +218,24 @@ type ProxyExitInfoProber interface {
// adminServiceImpl implements AdminService // adminServiceImpl implements AdminService
type adminServiceImpl struct { type adminServiceImpl struct {
userRepo ports.UserRepository userRepo UserRepository
groupRepo ports.GroupRepository groupRepo GroupRepository
accountRepo ports.AccountRepository accountRepo AccountRepository
proxyRepo ports.ProxyRepository proxyRepo ProxyRepository
apiKeyRepo ports.ApiKeyRepository apiKeyRepo ApiKeyRepository
redeemCodeRepo ports.RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber proxyProber ProxyExitInfoProber
} }
// NewAdminService creates a new AdminService // NewAdminService creates a new AdminService
func NewAdminService( func NewAdminService(
userRepo ports.UserRepository, userRepo UserRepository,
groupRepo ports.GroupRepository, groupRepo GroupRepository,
accountRepo ports.AccountRepository, accountRepo AccountRepository,
proxyRepo ports.ProxyRepository, proxyRepo ProxyRepository,
apiKeyRepo ports.ApiKeyRepository, apiKeyRepo ApiKeyRepository,
redeemCodeRepo ports.RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber, proxyProber ProxyExitInfoProber,
) AdminService { ) AdminService {
@@ -552,61 +549,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
} }
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
// 先获取分组信息,检查是否存在 affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("group not found: %w", err)
}
// 订阅类型分组先获取受影响的用户ID列表用于事务后失效缓存
var affectedUserIDs []int64
if group.IsSubscriptionType() && s.billingCacheService != nil {
var subscriptions []model.UserSubscription
if err := s.groupRepo.DB().WithContext(ctx).
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err == nil {
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
}
// 使用事务处理所有级联删除
db := s.groupRepo.DB()
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 如果是订阅类型分组,删除 user_subscriptions 中的相关记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
return fmt.Errorf("delete user subscriptions: %w", err)
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil任何类型的分组都需要
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
return fmt.Errorf("clear api key group_id: %w", err)
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
return fmt.Errorf("remove from allowed_groups: %w", err)
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return fmt.Errorf("delete account groups: %w", err)
}
// 5. 删除分组本身
if err := tx.Delete(&model.Group{}, id).Error; err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
})
if err != nil { if err != nil {
return err return err
} }
@@ -695,6 +638,7 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
} }
if input.ProxyID != nil { if input.ProxyID != nil {
account.ProxyID = input.ProxyID account.ProxyID = input.ProxyID
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
} }
// 只在指针非 nil 时更新 Concurrency支持设置为 0 // 只在指针非 nil 时更新 Concurrency支持设置为 0
if input.Concurrency != nil { if input.Concurrency != nil {
@@ -719,7 +663,8 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
} }
} }
return account, nil // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象)
return s.accountRepo.GetByID(ctx, id)
} }
// BulkUpdateAccounts updates multiple accounts in one request. // BulkUpdateAccounts updates multiple accounts in one request.
@@ -734,7 +679,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
} }
// Prepare bulk updates for columns and JSONB fields. // Prepare bulk updates for columns and JSONB fields.
repoUpdates := ports.AccountBulkUpdate{ repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials, Credentials: input.Credentials,
Extra: input.Extra, Extra: input.Extra,
} }

View File

@@ -6,30 +6,55 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm"
) )
var ( var (
ErrApiKeyNotFound = errors.New("api key not found") ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group") ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists") ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters") ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens") ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later") ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
) )
const ( const (
apiKeyMaxErrorsPerHour = 20 apiKeyMaxErrorsPerHour = 20
) )
type ApiKeyRepository interface {
Create(ctx context.Context, key *model.ApiKey) error
GetByID(ctx context.Context, id int64) (*model.ApiKey, error)
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
Update(ctx context.Context, key *model.ApiKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
IncrementDailyUsage(ctx context.Context, apiKey string) error
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}
// CreateApiKeyRequest 创建API Key请求 // CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct { type CreateApiKeyRequest struct {
Name string `json:"name"` Name string `json:"name"`
@@ -46,21 +71,21 @@ type UpdateApiKeyRequest struct {
// ApiKeyService API Key服务 // ApiKeyService API Key服务
type ApiKeyService struct { type ApiKeyService struct {
apiKeyRepo ports.ApiKeyRepository apiKeyRepo ApiKeyRepository
userRepo ports.UserRepository userRepo UserRepository
groupRepo ports.GroupRepository groupRepo GroupRepository
userSubRepo ports.UserSubscriptionRepository userSubRepo UserSubscriptionRepository
cache ports.ApiKeyCache cache ApiKeyCache
cfg *config.Config cfg *config.Config
} }
// NewApiKeyService 创建API Key服务实例 // NewApiKeyService 创建API Key服务实例
func NewApiKeyService( func NewApiKeyService(
apiKeyRepo ports.ApiKeyRepository, apiKeyRepo ApiKeyRepository,
userRepo ports.UserRepository, userRepo UserRepository,
groupRepo ports.GroupRepository, groupRepo GroupRepository,
userSubRepo ports.UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
cache ports.ApiKeyCache, cache ApiKeyCache,
cfg *config.Config, cfg *config.Config,
) *ApiKeyService { ) *ApiKeyService {
return &ApiKeyService{ return &ApiKeyService{
@@ -158,9 +183,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 验证用户存在 // 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
@@ -168,9 +190,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
if req.GroupID != nil { if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID) group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
@@ -244,9 +263,6 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
return apiKey, nil return apiKey, nil
@@ -260,9 +276,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
// 这里可以添加Redis缓存逻辑暂时直接查询数据库 // 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key) apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
@@ -279,9 +292,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) { func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
@@ -304,9 +314,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
group, err := s.groupRepo.GetByID(ctx, *req.GroupID) group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
@@ -336,9 +343,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error { func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrApiKeyNotFound
}
return fmt.Errorf("get api key: %w", err) return fmt.Errorf("get api key: %w", err)
} }
@@ -369,15 +373,12 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
// 检查API Key状态 // 检查API Key状态
if !apiKey.IsActive() { if !apiKey.IsActive() {
return nil, nil, errors.New("api key is not active") return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
} }
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, apiKey.UserID) user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrUserNotFound
}
return nil, nil, fmt.Errorf("get user: %w", err) return nil, nil, fmt.Errorf("get user: %w", err)
} }
@@ -411,9 +412,6 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
@@ -425,7 +423,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
// 获取用户的所有有效订阅 // 获取用户的所有有效订阅
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID) activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err) return nil, fmt.Errorf("list active subscriptions: %w", err)
} }

View File

@@ -4,26 +4,26 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"log" "log"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
var ( var (
ErrInvalidCredentials = errors.New("invalid email or password") ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = errors.New("user is not active") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = errors.New("email already exists") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrInvalidToken = errors.New("invalid token") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = errors.New("token has expired") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = errors.New("service temporarily unavailable") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
) )
// JWTClaims JWT载荷数据 // JWTClaims JWT载荷数据
@@ -36,7 +36,7 @@ type JWTClaims struct {
// AuthService 认证服务 // AuthService 认证服务
type AuthService struct { type AuthService struct {
userRepo ports.UserRepository userRepo UserRepository
cfg *config.Config cfg *config.Config
settingService *SettingService settingService *SettingService
emailService *EmailService emailService *EmailService
@@ -46,7 +46,7 @@ type AuthService struct {
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
func NewAuthService( func NewAuthService(
userRepo ports.UserRepository, userRepo UserRepository,
cfg *config.Config, cfg *config.Config,
settingService *SettingService, settingService *SettingService,
emailService *EmailService, emailService *EmailService,
@@ -255,7 +255,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
// 查找用户 // 查找用户
user, err := s.userRepo.GetByEmail(ctx, email) user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrUserNotFound) {
return "", nil, ErrInvalidCredentials return "", nil, ErrInvalidCredentials
} }
// 记录数据库错误但不暴露给用户 // 记录数据库错误但不暴露给用户
@@ -357,7 +357,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 获取最新的用户信息 // 获取最新的用户信息
user, err := s.userRepo.GetByID(ctx, claims.UserID) user, err := s.userRepo.GetByID(ctx, claims.UserID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrUserNotFound) {
return "", ErrInvalidToken return "", ErrInvalidToken
} }
log.Printf("[Auth] Database error refreshing token: %v", err) log.Printf("[Auth] Database error refreshing token: %v", err)

View File

@@ -0,0 +1,15 @@
package service
import (
"time"
)
// SubscriptionCacheData represents cached subscription data
type SubscriptionCacheData struct {
Status string
ExpiresAt time.Time
DailyUsage float64
WeeklyUsage float64
MonthlyUsage float64
Version int64
}

View File

@@ -2,20 +2,19 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log" "log"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
// 错误定义 // 错误定义
// 注ErrInsufficientBalance在redeem_service.go中定义 // 注ErrInsufficientBalance在redeem_service.go中定义
// 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 // 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var ( var (
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired") ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
) )
// subscriptionCacheData 订阅缓存数据结构(内部使用) // subscriptionCacheData 订阅缓存数据结构(内部使用)
@@ -31,13 +30,13 @@ type subscriptionCacheData struct {
// BillingCacheService 计费缓存服务 // BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct { type BillingCacheService struct {
cache ports.BillingCache cache BillingCache
userRepo ports.UserRepository userRepo UserRepository
subRepo ports.UserSubscriptionRepository subRepo UserSubscriptionRepository
} }
// NewBillingCacheService 创建计费缓存服务 // NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache ports.BillingCache, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService { func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository) *BillingCacheService {
return &BillingCacheService{ return &BillingCacheService{
cache: cache, cache: cache,
userRepo: userRepo, userRepo: userRepo,
@@ -149,7 +148,7 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
return data, nil return data, nil
} }
func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCacheData) *subscriptionCacheData { func (s *BillingCacheService) convertFromPortsData(data *SubscriptionCacheData) *subscriptionCacheData {
return &subscriptionCacheData{ return &subscriptionCacheData{
Status: data.Status, Status: data.Status,
ExpiresAt: data.ExpiresAt, ExpiresAt: data.ExpiresAt,
@@ -160,8 +159,8 @@ func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCache
} }
} }
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *ports.SubscriptionCacheData { func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *SubscriptionCacheData {
return &ports.SubscriptionCacheData{ return &SubscriptionCacheData{
Status: data.Status, Status: data.Status,
ExpiresAt: data.ExpiresAt, ExpiresAt: data.ExpiresAt,
DailyUsage: data.DailyUsage, DailyUsage: data.DailyUsage,

View File

@@ -1,12 +1,30 @@
package service package service
import ( import (
"context"
"fmt" "fmt"
"github.com/Wei-Shaw/sub2api/internal/config"
"log" "log"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/config"
) )
// BillingCache defines cache operations for billing service
type BillingCache interface {
// Balance operations
GetUserBalance(ctx context.Context, userID int64) (float64, error)
SetUserBalance(ctx context.Context, userID int64, balance float64) error
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
InvalidateUserBalance(ctx context.Context, userID int64) error
// Subscription operations
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
}
// ModelPricing 模型价格配置per-token价格与LiteLLM格式一致 // ModelPricing 模型价格配置per-token价格与LiteLLM格式一致
type ModelPricing struct { type ModelPricing struct {
InputPricePerToken float64 // 每token输入价格 (USD) InputPricePerToken float64 // 每token输入价格 (USD)

View File

@@ -7,10 +7,28 @@ import (
"fmt" "fmt"
"log" "log"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
// ConcurrencyCache defines cache operations for concurrency service
// Uses independent keys per request slot with native Redis TTL for automatic cleanup
type ConcurrencyCache interface {
// Account slot management - each slot is a separate key with independent TTL
// Key format: concurrency:account:{accountID}:{requestID}
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
// User slot management - each slot is a separate key with independent TTL
// Key format: concurrency:user:{userID}:{requestID}
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// Wait queue - uses counter with TTL set only on creation
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
}
// generateRequestID generates a unique request ID for concurrency slot tracking // generateRequestID generates a unique request ID for concurrency slot tracking
// Uses 8 random bytes (16 hex chars) for uniqueness // Uses 8 random bytes (16 hex chars) for uniqueness
func generateRequestID() string { func generateRequestID() string {
@@ -29,11 +47,11 @@ const (
// ConcurrencyService manages concurrent request limiting for accounts and users // ConcurrencyService manages concurrent request limiting for accounts and users
type ConcurrencyService struct { type ConcurrencyService struct {
cache ports.ConcurrencyCache cache ConcurrencyCache
} }
// NewConcurrencyService creates a new ConcurrencyService // NewConcurrencyService creates a new ConcurrencyService
func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService { func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
return &ConcurrencyService{cache: cache} return &ConcurrencyService{cache: cache}
} }

View File

@@ -13,19 +13,18 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
type CRSSyncService struct { type CRSSyncService struct {
accountRepo ports.AccountRepository accountRepo AccountRepository
proxyRepo ports.ProxyRepository proxyRepo ProxyRepository
oauthService *OAuthService oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService openaiOAuthService *OpenAIOAuthService
} }
func NewCRSSyncService( func NewCRSSyncService(
accountRepo ports.AccountRepository, accountRepo AccountRepository,
proxyRepo ports.ProxyRepository, proxyRepo ProxyRepository,
oauthService *OAuthService, oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
) *CRSSyncService { ) *CRSSyncService {

View File

@@ -6,15 +6,14 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
) )
// DashboardService provides aggregated statistics for admin dashboard. // DashboardService provides aggregated statistics for admin dashboard.
type DashboardService struct { type DashboardService struct {
usageRepo ports.UsageLogRepository usageRepo UsageLogRepository
} }
func NewDashboardService(usageRepo ports.UsageLogRepository) *DashboardService { func NewDashboardService(usageRepo UsageLogRepository) *DashboardService {
return &DashboardService{ return &DashboardService{
usageRepo: usageRepo, usageRepo: usageRepo,
} }

View File

@@ -4,23 +4,37 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"math/big" "math/big"
"net/smtp" "net/smtp"
"strconv" "strconv"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
) )
var ( var (
ErrEmailNotConfigured = errors.New("email service not configured") ErrEmailNotConfigured = infraerrors.ServiceUnavailable("EMAIL_NOT_CONFIGURED", "email service not configured")
ErrInvalidVerifyCode = errors.New("invalid or expired verification code") ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
ErrVerifyCodeTooFrequent = errors.New("please wait before requesting a new code") ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = errors.New("too many failed attempts, please request a new code") ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
) )
// EmailCache defines cache operations for email service
type EmailCache interface {
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error
}
// VerificationCodeData represents verification code data
type VerificationCodeData struct {
Code string
Attempts int
CreatedAt time.Time
}
const ( const (
verifyCodeTTL = 15 * time.Minute verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute verifyCodeCooldown = 1 * time.Minute
@@ -40,12 +54,12 @@ type SmtpConfig struct {
// EmailService 邮件服务 // EmailService 邮件服务
type EmailService struct { type EmailService struct {
settingRepo ports.SettingRepository settingRepo SettingRepository
cache ports.EmailCache cache EmailCache
} }
// NewEmailService 创建邮件服务实例 // NewEmailService 创建邮件服务实例
func NewEmailService(settingRepo ports.SettingRepository, cache ports.EmailCache) *EmailService { func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailService {
return &EmailService{ return &EmailService{
settingRepo: settingRepo, settingRepo: settingRepo,
cache: cache, cache: cache,
@@ -205,7 +219,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
} }
// 保存验证码到 Redis // 保存验证码到 Redis
data := &ports.VerificationCodeData{ data := &VerificationCodeData{
Code: code, Code: code,
Attempts: 0, Attempts: 0,
CreatedAt: time.Now(), CreatedAt: time.Now(),

View File

@@ -19,7 +19,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@@ -54,6 +53,13 @@ var allowedHeaders = map[string]bool{
"content-type": true, "content-type": true,
} }
// GatewayCache defines cache operations for gateway service
type GatewayCache interface {
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
}
// ClaudeUsage 表示Claude API返回的usage信息 // ClaudeUsage 表示Claude API返回的usage信息
type ClaudeUsage struct { type ClaudeUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
@@ -74,32 +80,32 @@ type ForwardResult struct {
// GatewayService handles API gateway operations // GatewayService handles API gateway operations
type GatewayService struct { type GatewayService struct {
accountRepo ports.AccountRepository accountRepo AccountRepository
usageLogRepo ports.UsageLogRepository usageLogRepo UsageLogRepository
userRepo ports.UserRepository userRepo UserRepository
userSubRepo ports.UserSubscriptionRepository userSubRepo UserSubscriptionRepository
cache ports.GatewayCache cache GatewayCache
cfg *config.Config cfg *config.Config
billingService *BillingService billingService *BillingService
rateLimitService *RateLimitService rateLimitService *RateLimitService
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
identityService *IdentityService identityService *IdentityService
httpUpstream ports.HTTPUpstream httpUpstream HTTPUpstream
} }
// NewGatewayService creates a new GatewayService // NewGatewayService creates a new GatewayService
func NewGatewayService( func NewGatewayService(
accountRepo ports.AccountRepository, accountRepo AccountRepository,
usageLogRepo ports.UsageLogRepository, usageLogRepo UsageLogRepository,
userRepo ports.UserRepository, userRepo UserRepository,
userSubRepo ports.UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
cache ports.GatewayCache, cache GatewayCache,
cfg *config.Config, cfg *config.Config,
billingService *BillingService, billingService *BillingService,
rateLimitService *RateLimitService, rateLimitService *RateLimitService,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
identityService *IdentityService, identityService *IdentityService,
httpUpstream ports.HTTPUpstream, httpUpstream HTTPUpstream,
) *GatewayService { ) *GatewayService {
return &GatewayService{ return &GatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
@@ -362,8 +368,8 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou
// 重试相关常量 // 重试相关常量
const ( const (
maxRetries = 5 // 最大重试次数 maxRetries = 10 // 最大重试次数
retryDelay = 6 * time.Second // 重试等待时间 retryDelay = 3 * time.Second // 重试等待时间
) )
// shouldRetryUpstreamError 判断是否应该重试上游错误 // shouldRetryUpstreamError 判断是否应该重试上游错误
@@ -507,7 +513,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
// OAuth账号应用统一指纹 // OAuth账号应用统一指纹
var fingerprint *ports.Fingerprint var fingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil { if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹包含随机生成的ClientID // 1. 获取或创建指纹包含随机生成的ClientID
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)

View File

@@ -2,20 +2,35 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"gorm.io/gorm"
) )
var ( var (
ErrGroupNotFound = errors.New("group not found") ErrGroupNotFound = infraerrors.NotFound("GROUP_NOT_FOUND", "group not found")
ErrGroupExists = errors.New("group name already exists") ErrGroupExists = infraerrors.Conflict("GROUP_EXISTS", "group name already exists")
) )
type GroupRepository interface {
Create(ctx context.Context, group *model.Group) error
GetByID(ctx context.Context, id int64) (*model.Group, error)
Update(ctx context.Context, group *model.Group) error
Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error)
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
}
// CreateGroupRequest 创建分组请求 // CreateGroupRequest 创建分组请求
type CreateGroupRequest struct { type CreateGroupRequest struct {
Name string `json:"name"` Name string `json:"name"`
@@ -35,11 +50,11 @@ type UpdateGroupRequest struct {
// GroupService 分组管理服务 // GroupService 分组管理服务
type GroupService struct { type GroupService struct {
groupRepo ports.GroupRepository groupRepo GroupRepository
} }
// NewGroupService 创建分组服务实例 // NewGroupService 创建分组服务实例
func NewGroupService(groupRepo ports.GroupRepository) *GroupService { func NewGroupService(groupRepo GroupRepository) *GroupService {
return &GroupService{ return &GroupService{
groupRepo: groupRepo, groupRepo: groupRepo,
} }
@@ -76,9 +91,6 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) { func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
return group, nil return group, nil
@@ -106,9 +118,6 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) { func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
@@ -153,9 +162,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
// 检查分组是否存在 // 检查分组是否存在
_, err := s.groupRepo.GetByID(ctx, id) _, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrGroupNotFound
}
return fmt.Errorf("get group: %w", err) return fmt.Errorf("get group: %w", err)
} }
@@ -170,9 +176,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) { func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }

View File

@@ -1,4 +1,4 @@
package ports package service
import "net/http" import "net/http"

View File

@@ -7,7 +7,6 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"log" "log"
"net/http" "net/http"
"regexp" "regexp"
@@ -24,7 +23,7 @@ var (
) )
// 默认指纹值(当客户端未提供时使用) // 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = ports.Fingerprint{ var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.0.62 (external, cli)", UserAgent: "claude-cli/2.0.62 (external, cli)",
StainlessLang: "js", StainlessLang: "js",
StainlessPackageVersion: "0.52.0", StainlessPackageVersion: "0.52.0",
@@ -34,20 +33,38 @@ var defaultFingerprint = ports.Fingerprint{
StainlessRuntimeVersion: "v22.14.0", StainlessRuntimeVersion: "v22.14.0",
} }
// Fingerprint represents account fingerprint data
type Fingerprint struct {
ClientID string
UserAgent string
StainlessLang string
StainlessPackageVersion string
StainlessOS string
StainlessArch string
StainlessRuntime string
StainlessRuntimeVersion string
}
// IdentityCache defines cache operations for identity service
type IdentityCache interface {
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
}
// IdentityService 管理OAuth账号的请求身份指纹 // IdentityService 管理OAuth账号的请求身份指纹
type IdentityService struct { type IdentityService struct {
cache ports.IdentityCache cache IdentityCache
} }
// NewIdentityService 创建新的IdentityService // NewIdentityService 创建新的IdentityService
func NewIdentityService(cache ports.IdentityCache) *IdentityService { func NewIdentityService(cache IdentityCache) *IdentityService {
return &IdentityService{cache: cache} return &IdentityService{cache: cache}
} }
// GetOrCreateFingerprint 获取或创建账号的指纹 // GetOrCreateFingerprint 获取或创建账号的指纹
// 如果缓存存在检测user-agent版本新版本则更新 // 如果缓存存在检测user-agent版本新版本则更新
// 如果缓存不存在生成随机ClientID并从请求头创建指纹然后缓存 // 如果缓存不存在生成随机ClientID并从请求头创建指纹然后缓存
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*ports.Fingerprint, error) { func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
// 尝试从缓存获取指纹 // 尝试从缓存获取指纹
cached, err := s.cache.GetFingerprint(ctx, accountID) cached, err := s.cache.GetFingerprint(ctx, accountID)
if err == nil && cached != nil { if err == nil && cached != nil {
@@ -79,8 +96,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
} }
// createFingerprintFromHeaders 从请求头创建指纹 // createFingerprintFromHeaders 从请求头创建指纹
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *ports.Fingerprint { func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
fp := &ports.Fingerprint{} fp := &Fingerprint{}
// 获取User-Agent // 获取User-Agent
if ua := headers.Get("User-Agent"); ua != "" { if ua := headers.Get("User-Agent"); ua != "" {
@@ -109,7 +126,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
} }
// ApplyFingerprint 将指纹应用到请求头覆盖原有的x-stainless-*头) // ApplyFingerprint 将指纹应用到请求头覆盖原有的x-stainless-*头)
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerprint) { func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
if fp == nil { if fp == nil {
return return
} }

View File

@@ -8,9 +8,15 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service/ports" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
// OpenAIOAuthClient interface for OpenAI OAuth operations
type OpenAIOAuthClient interface {
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
}
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows // ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
type ClaudeOAuthClient interface { type ClaudeOAuthClient interface {
GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error)
@@ -22,12 +28,12 @@ type ClaudeOAuthClient interface {
// OAuthService handles OAuth authentication flows // OAuthService handles OAuth authentication flows
type OAuthService struct { type OAuthService struct {
sessionStore *oauth.SessionStore sessionStore *oauth.SessionStore
proxyRepo ports.ProxyRepository proxyRepo ProxyRepository
oauthClient ClaudeOAuthClient oauthClient ClaudeOAuthClient
} }
// NewOAuthService creates a new OAuth service // NewOAuthService creates a new OAuth service
func NewOAuthService(proxyRepo ports.ProxyRepository, oauthClient ClaudeOAuthClient) *OAuthService { func NewOAuthService(proxyRepo ProxyRepository, oauthClient ClaudeOAuthClient) *OAuthService {
return &OAuthService{ return &OAuthService{
sessionStore: oauth.NewSessionStore(), sessionStore: oauth.NewSessionStore(),
proxyRepo: proxyRepo, proxyRepo: proxyRepo,

Some files were not shown because too many files have changed in this diff Show More