mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-07 08:50:22 +08:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b3463769dc | ||
|
|
d9e6cfc44d | ||
|
|
57fd172287 | ||
|
|
8d7a497553 | ||
|
|
b31698b9f2 | ||
|
|
eeaff85e47 | ||
|
|
f51ad2e126 |
@@ -19,14 +19,16 @@ linters:
|
||||
files:
|
||||
- "**/internal/service/**"
|
||||
deny:
|
||||
- pkg: sub2api/internal/repository
|
||||
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
|
||||
desc: "service must not import repository"
|
||||
- pkg: gorm.io/gorm
|
||||
desc: "service must not import gorm"
|
||||
handler-no-repository:
|
||||
list-mode: original
|
||||
files:
|
||||
- "**/internal/handler/**"
|
||||
deny:
|
||||
- pkg: sub2api/internal/repository
|
||||
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
|
||||
desc: "handler must not import repository"
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
|
||||
@@ -16,14 +16,14 @@ build-embed:
|
||||
@echo "构建完成: bin/server (with embedded frontend)"
|
||||
|
||||
test-unit:
|
||||
@go test ./... $(TEST_ARGS)
|
||||
@go test -tags unit ./... -count=1
|
||||
|
||||
test-integration:
|
||||
@go test -tags integration ./internal/repository -count=1 -race -parallel=8
|
||||
@go test -tags integration ./... -count=1 -race -parallel=8
|
||||
|
||||
test-cover-integration:
|
||||
@echo "运行集成测试并生成覆盖率报告..."
|
||||
@go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./internal/repository/...
|
||||
@go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./...
|
||||
@go tool cover -func=coverage.out | tail -1
|
||||
@go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "覆盖率报告已生成: coverage.html"
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/setup"
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
|
||||
@@ -84,7 +84,7 @@ func main() {
|
||||
|
||||
func runSetupServer() {
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
r.Use(middleware.Recovery())
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// Register setup routes
|
||||
|
||||
@@ -4,18 +4,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
@@ -35,6 +36,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
// 业务层 ProviderSets
|
||||
repository.ProviderSet,
|
||||
service.ProviderSet,
|
||||
middleware.ProviderSet,
|
||||
handler.ProviderSet,
|
||||
|
||||
// 服务器层 ProviderSet
|
||||
@@ -62,7 +64,11 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
func provideCleanup(
|
||||
db *gorm.DB,
|
||||
rdb *redis.Client,
|
||||
services *service.Services,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
oauth *service.OAuthService,
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -74,23 +80,23 @@ func provideCleanup(
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
services.TokenRefresh.Stop()
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
services.Pricing.Stop()
|
||||
pricing.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"EmailQueueService", func() error {
|
||||
services.EmailQueue.Stop()
|
||||
emailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
services.OAuth.Stop()
|
||||
oauth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIOAuthService", func() error {
|
||||
services.OpenAIOAuth.Stop()
|
||||
openaiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
@@ -116,54 +117,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
||||
groupService := service.NewGroupService(groupRepository)
|
||||
accountService := service.NewAccountService(accountRepository, groupRepository)
|
||||
proxyService := service.NewProxyService(proxyRepository)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
|
||||
services := &service.Services{
|
||||
Auth: authService,
|
||||
User: userService,
|
||||
ApiKey: apiKeyService,
|
||||
Group: groupService,
|
||||
Account: accountService,
|
||||
Proxy: proxyService,
|
||||
Redeem: redeemService,
|
||||
Usage: usageService,
|
||||
Pricing: pricingService,
|
||||
Billing: billingService,
|
||||
BillingCache: billingCacheService,
|
||||
Admin: adminService,
|
||||
Gateway: gatewayService,
|
||||
OpenAIGateway: openAIGatewayService,
|
||||
OAuth: oAuthService,
|
||||
OpenAIOAuth: openAIOAuthService,
|
||||
RateLimit: rateLimitService,
|
||||
AccountUsage: accountUsageService,
|
||||
AccountTest: accountTestService,
|
||||
Setting: settingService,
|
||||
Email: emailService,
|
||||
EmailQueue: emailQueueService,
|
||||
Turnstile: turnstileService,
|
||||
Subscription: subscriptionService,
|
||||
Concurrency: concurrencyService,
|
||||
Identity: identityService,
|
||||
Update: updateService,
|
||||
TokenRefresh: tokenRefreshService,
|
||||
}
|
||||
repositories := &repository.Repositories{
|
||||
User: userRepository,
|
||||
ApiKey: apiKeyRepository,
|
||||
Group: groupRepository,
|
||||
Account: accountRepository,
|
||||
Proxy: proxyRepository,
|
||||
RedeemCode: redeemCodeRepository,
|
||||
UsageLog: usageLogRepository,
|
||||
Setting: settingRepository,
|
||||
UserSubscription: userSubscriptionRepository,
|
||||
}
|
||||
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
v := provideCleanup(db, client, services)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
|
||||
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -188,7 +148,11 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
func provideCleanup(
|
||||
db *gorm.DB,
|
||||
rdb *redis.Client,
|
||||
services *service.Services,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
oauth *service.OAuthService,
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -199,23 +163,23 @@ func provideCleanup(
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
services.TokenRefresh.Stop()
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
services.Pricing.Stop()
|
||||
pricing.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"EmailQueueService", func() error {
|
||||
services.EmailQueue.Stop()
|
||||
emailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
services.OAuth.Stop()
|
||||
oauth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OpenAIOAuthService", func() error {
|
||||
services.OpenAIOAuth.Stop()
|
||||
openaiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
|
||||
@@ -117,7 +117,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list accounts: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -156,7 +156,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -184,7 +184,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create account: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -218,7 +218,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update account: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -236,7 +236,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.adminService.DeleteAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete account: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -297,7 +297,7 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
|
||||
SyncProxies: syncProxies,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Sync failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -332,7 +332,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -349,7 +349,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -372,7 +372,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update account credentials: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -403,7 +403,7 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
|
||||
|
||||
stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get account stats: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -421,7 +421,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
|
||||
account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to clear error: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -570,7 +570,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
Extra: req.Extra,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to bulk update accounts: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -595,7 +595,7 @@ func (h *OAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
|
||||
result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -613,7 +613,7 @@ func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) {
|
||||
|
||||
result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate setup token URL: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -642,7 +642,7 @@ func (h *OAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -664,7 +664,7 @@ func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) {
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -692,7 +692,7 @@ func (h *OAuthHandler) CookieAuth(c *gin.Context) {
|
||||
Scope: "full",
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Cookie auth failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -714,7 +714,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
|
||||
Scope: "inference",
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Cookie auth failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -732,7 +732,7 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||
|
||||
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -750,7 +750,7 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
|
||||
err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to clear rate limit: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -768,7 +768,7 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
|
||||
|
||||
stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get today stats: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -797,7 +797,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
|
||||
|
||||
account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update schedulable status: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
|
||||
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list groups: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get groups: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
|
||||
|
||||
group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Group not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -133,7 +133,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create group: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update group: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -186,7 +186,7 @@ func (h *GroupHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete group: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -225,7 +225,7 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
|
||||
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get group API keys: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to refresh token: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
// Use OpenAI OAuth service to refresh token
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -159,7 +159,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update account credentials: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -192,7 +192,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -220,7 +220,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
GroupIDs: req.GroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to create account: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ func (h *ProxyHandler) List(c *gin.Context) {
|
||||
|
||||
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list proxies: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
|
||||
if withCount {
|
||||
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get proxies: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, proxies)
|
||||
@@ -78,7 +78,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
|
||||
|
||||
proxies, err := h.adminService.GetAllProxies(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get proxies: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
|
||||
|
||||
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Proxy not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ func (h *ProxyHandler) Create(c *gin.Context) {
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create proxy: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
|
||||
Status: strings.TrimSpace(req.Status),
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update proxy: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -171,7 +171,7 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete proxy: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -189,7 +189,7 @@ func (h *ProxyHandler) Test(c *gin.Context) {
|
||||
|
||||
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to test proxy: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -229,7 +229,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
||||
|
||||
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get proxy accounts: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -272,7 +272,7 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
|
||||
// Check for duplicates (same host, port, username, password)
|
||||
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
|
||||
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list redeem codes: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
|
||||
|
||||
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Redeem code not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate redeem codes: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ func (h *RedeemHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete redeem code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func (h *RedeemHandler) BatchDelete(c *gin.Context) {
|
||||
|
||||
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to batch delete redeem codes: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
|
||||
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to expire redeem code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -178,7 +178,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
// Get all codes without pagination (use large page size)
|
||||
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
|
||||
func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get settings: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -111,14 +111,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||
response.InternalError(c, "Failed to update settings: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 重新获取设置返回
|
||||
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get updated settings: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -166,7 +166,7 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
|
||||
|
||||
err := h.emailService.TestSmtpConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -252,7 +252,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
`
|
||||
|
||||
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
||||
response.BadRequest(c, "Failed to send test email: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -264,7 +264,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get admin API key status: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate admin API key: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -292,7 +292,7 @@ func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
// DELETE /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
|
||||
response.InternalError(c, "Failed to delete admin API key: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
|
||||
|
||||
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Subscription not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
|
||||
Notes: req.Notes,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to assign subscription: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
|
||||
Notes: req.Notes,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -192,7 +192,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
|
||||
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to extend subscription: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -210,7 +210,7 @@ func (h *SubscriptionHandler) Revoke(c *gin.Context) {
|
||||
|
||||
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to revoke subscription: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -230,7 +230,7 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list group subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -248,7 +248,7 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list user subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list usage records: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
if apiKeyID > 0 {
|
||||
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
@@ -168,7 +168,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
if userID > 0 {
|
||||
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, stats)
|
||||
@@ -178,7 +178,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
// Get global stats
|
||||
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -197,7 +197,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||
// Limit to 30 results
|
||||
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to search users: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -236,7 +236,7 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
||||
|
||||
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to search API keys: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
|
||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list users: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
|
||||
user, err := h.adminService.GetUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "User not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to create user: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update user: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ func (h *UserHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.adminService.DeleteUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete user: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -186,7 +186,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update balance: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -206,7 +206,7 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
|
||||
|
||||
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get user API keys: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -226,7 +226,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
|
||||
|
||||
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get user usage: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list API keys: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
|
||||
|
||||
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "API key not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -128,7 +128,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
}
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to create API key: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -173,7 +173,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
|
||||
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to update API key: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -203,7 +203,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
|
||||
|
||||
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to delete API key: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -227,7 +227,7 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
|
||||
|
||||
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get available groups: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -66,14 +66,14 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
||||
if req.VerifyCode == "" {
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Registration failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -95,13 +95,13 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to send verification code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -122,13 +122,13 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
||||
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
|
||||
if err != nil {
|
||||
response.Unauthorized(c, "Login failed: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -41,13 +41,13 @@ func NewGatewayHandler(gatewayService *service.GatewayService, userService *serv
|
||||
// POST /v1/messages
|
||||
func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
@@ -79,7 +79,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
streamStarted := false
|
||||
|
||||
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
@@ -171,7 +171,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// GET /v1/models
|
||||
// Returns different model lists based on the API key's group platform
|
||||
func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
apiKey, _ := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, _ := middleware2.GetApiKeyFromContext(c)
|
||||
|
||||
// Return OpenAI models for OpenAI platform groups
|
||||
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform == "openai" {
|
||||
@@ -192,13 +192,13 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
// Usage handles getting account balance for CC Switch integration
|
||||
// GET /v1/usage
|
||||
func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
@@ -206,7 +206,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
|
||||
// 订阅模式:返回订阅限额信息
|
||||
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
|
||||
subscription, ok := middleware.GetSubscriptionFromContext(c)
|
||||
subscription, ok := middleware2.GetSubscriptionFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription")
|
||||
return
|
||||
@@ -328,13 +328,13 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
|
||||
// 特点:校验订阅/余额,但不计算并发、不记录使用量
|
||||
func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
@@ -362,7 +362,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取订阅信息(可能为nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 校验 billing eligibility(订阅/余额)
|
||||
// 【注意】不计算并发,但需要校验订阅/余额
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -40,13 +40,13 @@ func NewOpenAIGatewayHandler(
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
@@ -91,7 +91,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
streamStarted := false
|
||||
|
||||
// Get subscription info (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. Check if wait queue is full
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
|
||||
@@ -57,7 +57,7 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
|
||||
|
||||
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to redeem code: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func (h *RedeemHandler) GetHistory(c *gin.Context) {
|
||||
|
||||
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get history: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ func NewSettingHandler(settingService *service.SettingService, version string) *
|
||||
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetPublicSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get settings: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func (h *SubscriptionHandler) GetActive(c *gin.Context) {
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get active subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
|
||||
// Get all active subscriptions with progress
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
|
||||
// Get all active subscriptions
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "API key not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != user.ID {
|
||||
@@ -77,7 +77,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
|
||||
}
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to list usage records: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
|
||||
|
||||
record, err := h.usageService.GetByID(c.Request.Context(), usageID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Usage record not found")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -204,7 +204,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
|
||||
}
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -259,7 +259,7 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
||||
|
||||
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get dashboard statistics")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
|
||||
|
||||
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get usage trend")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -317,7 +317,7 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
|
||||
|
||||
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get model statistics")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -362,7 +362,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
// Verify ownership of all requested API keys
|
||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to verify API key ownership")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -386,7 +386,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
|
||||
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get API key usage stats")
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
|
||||
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to get user profile: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
}
|
||||
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to change password: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
}
|
||||
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to update profile: "+err.Error())
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
158
backend/internal/infrastructure/errors/errors.go
Normal file
158
backend/internal/infrastructure/errors/errors.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
UnknownCode = http.StatusInternalServerError
|
||||
UnknownReason = ""
|
||||
UnknownMessage = "internal error"
|
||||
)
|
||||
|
||||
type Status struct {
|
||||
Code int32 `json:"code"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ApplicationError is the standard error type used to control HTTP responses.
|
||||
//
|
||||
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
|
||||
type ApplicationError struct {
|
||||
Status
|
||||
cause error
|
||||
}
|
||||
|
||||
// Error is kept for backwards compatibility within this package.
|
||||
type Error = ApplicationError
|
||||
|
||||
func (e *ApplicationError) Error() string {
|
||||
if e == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
if e.cause == nil {
|
||||
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
|
||||
}
|
||||
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
|
||||
}
|
||||
|
||||
// Unwrap provides compatibility for Go 1.13 error chains.
|
||||
func (e *ApplicationError) Unwrap() error { return e.cause }
|
||||
|
||||
// Is matches each error in the chain with the target value.
|
||||
func (e *ApplicationError) Is(err error) bool {
|
||||
if se := new(ApplicationError); errors.As(err, &se) {
|
||||
return se.Code == e.Code && se.Reason == e.Reason
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WithCause attaches the underlying cause of the error.
|
||||
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
|
||||
err := Clone(e)
|
||||
err.cause = cause
|
||||
return err
|
||||
}
|
||||
|
||||
// WithMetadata deep-copies the given metadata map.
|
||||
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
|
||||
err := Clone(e)
|
||||
if md == nil {
|
||||
err.Metadata = nil
|
||||
return err
|
||||
}
|
||||
err.Metadata = make(map[string]string, len(md))
|
||||
for k, v := range md {
|
||||
err.Metadata[k] = v
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// New returns an error object for the code, message.
|
||||
func New(code int, reason, message string) *ApplicationError {
|
||||
return &ApplicationError{
|
||||
Status: Status{
|
||||
Code: int32(code),
|
||||
Message: message,
|
||||
Reason: reason,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Newf New(code fmt.Sprintf(format, a...))
|
||||
func Newf(code int, reason, format string, a ...any) *ApplicationError {
|
||||
return New(code, reason, fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// Errorf returns an error object for the code, message and error info.
|
||||
func Errorf(code int, reason, format string, a ...any) error {
|
||||
return New(code, reason, fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// Code returns the http code for an error.
|
||||
// It supports wrapped errors.
|
||||
func Code(err error) int {
|
||||
if err == nil {
|
||||
return http.StatusOK
|
||||
}
|
||||
return int(FromError(err).Code)
|
||||
}
|
||||
|
||||
// Reason returns the reason for a particular error.
|
||||
// It supports wrapped errors.
|
||||
func Reason(err error) string {
|
||||
if err == nil {
|
||||
return UnknownReason
|
||||
}
|
||||
return FromError(err).Reason
|
||||
}
|
||||
|
||||
// Message returns the message for a particular error.
|
||||
// It supports wrapped errors.
|
||||
func Message(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return FromError(err).Message
|
||||
}
|
||||
|
||||
// Clone deep clone error to a new error.
|
||||
func Clone(err *ApplicationError) *ApplicationError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var metadata map[string]string
|
||||
if err.Metadata != nil {
|
||||
metadata = make(map[string]string, len(err.Metadata))
|
||||
for k, v := range err.Metadata {
|
||||
metadata[k] = v
|
||||
}
|
||||
}
|
||||
return &ApplicationError{
|
||||
cause: err.cause,
|
||||
Status: Status{
|
||||
Code: err.Code,
|
||||
Reason: err.Reason,
|
||||
Message: err.Message,
|
||||
Metadata: metadata,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// FromError tries to convert an error to *ApplicationError.
|
||||
// It supports wrapped errors.
|
||||
func FromError(err error) *ApplicationError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if se := new(ApplicationError); errors.As(err, &se) {
|
||||
return se
|
||||
}
|
||||
|
||||
// Fall back to a generic internal error.
|
||||
return New(UnknownCode, UnknownReason, UnknownMessage).WithCause(err)
|
||||
}
|
||||
168
backend/internal/infrastructure/errors/errors_test.go
Normal file
168
backend/internal/infrastructure/errors/errors_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
//go:build unit
|
||||
|
||||
package errors
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplicationError_Basics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *ApplicationError
|
||||
want Status
|
||||
wantIs bool
|
||||
target error
|
||||
wrapped error
|
||||
}{
|
||||
{
|
||||
name: "new",
|
||||
err: New(400, "BAD_REQUEST", "invalid input"),
|
||||
want: Status{
|
||||
Code: 400,
|
||||
Reason: "BAD_REQUEST",
|
||||
Message: "invalid input",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "is_matches_code_and_reason",
|
||||
err: New(401, "UNAUTHORIZED", "nope"),
|
||||
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
|
||||
target: New(401, "UNAUTHORIZED", "ignored message"),
|
||||
wantIs: true,
|
||||
},
|
||||
{
|
||||
name: "is_does_not_match_reason",
|
||||
err: New(401, "UNAUTHORIZED", "nope"),
|
||||
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
|
||||
target: New(401, "DIFFERENT", "ignored message"),
|
||||
wantIs: false,
|
||||
},
|
||||
{
|
||||
name: "from_error_unwraps_wrapped_application_error",
|
||||
err: New(404, "NOT_FOUND", "missing"),
|
||||
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
|
||||
want: Status{
|
||||
Code: 404,
|
||||
Reason: "NOT_FOUND",
|
||||
Message: "missing",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.err != nil {
|
||||
require.Equal(t, tt.want, tt.err.Status)
|
||||
}
|
||||
|
||||
if tt.target != nil {
|
||||
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
|
||||
}
|
||||
|
||||
if tt.wrapped != nil {
|
||||
got := FromError(tt.wrapped)
|
||||
require.Equal(t, tt.want, got.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
md map[string]string
|
||||
}{
|
||||
{name: "non_nil", md: map[string]string{"a": "1"}},
|
||||
{name: "nil", md: nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
|
||||
|
||||
if tt.md == nil {
|
||||
require.Nil(t, appErr.Metadata)
|
||||
return
|
||||
}
|
||||
|
||||
tt.md["a"] = "changed"
|
||||
require.Equal(t, "1", appErr.Metadata["a"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromError_Generic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantCode int32
|
||||
wantReason string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "plain_error",
|
||||
err: stderrors.New("boom"),
|
||||
wantCode: UnknownCode,
|
||||
wantReason: UnknownReason,
|
||||
wantMsg: UnknownMessage,
|
||||
},
|
||||
{
|
||||
name: "wrapped_plain_error",
|
||||
err: fmt.Errorf("wrap: %w", io.EOF),
|
||||
wantCode: UnknownCode,
|
||||
wantReason: UnknownReason,
|
||||
wantMsg: UnknownMessage,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FromError(tt.err)
|
||||
require.Equal(t, tt.wantCode, got.Code)
|
||||
require.Equal(t, tt.wantReason, got.Reason)
|
||||
require.Equal(t, tt.wantMsg, got.Message)
|
||||
require.Equal(t, tt.err, got.Unwrap())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantStatusCode int
|
||||
wantBody Status
|
||||
}{
|
||||
{
|
||||
name: "nil_error",
|
||||
err: nil,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: Status{Code: int32(http.StatusOK)},
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: Forbidden("FORBIDDEN", "no access"),
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
wantBody: Status{
|
||||
Code: int32(http.StatusForbidden),
|
||||
Reason: "FORBIDDEN",
|
||||
Message: "no access",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
code, body := ToHTTP(tt.err)
|
||||
require.Equal(t, tt.wantStatusCode, code)
|
||||
require.Equal(t, tt.wantBody, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
21
backend/internal/infrastructure/errors/http.go
Normal file
21
backend/internal/infrastructure/errors/http.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package errors
|
||||
|
||||
import "net/http"
|
||||
|
||||
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
|
||||
//
|
||||
// The returned body matches the project's Status shape:
|
||||
// { code, reason, message, metadata }.
|
||||
func ToHTTP(err error) (statusCode int, body Status) {
|
||||
if err == nil {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
appErr := FromError(err)
|
||||
if appErr == nil {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
cloned := Clone(appErr)
|
||||
return int(cloned.Code), cloned.Status
|
||||
}
|
||||
114
backend/internal/infrastructure/errors/types.go
Normal file
114
backend/internal/infrastructure/errors/types.go
Normal file
@@ -0,0 +1,114 @@
|
||||
// nolint:mnd
|
||||
package errors
|
||||
|
||||
import "net/http"
|
||||
|
||||
// BadRequest new BadRequest error that is mapped to a 400 response.
|
||||
func BadRequest(reason, message string) *ApplicationError {
|
||||
return New(http.StatusBadRequest, reason, message)
|
||||
}
|
||||
|
||||
// IsBadRequest determines if err is an error which indicates a BadRequest error.
|
||||
// It supports wrapped errors.
|
||||
func IsBadRequest(err error) bool {
|
||||
return Code(err) == http.StatusBadRequest
|
||||
}
|
||||
|
||||
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
|
||||
func TooManyRequests(reason, message string) *ApplicationError {
|
||||
return New(http.StatusTooManyRequests, reason, message)
|
||||
}
|
||||
|
||||
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
|
||||
// It supports wrapped errors.
|
||||
func IsTooManyRequests(err error) bool {
|
||||
return Code(err) == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
// Unauthorized new Unauthorized error that is mapped to a 401 response.
|
||||
func Unauthorized(reason, message string) *ApplicationError {
|
||||
return New(http.StatusUnauthorized, reason, message)
|
||||
}
|
||||
|
||||
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
|
||||
// It supports wrapped errors.
|
||||
func IsUnauthorized(err error) bool {
|
||||
return Code(err) == http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// Forbidden new Forbidden error that is mapped to a 403 response.
|
||||
func Forbidden(reason, message string) *ApplicationError {
|
||||
return New(http.StatusForbidden, reason, message)
|
||||
}
|
||||
|
||||
// IsForbidden determines if err is an error which indicates a Forbidden error.
|
||||
// It supports wrapped errors.
|
||||
func IsForbidden(err error) bool {
|
||||
return Code(err) == http.StatusForbidden
|
||||
}
|
||||
|
||||
// NotFound new NotFound error that is mapped to a 404 response.
|
||||
func NotFound(reason, message string) *ApplicationError {
|
||||
return New(http.StatusNotFound, reason, message)
|
||||
}
|
||||
|
||||
// IsNotFound determines if err is an error which indicates an NotFound error.
|
||||
// It supports wrapped errors.
|
||||
func IsNotFound(err error) bool {
|
||||
return Code(err) == http.StatusNotFound
|
||||
}
|
||||
|
||||
// Conflict new Conflict error that is mapped to a 409 response.
|
||||
func Conflict(reason, message string) *ApplicationError {
|
||||
return New(http.StatusConflict, reason, message)
|
||||
}
|
||||
|
||||
// IsConflict determines if err is an error which indicates a Conflict error.
|
||||
// It supports wrapped errors.
|
||||
func IsConflict(err error) bool {
|
||||
return Code(err) == http.StatusConflict
|
||||
}
|
||||
|
||||
// InternalServer new InternalServer error that is mapped to a 500 response.
|
||||
func InternalServer(reason, message string) *ApplicationError {
|
||||
return New(http.StatusInternalServerError, reason, message)
|
||||
}
|
||||
|
||||
// IsInternalServer determines if err is an error which indicates an Internal error.
|
||||
// It supports wrapped errors.
|
||||
func IsInternalServer(err error) bool {
|
||||
return Code(err) == http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
|
||||
func ServiceUnavailable(reason, message string) *ApplicationError {
|
||||
return New(http.StatusServiceUnavailable, reason, message)
|
||||
}
|
||||
|
||||
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
|
||||
// It supports wrapped errors.
|
||||
func IsServiceUnavailable(err error) bool {
|
||||
return Code(err) == http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
|
||||
func GatewayTimeout(reason, message string) *ApplicationError {
|
||||
return New(http.StatusGatewayTimeout, reason, message)
|
||||
}
|
||||
|
||||
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
|
||||
// It supports wrapped errors.
|
||||
func IsGatewayTimeout(err error) bool {
|
||||
return Code(err) == http.StatusGatewayTimeout
|
||||
}
|
||||
|
||||
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
|
||||
func ClientClosed(reason, message string) *ApplicationError {
|
||||
return New(499, reason, message)
|
||||
}
|
||||
|
||||
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
|
||||
// It supports wrapped errors.
|
||||
func IsClientClosed(err error) bool {
|
||||
return Code(err) == 499
|
||||
}
|
||||
@@ -4,14 +4,17 @@ import (
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 标准API响应格式
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PaginatedData 分页数据格式(匹配前端期望)
|
||||
@@ -44,11 +47,36 @@ func Created(c *gin.Context, data any) {
|
||||
// Error 返回错误响应
|
||||
func Error(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Reason: "",
|
||||
Metadata: nil,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorWithDetails returns an error response compatible with the existing envelope while
|
||||
// optionally providing structured error fields (reason/metadata).
|
||||
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Reason: reason,
|
||||
Metadata: metadata,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
|
||||
// It returns true if an error was written.
|
||||
func ErrorFrom(c *gin.Context, err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
||||
return true
|
||||
}
|
||||
|
||||
// BadRequest 返回400错误
|
||||
func BadRequest(c *gin.Context, message string) {
|
||||
Error(c, http.StatusBadRequest, message)
|
||||
|
||||
171
backend/internal/pkg/response/response_test.go
Normal file
171
backend/internal/pkg/response/response_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
//go:build unit
|
||||
|
||||
package response
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestErrorWithDetails(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
message string
|
||||
reason string
|
||||
metadata map[string]string
|
||||
want Response
|
||||
}{
|
||||
{
|
||||
name: "plain_error",
|
||||
statusCode: http.StatusBadRequest,
|
||||
message: "invalid request",
|
||||
want: Response{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "invalid request",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "structured_error",
|
||||
statusCode: http.StatusForbidden,
|
||||
message: "no access",
|
||||
reason: "FORBIDDEN",
|
||||
metadata: map[string]string{"k": "v"},
|
||||
want: Response{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "no access",
|
||||
Reason: "FORBIDDEN",
|
||||
Metadata: map[string]string{"k": "v"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
|
||||
|
||||
require.Equal(t, tt.statusCode, w.Code)
|
||||
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorFrom(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantWritten bool
|
||||
wantHTTPCode int
|
||||
wantBody Response
|
||||
}{
|
||||
{
|
||||
name: "nil_error",
|
||||
err: nil,
|
||||
wantWritten: false,
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusForbidden,
|
||||
wantBody: Response{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "no access",
|
||||
Reason: "FORBIDDEN",
|
||||
Metadata: map[string]string{"scope": "admin"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad_request_error",
|
||||
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusBadRequest,
|
||||
wantBody: Response{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "invalid request",
|
||||
Reason: "INVALID_REQUEST",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unauthorized_error",
|
||||
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusUnauthorized,
|
||||
wantBody: Response{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "unauthorized",
|
||||
Reason: "UNAUTHORIZED",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not_found_error",
|
||||
err: infraerrors.NotFound("NOT_FOUND", "not found"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusNotFound,
|
||||
wantBody: Response{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "not found",
|
||||
Reason: "NOT_FOUND",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "conflict_error",
|
||||
err: infraerrors.Conflict("CONFLICT", "conflict"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusConflict,
|
||||
wantBody: Response{
|
||||
Code: http.StatusConflict,
|
||||
Message: "conflict",
|
||||
Reason: "CONFLICT",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown_error_defaults_to_500",
|
||||
err: errors.New("boom"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusInternalServerError,
|
||||
wantBody: Response{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: infraerrors.UnknownMessage,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
written := ErrorFrom(c, tt.err)
|
||||
require.Equal(t, tt.wantWritten, written)
|
||||
|
||||
if !tt.wantWritten {
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Empty(t, w.Body.String())
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, tt.wantHTTPCode, w.Code)
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.wantBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,32 +3,33 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"time"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type AccountRepository struct {
|
||||
type accountRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAccountRepository(db *gorm.DB) *AccountRepository {
|
||||
return &AccountRepository{db: db}
|
||||
func NewAccountRepository(db *gorm.DB) service.AccountRepository {
|
||||
return &accountRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error {
|
||||
func (r *accountRepository) Create(ctx context.Context, account *model.Account) error {
|
||||
return r.db.WithContext(ctx).Create(account).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||
var account model.Account
|
||||
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||
}
|
||||
// 填充 GroupIDs 和 Groups 虚拟字段
|
||||
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
|
||||
@@ -42,7 +43,7 @@ func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Accou
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
|
||||
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
|
||||
if crsAccountID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -58,11 +59,11 @@ func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error {
|
||||
func (r *accountRepository) Update(ctx context.Context, account *model.Account) error {
|
||||
return r.db.WithContext(ctx).Save(account).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 先删除账号与分组的绑定关系
|
||||
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
return err
|
||||
@@ -71,12 +72,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
|
||||
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
var accounts []model.Account
|
||||
var total int64
|
||||
|
||||
@@ -130,7 +131,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||
@@ -141,7 +142,7 @@ func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]m
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", model.StatusActive).
|
||||
@@ -151,12 +152,12 @@ func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, er
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"status": model.StatusError,
|
||||
@@ -164,7 +165,7 @@ func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
ag := &model.AccountGroup{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
@@ -173,12 +174,12 @@ func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID i
|
||||
return r.db.WithContext(ctx).Create(ag).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
||||
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
||||
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
|
||||
Delete(&model.AccountGroup{}).Error
|
||||
}
|
||||
|
||||
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
|
||||
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
|
||||
@@ -187,7 +188,7 @@ func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]m
|
||||
return groups, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("platform = ? AND status = ?", platform, model.StatusActive).
|
||||
@@ -197,7 +198,7 @@ func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string)
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
// 删除现有绑定
|
||||
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
return err
|
||||
@@ -220,7 +221,7 @@ func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, gro
|
||||
}
|
||||
|
||||
// ListSchedulable 获取所有可调度的账号
|
||||
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
@@ -234,7 +235,7 @@ func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Accoun
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupID 按组获取可调度的账号
|
||||
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
@@ -250,7 +251,7 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
|
||||
}
|
||||
|
||||
// ListSchedulableByPlatform 按平台获取可调度的账号
|
||||
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
@@ -265,7 +266,7 @@ func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platf
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
|
||||
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
|
||||
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
@@ -282,7 +283,7 @@ func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
|
||||
}
|
||||
|
||||
// SetRateLimited 标记账号为限流状态(429)
|
||||
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
@@ -292,13 +293,13 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
}
|
||||
|
||||
// SetOverloaded 标记账号为过载状态(529)
|
||||
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("overload_until", until).Error
|
||||
}
|
||||
|
||||
// ClearRateLimit 清除账号的限流状态
|
||||
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"rate_limited_at": nil,
|
||||
@@ -308,7 +309,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// UpdateSessionWindow 更新账号的5小时时间窗口信息
|
||||
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
updates := map[string]any{
|
||||
"session_window_status": status,
|
||||
}
|
||||
@@ -322,14 +323,14 @@ func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
|
||||
}
|
||||
|
||||
// SetSchedulable 设置账号的调度开关
|
||||
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
Update("schedulable", schedulable).Error
|
||||
}
|
||||
|
||||
// UpdateExtra updates specific fields in account's Extra JSONB field
|
||||
// It merges the updates into existing Extra data without overwriting other fields
|
||||
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -357,7 +358,7 @@ func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
|
||||
// BulkUpdate updates multiple accounts with the provided fields.
|
||||
// It merges credentials/extra JSONB fields instead of overwriting them.
|
||||
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ports.AccountBulkUpdate) (int64, error) {
|
||||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -18,13 +18,13 @@ type AccountRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *AccountRepository
|
||||
repo *accountRepository
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewAccountRepository(s.db)
|
||||
s.repo = NewAccountRepository(s.db).(*accountRepository)
|
||||
}
|
||||
|
||||
func TestAccountRepoSuite(t *testing.T) {
|
||||
@@ -167,7 +167,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
db := testTx(s.T())
|
||||
repo := NewAccountRepository(db)
|
||||
repo := NewAccountRepository(db).(*accountRepository)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.setup(db)
|
||||
@@ -513,7 +513,7 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1})
|
||||
|
||||
newPriority := 99
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, ports.AccountBulkUpdate{
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
|
||||
Priority: &newPriority,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
@@ -531,7 +531,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
|
||||
Credentials: model.JSONB{"existing": "value"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Credentials: model.JSONB{"new_key": "new_value"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
@@ -547,7 +547,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
|
||||
Extra: model.JSONB{"existing": "val"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Extra: model.JSONB{"new_key": "new_val"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
@@ -558,7 +558,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, ports.AccountBulkUpdate{})
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, service.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected)
|
||||
}
|
||||
@@ -566,7 +566,7 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"})
|
||||
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{})
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected)
|
||||
}
|
||||
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -19,7 +18,7 @@ type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewApiKeyCache(rdb *redis.Client) ports.ApiKeyCache {
|
||||
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,51 +2,55 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKeyRepository struct {
|
||||
type apiKeyRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository {
|
||||
return &ApiKeyRepository{db: db}
|
||||
func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
|
||||
return r.db.WithContext(ctx).Create(key).Error
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
|
||||
err := r.db.WithContext(ctx).Create(key).Error
|
||||
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
var key model.ApiKey
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
|
||||
}
|
||||
return &apiKey, nil
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
|
||||
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
var total int64
|
||||
|
||||
@@ -73,19 +77,19 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
var total int64
|
||||
|
||||
@@ -113,7 +117,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
}
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
|
||||
var keys []model.ApiKey
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.ApiKey{})
|
||||
@@ -135,7 +139,7 @@ func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
||||
}
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
|
||||
Where("group_id = ?", groupID).
|
||||
Update("group_id", nil)
|
||||
@@ -143,7 +147,7 @@ func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
|
||||
@@ -16,13 +16,13 @@ type ApiKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *ApiKeyRepository
|
||||
repo *apiKeyRepository
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewApiKeyRepository(s.db)
|
||||
s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository)
|
||||
}
|
||||
|
||||
func TestApiKeyRepoSuite(t *testing.T) {
|
||||
|
||||
@@ -8,8 +8,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -58,7 +57,7 @@ type billingCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewBillingCache(rdb *redis.Client) ports.BillingCache {
|
||||
func NewBillingCache(rdb *redis.Client) service.BillingCache {
|
||||
return &billingCache{rdb: rdb}
|
||||
}
|
||||
|
||||
@@ -90,7 +89,7 @@ func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*ports.SubscriptionCacheData, error) {
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
@@ -102,8 +101,8 @@ func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID
|
||||
return c.parseSubscriptionCache(result)
|
||||
}
|
||||
|
||||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.SubscriptionCacheData, error) {
|
||||
result := &ports.SubscriptionCacheData{}
|
||||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
|
||||
result := &service.SubscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
@@ -136,7 +135,7 @@ func (c *billingCache) parseSubscriptionCache(data map[string]string) (*ports.Su
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *ports.SubscriptionCacheData) error {
|
||||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -21,18 +21,18 @@ type BillingCacheSuite struct {
|
||||
func (s *BillingCacheSuite) TestUserBalance() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache)
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
_, err := cache.GetUserBalance(ctx, 1)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(1)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
@@ -44,7 +44,7 @@ func (s *BillingCacheSuite) TestUserBalance() {
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(2)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
@@ -61,7 +61,7 @@ func (s *BillingCacheSuite) TestUserBalance() {
|
||||
},
|
||||
{
|
||||
name: "deduct_reduces_balance",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(3)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
@@ -74,7 +74,7 @@ func (s *BillingCacheSuite) TestUserBalance() {
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(100)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
@@ -96,7 +96,7 @@ func (s *BillingCacheSuite) TestUserBalance() {
|
||||
},
|
||||
{
|
||||
name: "deduct_refreshes_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(103)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
@@ -133,11 +133,11 @@ func (s *BillingCacheSuite) TestUserBalance() {
|
||||
func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache)
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(10)
|
||||
groupID := int64(20)
|
||||
|
||||
@@ -147,7 +147,7 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
},
|
||||
{
|
||||
name: "update_usage_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(11)
|
||||
groupID := int64(21)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
@@ -161,12 +161,12 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(12)
|
||||
groupID := int64(22)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &ports.SubscriptionCacheData{
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
@@ -189,11 +189,11 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
},
|
||||
{
|
||||
name: "update_usage_increments_all_fields",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(13)
|
||||
groupID := int64(23)
|
||||
|
||||
data := &ports.SubscriptionCacheData{
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
@@ -214,12 +214,12 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(101)
|
||||
groupID := int64(10)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &ports.SubscriptionCacheData{
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
@@ -245,7 +245,7 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
},
|
||||
{
|
||||
name: "missing_status_returns_parsing_error",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(102)
|
||||
groupID := int64(11)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -107,7 +106,7 @@ type concurrencyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewConcurrencyCache(rdb *redis.Client) ports.ConcurrencyCache {
|
||||
func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache {
|
||||
return &concurrencyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
type ConcurrencyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache ports.ConcurrencyCache
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -16,24 +15,24 @@ type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewEmailCache(rdb *redis.Client) ports.EmailCache {
|
||||
func NewEmailCache(rdb *redis.Client) service.EmailCache {
|
||||
return &emailCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*ports.VerificationCodeData, error) {
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data ports.VerificationCodeData
|
||||
var data service.VerificationCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *ports.VerificationCodeData, ttl time.Duration) error {
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
type EmailCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache ports.EmailCache
|
||||
cache service.EmailCache
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) SetupTest() {
|
||||
@@ -31,7 +31,7 @@ func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
|
||||
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
|
||||
email := "a@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &ports.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
|
||||
data := &service.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
@@ -44,7 +44,7 @@ func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
|
||||
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
|
||||
email := "ttl@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &ports.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
|
||||
data := &service.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
@@ -56,7 +56,7 @@ func (s *EmailCacheSuite) TestVerificationCode_TTL() {
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
|
||||
email := "delete@example.com"
|
||||
data := &ports.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
|
||||
data := &service.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
|
||||
|
||||
|
||||
40
backend/internal/repository/error_translate.go
Normal file
40
backend/internal/repository/error_translate.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return notFound.WithCause(err)
|
||||
}
|
||||
|
||||
if conflict != nil && isUniqueConstraintViolation(err) {
|
||||
return conflict.WithCause(err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func isUniqueConstraintViolation(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return true
|
||||
}
|
||||
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "duplicate key") ||
|
||||
strings.Contains(msg, "unique constraint") ||
|
||||
strings.Contains(msg, "duplicate entry")
|
||||
}
|
||||
@@ -4,8 +4,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -15,7 +14,7 @@ type gatewayCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewGatewayCache(rdb *redis.Client) ports.GatewayCache {
|
||||
func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
|
||||
return &gatewayCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
type GatewayCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache ports.GatewayCache
|
||||
cache service.GatewayCache
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) SetupTest() {
|
||||
|
||||
@@ -2,47 +2,52 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type GroupRepository struct {
|
||||
type groupRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewGroupRepository(db *gorm.DB) *GroupRepository {
|
||||
return &GroupRepository{db: db}
|
||||
func NewGroupRepository(db *gorm.DB) service.GroupRepository {
|
||||
return &groupRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *GroupRepository) Create(ctx context.Context, group *model.Group) error {
|
||||
return r.db.WithContext(ctx).Create(group).Error
|
||||
func (r *groupRepository) Create(ctx context.Context, group *model.Group) error {
|
||||
err := r.db.WithContext(ctx).Create(group).Error
|
||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||
}
|
||||
|
||||
func (r *GroupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
|
||||
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
|
||||
var group model.Group
|
||||
err := r.db.WithContext(ctx).First(&group, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) Update(ctx context.Context, group *model.Group) error {
|
||||
func (r *groupRepository) Update(ctx context.Context, group *model.Group) error {
|
||||
return r.db.WithContext(ctx).Save(group).Error
|
||||
}
|
||||
|
||||
func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
|
||||
}
|
||||
|
||||
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
||||
}
|
||||
|
||||
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
|
||||
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
var groups []model.Group
|
||||
var total int64
|
||||
|
||||
@@ -86,7 +91,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
|
||||
func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
|
||||
if err != nil {
|
||||
@@ -100,7 +105,7 @@ func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error)
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
|
||||
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
|
||||
if err != nil {
|
||||
@@ -114,25 +119,80 @@ func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform str
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (r *GroupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GroupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
|
||||
func (r *GroupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DB 返回底层数据库连接,用于事务处理
|
||||
func (r *GroupRepository) DB() *gorm.DB {
|
||||
return r.db
|
||||
func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
group, err := r.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var affectedUserIDs []int64
|
||||
if group.IsSubscriptionType() {
|
||||
var subscriptions []model.UserSubscription
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&model.UserSubscription{}).
|
||||
Where("group_id = ?", id).
|
||||
Select("user_id").
|
||||
Find(&subscriptions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, sub := range subscriptions {
|
||||
affectedUserIDs = append(affectedUserIDs, sub.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 1. 删除订阅类型分组的订阅记录
|
||||
if group.IsSubscriptionType() {
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
|
||||
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 3. 从 users.allowed_groups 数组中移除该分组 ID
|
||||
if err := tx.Model(&model.User{}).
|
||||
Where("? = ANY(allowed_groups)", id).
|
||||
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 4. 删除 account_groups 中间表的数据
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 5. 删除分组本身(带锁,避免并发写)
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return affectedUserIDs, nil
|
||||
}
|
||||
|
||||
@@ -16,13 +16,13 @@ type GroupRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *GroupRepository
|
||||
repo *groupRepository
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewGroupRepository(s.db)
|
||||
s.repo = NewGroupRepository(s.db).(*groupRepository)
|
||||
}
|
||||
|
||||
func TestGroupRepoSuite(t *testing.T) {
|
||||
@@ -234,11 +234,3 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
||||
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- DB ---
|
||||
|
||||
func (s *GroupRepoSuite) TestDB() {
|
||||
db := s.repo.DB()
|
||||
s.Require().NotNil(db, "DB should return non-nil")
|
||||
s.Require().Equal(s.db, db, "DB should return the underlying gorm.DB")
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// httpUpstreamService is a generic HTTP upstream service that can be used for
|
||||
@@ -17,7 +17,7 @@ type httpUpstreamService struct {
|
||||
}
|
||||
|
||||
// NewHTTPUpstream creates a new generic HTTP upstream service
|
||||
func NewHTTPUpstream(cfg *config.Config) ports.HTTPUpstream {
|
||||
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
|
||||
@@ -6,8 +6,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -20,24 +19,24 @@ type identityCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewIdentityCache(rdb *redis.Client) ports.IdentityCache {
|
||||
func NewIdentityCache(rdb *redis.Client) service.IdentityCache {
|
||||
return &identityCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*ports.Fingerprint, error) {
|
||||
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var fp ports.Fingerprint
|
||||
var fp service.Fingerprint
|
||||
if err := json.Unmarshal([]byte(val), &fp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fp, nil
|
||||
}
|
||||
|
||||
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *ports.Fingerprint) error {
|
||||
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
val, err := json.Marshal(fp)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -30,7 +30,7 @@ func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
|
||||
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
|
||||
gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||
require.NoError(s.T(), err, "GetFingerprint")
|
||||
@@ -39,7 +39,7 @@ func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestFingerprint_TTL() {
|
||||
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
|
||||
|
||||
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)
|
||||
|
||||
@@ -7,13 +7,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
|
||||
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
|
||||
func NewOpenAIOAuthClient() service.OpenAIOAuthClient {
|
||||
return &openaiOAuthService{tokenURL: openai.TokenURL}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,47 +2,50 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ProxyRepository struct {
|
||||
type proxyRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewProxyRepository(db *gorm.DB) *ProxyRepository {
|
||||
return &ProxyRepository{db: db}
|
||||
func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
|
||||
return &proxyRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
|
||||
func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
|
||||
return r.db.WithContext(ctx).Create(proxy).Error
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
|
||||
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
|
||||
var proxy model.Proxy
|
||||
err := r.db.WithContext(ctx).First(&proxy, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
|
||||
}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
|
||||
func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
|
||||
return r.db.WithContext(ctx).Save(proxy).Error
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
|
||||
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
var proxies []model.Proxy
|
||||
var total int64
|
||||
|
||||
@@ -81,14 +84,14 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ProxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
|
||||
func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
|
||||
var proxies []model.Proxy
|
||||
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
|
||||
return proxies, err
|
||||
}
|
||||
|
||||
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
|
||||
func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Proxy{}).
|
||||
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
|
||||
@@ -100,7 +103,7 @@ func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
|
||||
}
|
||||
|
||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||
func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Account{}).
|
||||
Where("proxy_id = ?", proxyID).
|
||||
@@ -109,7 +112,7 @@ func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
|
||||
}
|
||||
|
||||
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
||||
func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
|
||||
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
|
||||
type result struct {
|
||||
ProxyID int64 `gorm:"column:proxy_id"`
|
||||
Count int64 `gorm:"column:count"`
|
||||
@@ -133,7 +136,7 @@ func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
|
||||
}
|
||||
|
||||
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
|
||||
func (r *ProxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
|
||||
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
|
||||
var proxies []model.Proxy
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", model.StatusActive).
|
||||
|
||||
@@ -17,13 +17,13 @@ type ProxyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *ProxyRepository
|
||||
repo *proxyRepository
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewProxyRepository(s.db)
|
||||
s.repo = NewProxyRepository(s.db).(*proxyRepository)
|
||||
}
|
||||
|
||||
func TestProxyRepoSuite(t *testing.T) {
|
||||
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -20,7 +19,7 @@ type redeemCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewRedeemCache(rdb *redis.Client) ports.RedeemCache {
|
||||
func NewRedeemCache(rdb *redis.Client) service.RedeemCache {
|
||||
return &redeemCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,57 +2,60 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type RedeemCodeRepository struct {
|
||||
type redeemCodeRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRedeemCodeRepository(db *gorm.DB) *RedeemCodeRepository {
|
||||
return &RedeemCodeRepository{db: db}
|
||||
func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
|
||||
return &redeemCodeRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
|
||||
func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
|
||||
return r.db.WithContext(ctx).Create(code).Error
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
|
||||
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
|
||||
return r.db.WithContext(ctx).Create(&codes).Error
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
|
||||
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
|
||||
var code model.RedeemCode
|
||||
err := r.db.WithContext(ctx).First(&code, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
|
||||
}
|
||||
return &code, nil
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
|
||||
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
|
||||
var redeemCode model.RedeemCode
|
||||
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
|
||||
}
|
||||
return &redeemCode, nil
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
|
||||
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
var codes []model.RedeemCode
|
||||
var total int64
|
||||
|
||||
@@ -91,11 +94,11 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
|
||||
func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
|
||||
return r.db.WithContext(ctx).Save(code).Error
|
||||
}
|
||||
|
||||
func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
|
||||
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
|
||||
now := time.Now()
|
||||
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
|
||||
Where("id = ? AND status = ?", id, model.StatusUnused).
|
||||
@@ -108,13 +111,13 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound // 兑换码不存在或已被使用
|
||||
return service.ErrRedeemCodeUsed.WithCause(gorm.ErrRecordNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListByUser returns all redeem codes used by a specific user
|
||||
func (r *RedeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
|
||||
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
|
||||
var codes []model.RedeemCode
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -17,13 +18,13 @@ type RedeemCodeRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *RedeemCodeRepository
|
||||
repo *redeemCodeRepository
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewRedeemCodeRepository(s.db)
|
||||
s.repo = NewRedeemCodeRepository(s.db).(*redeemCodeRepository)
|
||||
}
|
||||
|
||||
func TestRedeemCodeRepoSuite(t *testing.T) {
|
||||
@@ -195,7 +196,7 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
|
||||
// Second use should fail
|
||||
err = s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().Error(err, "Use expected error on second call")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
|
||||
@@ -204,7 +205,7 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().Error(err, "expected error for already used code")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
|
||||
}
|
||||
|
||||
// --- ListByUser ---
|
||||
@@ -298,7 +299,7 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
|
||||
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
|
||||
err = s.repo.Use(s.ctx, codeB.ID, user.ID)
|
||||
s.Require().Error(err, "Use expected error on second call")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
|
||||
|
||||
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
package repository
|
||||
|
||||
// Repositories 所有仓库的集合
|
||||
type Repositories struct {
|
||||
User *UserRepository
|
||||
ApiKey *ApiKeyRepository
|
||||
Group *GroupRepository
|
||||
Account *AccountRepository
|
||||
Proxy *ProxyRepository
|
||||
RedeemCode *RedeemCodeRepository
|
||||
UsageLog *UsageLogRepository
|
||||
Setting *SettingRepository
|
||||
UserSubscription *UserSubscriptionRepository
|
||||
}
|
||||
@@ -2,35 +2,38 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// SettingRepository 系统设置数据访问层
|
||||
type SettingRepository struct {
|
||||
type settingRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewSettingRepository 创建系统设置仓库实例
|
||||
func NewSettingRepository(db *gorm.DB) *SettingRepository {
|
||||
return &SettingRepository{db: db}
|
||||
func NewSettingRepository(db *gorm.DB) service.SettingRepository {
|
||||
return &settingRepository{db: db}
|
||||
}
|
||||
|
||||
// Get 根据Key获取设置值
|
||||
func (r *SettingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
|
||||
func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
|
||||
var setting model.Setting
|
||||
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
|
||||
}
|
||||
return &setting, nil
|
||||
}
|
||||
|
||||
// GetValue 获取设置值字符串
|
||||
func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, error) {
|
||||
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
|
||||
setting, err := r.Get(ctx, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -39,7 +42,7 @@ func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, e
|
||||
}
|
||||
|
||||
// Set 设置值(存在则更新,不存在则创建)
|
||||
func (r *SettingRepository) Set(ctx context.Context, key, value string) error {
|
||||
func (r *settingRepository) Set(ctx context.Context, key, value string) error {
|
||||
setting := &model.Setting{
|
||||
Key: key,
|
||||
Value: value,
|
||||
@@ -53,7 +56,7 @@ func (r *SettingRepository) Set(ctx context.Context, key, value string) error {
|
||||
}
|
||||
|
||||
// GetMultiple 批量获取设置
|
||||
func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
var settings []model.Setting
|
||||
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
|
||||
if err != nil {
|
||||
@@ -68,7 +71,7 @@ func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map
|
||||
}
|
||||
|
||||
// SetMultiple 批量设置值
|
||||
func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for key, value := range settings {
|
||||
setting := &model.Setting{
|
||||
@@ -88,7 +91,7 @@ func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string
|
||||
}
|
||||
|
||||
// GetAll 获取所有设置
|
||||
func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
var settings []model.Setting
|
||||
err := r.db.WithContext(ctx).Find(&settings).Error
|
||||
if err != nil {
|
||||
@@ -103,6 +106,6 @@ func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, erro
|
||||
}
|
||||
|
||||
// Delete 删除设置
|
||||
func (r *SettingRepository) Delete(ctx context.Context, key string) error {
|
||||
func (r *settingRepository) Delete(ctx context.Context, key string) error {
|
||||
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -14,13 +15,13 @@ type SettingRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *SettingRepository
|
||||
repo *settingRepository
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewSettingRepository(s.db)
|
||||
s.repo = NewSettingRepository(s.db).(*settingRepository)
|
||||
}
|
||||
|
||||
func TestSettingRepoSuite(t *testing.T) {
|
||||
@@ -45,7 +46,7 @@ func (s *SettingRepoSuite) TestSet_Upsert() {
|
||||
func (s *SettingRepoSuite) TestGetValue_Missing() {
|
||||
_, err := s.repo.GetValue(s.ctx, "nonexistent")
|
||||
s.Require().Error(err, "expected error for missing key")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
s.Require().ErrorIs(err, service.ErrSettingNotFound)
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
|
||||
@@ -86,7 +87,7 @@ func (s *SettingRepoSuite) TestDelete() {
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
|
||||
_, err := s.repo.GetValue(s.ctx, "todelete")
|
||||
s.Require().Error(err, "expected missing key error after Delete")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
s.Require().ErrorIs(err, service.ErrSettingNotFound)
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestDelete_Idempotent() {
|
||||
|
||||
@@ -4,8 +4,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -15,7 +14,7 @@ type updateCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewUpdateCache(rdb *redis.Client) ports.UpdateCache {
|
||||
func NewUpdateCache(rdb *redis.Client) service.UpdateCache {
|
||||
return &updateCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,25 +2,28 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UsageLogRepository struct {
|
||||
type usageLogRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUsageLogRepository(db *gorm.DB) *UsageLogRepository {
|
||||
return &UsageLogRepository{db: db}
|
||||
func NewUsageLogRepository(db *gorm.DB) service.UsageLogRepository {
|
||||
return &usageLogRepository{db: db}
|
||||
}
|
||||
|
||||
// getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤)
|
||||
func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
|
||||
func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
|
||||
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
|
||||
var perfStats struct {
|
||||
RequestCount int64 `gorm:"column:request_count"`
|
||||
@@ -43,20 +46,20 @@ func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int
|
||||
return perfStats.RequestCount / 5, perfStats.TokenCount / 5
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
|
||||
func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
|
||||
var log model.UsageLog
|
||||
err := r.db.WithContext(ctx).First(&log, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil)
|
||||
}
|
||||
return &log, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -83,7 +86,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -120,7 +123,7 @@ type UserStats struct {
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
|
||||
func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
|
||||
var stats UserStats
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Select(`
|
||||
@@ -139,7 +142,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
|
||||
// DashboardStats 仪表盘统计
|
||||
type DashboardStats = usagestats.DashboardStats
|
||||
|
||||
func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
var stats DashboardStats
|
||||
today := timezone.Today()
|
||||
|
||||
@@ -260,7 +263,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -287,7 +290,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
|
||||
@@ -296,7 +299,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
|
||||
@@ -305,7 +308,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||
@@ -314,7 +317,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
|
||||
@@ -323,12 +326,12 @@ func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelN
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
|
||||
}
|
||||
|
||||
// GetAccountTodayStats 获取账号今日统计
|
||||
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
||||
func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
||||
today := timezone.Today()
|
||||
|
||||
var stats struct {
|
||||
@@ -358,7 +361,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// GetAccountWindowStats 获取账号时间窗口内的统计
|
||||
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||
func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||
var stats struct {
|
||||
Requests int64 `gorm:"column:requests"`
|
||||
Tokens int64 `gorm:"column:tokens"`
|
||||
@@ -398,7 +401,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
|
||||
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
|
||||
|
||||
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
|
||||
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
|
||||
var results []ApiKeyUsageTrendPoint
|
||||
|
||||
// Choose date format based on granularity
|
||||
@@ -442,7 +445,7 @@ func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
|
||||
}
|
||||
|
||||
// GetUserUsageTrend returns usage trend data grouped by user and date
|
||||
func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) {
|
||||
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) {
|
||||
var results []UserUsageTrendPoint
|
||||
|
||||
// Choose date format based on granularity
|
||||
@@ -491,7 +494,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
|
||||
type UserDashboardStats = usagestats.UserDashboardStats
|
||||
|
||||
// GetUserDashboardStats 获取用户专属的仪表盘统计
|
||||
func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
|
||||
func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
|
||||
var stats UserDashboardStats
|
||||
today := timezone.Today()
|
||||
|
||||
@@ -578,7 +581,7 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
}
|
||||
|
||||
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
||||
func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
|
||||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
|
||||
var results []TrendDataPoint
|
||||
|
||||
var dateFormat string
|
||||
@@ -612,7 +615,7 @@ func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
|
||||
}
|
||||
|
||||
// GetUserModelStats 获取指定用户的模型统计
|
||||
func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
|
||||
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
|
||||
var results []ModelStat
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
@@ -641,7 +644,7 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
||||
type UsageLogFilters = usagestats.UsageLogFilters
|
||||
|
||||
// ListWithFilters lists usage logs with optional filters (for admin)
|
||||
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
var total int64
|
||||
|
||||
@@ -692,7 +695,7 @@ type UsageStats = usagestats.UsageStats
|
||||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||||
|
||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
|
||||
func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
|
||||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return make(map[int64]*BatchUserUsageStats), nil
|
||||
}
|
||||
@@ -752,7 +755,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
|
||||
|
||||
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
|
||||
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return make(map[int64]*BatchApiKeyUsageStats), nil
|
||||
}
|
||||
@@ -809,7 +812,7 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
|
||||
}
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
|
||||
func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
|
||||
var results []TrendDataPoint
|
||||
|
||||
var dateFormat string
|
||||
@@ -848,7 +851,7 @@ func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
||||
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
|
||||
var results []ModelStat
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
@@ -882,7 +885,7 @@ func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
}
|
||||
|
||||
// GetGlobalStats gets usage statistics for all users within a time range
|
||||
func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
var stats struct {
|
||||
TotalRequests int64 `gorm:"column:total_requests"`
|
||||
TotalInputTokens int64 `gorm:"column:total_input_tokens"`
|
||||
@@ -932,7 +935,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
|
||||
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
||||
|
||||
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||||
func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
|
||||
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
|
||||
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||||
if daysCount <= 0 {
|
||||
daysCount = 30
|
||||
|
||||
@@ -19,13 +19,13 @@ type UsageLogRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *UsageLogRepository
|
||||
repo *usageLogRepository
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewUsageLogRepository(s.db)
|
||||
s.repo = NewUsageLogRepository(s.db).(*usageLogRepository)
|
||||
}
|
||||
|
||||
func TestUsageLogRepoSuite(t *testing.T) {
|
||||
|
||||
@@ -2,56 +2,61 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserRepository struct {
|
||||
type userRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserRepository(db *gorm.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
func NewUserRepository(db *gorm.DB) service.UserRepository {
|
||||
return &userRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *UserRepository) Create(ctx context.Context, user *model.User) error {
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
func (r *userRepository) Create(ctx context.Context, user *model.User) error {
|
||||
err := r.db.WithContext(ctx).Create(user).Error
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
|
||||
func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.WithContext(ctx).First(&user, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
|
||||
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) Update(ctx context.Context, user *model.User) error {
|
||||
return r.db.WithContext(ctx).Save(user).Error
|
||||
func (r *userRepository) Update(ctx context.Context, user *model.User) error {
|
||||
err := r.db.WithContext(ctx).Save(user).Error
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
func (r *UserRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
|
||||
}
|
||||
|
||||
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
|
||||
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists users with optional filtering by status, role, and search query
|
||||
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
|
||||
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
|
||||
var users []model.User
|
||||
var total int64
|
||||
|
||||
@@ -120,13 +125,13 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
|
||||
Update("balance", gorm.Expr("balance + ?", amount)).Error
|
||||
}
|
||||
|
||||
// DeductBalance 扣减用户余额,仅当余额充足时执行
|
||||
func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
result := r.db.WithContext(ctx).Model(&model.User{}).
|
||||
Where("id = ? AND balance >= ?", id, amount).
|
||||
Update("balance", gorm.Expr("balance - ?", amount))
|
||||
@@ -134,17 +139,17 @@ func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount flo
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound // 余额不足或用户不存在
|
||||
return service.ErrInsufficientBalance
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
|
||||
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
|
||||
}
|
||||
|
||||
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
|
||||
return count > 0, err
|
||||
@@ -152,7 +157,7 @@ func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool,
|
||||
|
||||
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
|
||||
// 使用 PostgreSQL 的 array_remove 函数
|
||||
func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.User{}).
|
||||
Where("? = ANY(allowed_groups)", groupID).
|
||||
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
|
||||
@@ -160,14 +165,14 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
|
||||
}
|
||||
|
||||
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
|
||||
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
|
||||
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
|
||||
Order("id ASC").
|
||||
First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
@@ -18,13 +19,13 @@ type UserRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *UserRepository
|
||||
repo *userRepository
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewUserRepository(s.db)
|
||||
s.repo = NewUserRepository(s.db).(*userRepository)
|
||||
}
|
||||
|
||||
func TestUserRepoSuite(t *testing.T) {
|
||||
@@ -247,7 +248,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
|
||||
s.Require().Error(err, "expected error for insufficient balance")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||
s.Require().ErrorIs(err, service.ErrInsufficientBalance)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
|
||||
@@ -432,7 +433,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
|
||||
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
|
||||
s.Require().Error(err, "DeductBalance expected error for insufficient balance")
|
||||
s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error")
|
||||
s.Require().ErrorIs(err, service.ErrInsufficientBalance, "DeductBalance unexpected error")
|
||||
|
||||
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
|
||||
got5, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
|
||||
@@ -6,27 +6,29 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserSubscriptionRepository 用户订阅仓库
|
||||
type UserSubscriptionRepository struct {
|
||||
type userSubscriptionRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserSubscriptionRepository 创建用户订阅仓库
|
||||
func NewUserSubscriptionRepository(db *gorm.DB) *UserSubscriptionRepository {
|
||||
return &UserSubscriptionRepository{db: db}
|
||||
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
|
||||
return &userSubscriptionRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建订阅
|
||||
func (r *UserSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
|
||||
return r.db.WithContext(ctx).Create(sub).Error
|
||||
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
|
||||
err := r.db.WithContext(ctx).Create(sub).Error
|
||||
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取订阅
|
||||
func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
|
||||
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
|
||||
var sub model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("User").
|
||||
@@ -34,26 +36,26 @@ func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*mo
|
||||
Preload("AssignedByUser").
|
||||
First(&sub, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
|
||||
func (r *UserSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
var sub model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ? AND group_id = ?", userID, groupID).
|
||||
First(&sub).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
|
||||
func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
var sub model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
@@ -61,24 +63,24 @@ func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con
|
||||
userID, groupID, model.SubscriptionStatusActive, time.Now()).
|
||||
First(&sub).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// Update 更新订阅
|
||||
func (r *UserSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
|
||||
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
|
||||
sub.UpdatedAt = time.Now()
|
||||
return r.db.WithContext(ctx).Save(sub).Error
|
||||
}
|
||||
|
||||
// Delete 删除订阅
|
||||
func (r *UserSubscriptionRepository) Delete(ctx context.Context, id int64) error {
|
||||
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
|
||||
}
|
||||
|
||||
// ListByUserID 获取用户的所有订阅
|
||||
func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
var subs []model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
@@ -89,7 +91,7 @@ func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID in
|
||||
}
|
||||
|
||||
// ListActiveByUserID 获取用户的所有有效订阅
|
||||
func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
var subs []model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
@@ -101,7 +103,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
|
||||
}
|
||||
|
||||
// ListByGroupID 获取分组的所有订阅(分页)
|
||||
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []model.UserSubscription
|
||||
var total int64
|
||||
|
||||
@@ -136,7 +138,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []model.UserSubscription
|
||||
var total int64
|
||||
|
||||
@@ -182,7 +184,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
}
|
||||
|
||||
// IncrementUsage 增加使用量
|
||||
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
@@ -194,7 +196,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
|
||||
}
|
||||
|
||||
// ResetDailyUsage 重置日使用量
|
||||
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
@@ -205,7 +207,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
|
||||
}
|
||||
|
||||
// ResetWeeklyUsage 重置周使用量
|
||||
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
@@ -216,7 +218,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
|
||||
}
|
||||
|
||||
// ResetMonthlyUsage 重置月使用量
|
||||
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
@@ -227,7 +229,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
|
||||
}
|
||||
|
||||
// ActivateWindows 激活所有窗口(首次使用时)
|
||||
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
|
||||
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
@@ -239,7 +241,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
|
||||
}
|
||||
|
||||
// UpdateStatus 更新订阅状态
|
||||
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
|
||||
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
@@ -249,7 +251,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
|
||||
}
|
||||
|
||||
// ExtendExpiry 延长订阅过期时间
|
||||
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
|
||||
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
@@ -259,7 +261,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
|
||||
}
|
||||
|
||||
// UpdateNotes 更新订阅备注
|
||||
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
|
||||
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
@@ -269,7 +271,7 @@ func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64,
|
||||
}
|
||||
|
||||
// ListExpired 获取所有已过期但状态仍为active的订阅
|
||||
func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
|
||||
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
|
||||
var subs []model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
||||
@@ -278,7 +280,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
|
||||
}
|
||||
|
||||
// BatchUpdateExpiredStatus 批量更新过期订阅状态
|
||||
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
||||
Updates(map[string]any{
|
||||
@@ -289,7 +291,7 @@ func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex
|
||||
}
|
||||
|
||||
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
|
||||
func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("user_id = ? AND group_id = ?", userID, groupID).
|
||||
@@ -298,7 +300,7 @@ func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Contex
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的订阅数量
|
||||
func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("group_id = ?", groupID).
|
||||
@@ -307,7 +309,7 @@ func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID
|
||||
}
|
||||
|
||||
// CountActiveByGroupID 获取分组的有效订阅数量
|
||||
func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("group_id = ? AND status = ? AND expires_at > ?",
|
||||
@@ -317,7 +319,7 @@ func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除分组相关的所有订阅记录
|
||||
func (r *UserSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@@ -17,13 +17,13 @@ type UserSubscriptionRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
db *gorm.DB
|
||||
repo *UserSubscriptionRepository
|
||||
repo *userSubscriptionRepository
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.db = testTx(s.T())
|
||||
s.repo = NewUserSubscriptionRepository(s.db)
|
||||
s.repo = NewUserSubscriptionRepository(s.db).(*userSubscriptionRepository)
|
||||
}
|
||||
|
||||
func TestUserSubscriptionRepoSuite(t *testing.T) {
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
@@ -17,7 +15,6 @@ var ProviderSet = wire.NewSet(
|
||||
NewUsageLogRepository,
|
||||
NewSettingRepository,
|
||||
NewUserSubscriptionRepository,
|
||||
wire.Struct(new(Repositories), "*"),
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
@@ -38,15 +35,4 @@ var ProviderSet = wire.NewSet(
|
||||
NewClaudeOAuthClient,
|
||||
NewHTTPUpstream,
|
||||
NewOpenAIOAuthClient,
|
||||
|
||||
// Bind concrete repositories to service port interfaces
|
||||
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
|
||||
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
|
||||
wire.Bind(new(ports.GroupRepository), new(*GroupRepository)),
|
||||
wire.Bind(new(ports.AccountRepository), new(*AccountRepository)),
|
||||
wire.Bind(new(ports.ProxyRepository), new(*ProxyRepository)),
|
||||
wire.Bind(new(ports.RedeemCodeRepository), new(*RedeemCodeRepository)),
|
||||
wire.Bind(new(ports.UsageLogRepository), new(*UsageLogRepository)),
|
||||
wire.Bind(new(ports.SettingRepository), new(*SettingRepository)),
|
||||
wire.Bind(new(ports.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
|
||||
)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
@@ -19,15 +19,21 @@ var ProviderSet = wire.NewSet(
|
||||
)
|
||||
|
||||
// ProvideRouter 提供路由器
|
||||
func ProvideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
|
||||
func ProvideRouter(
|
||||
cfg *config.Config,
|
||||
handlers *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
r.Use(middleware2.Recovery())
|
||||
|
||||
return SetupRouter(r, cfg, handlers, services, repos)
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth)
|
||||
}
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
|
||||
@@ -1,32 +1,39 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AdminAuth 管理员认证中间件
|
||||
// NewAdminAuthMiddleware 创建管理员认证中间件
|
||||
func NewAdminAuthMiddleware(
|
||||
authService *service.AuthService,
|
||||
userService *service.UserService,
|
||||
settingService *service.SettingService,
|
||||
) AdminAuthMiddleware {
|
||||
return AdminAuthMiddleware(adminAuth(authService, userService, settingService))
|
||||
}
|
||||
|
||||
// adminAuth 管理员认证中间件实现
|
||||
// 支持两种认证方式(通过不同的 header 区分):
|
||||
// 1. Admin API Key: x-api-key: <admin-api-key>
|
||||
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
|
||||
func AdminAuth(
|
||||
func adminAuth(
|
||||
authService *service.AuthService,
|
||||
userRepo interface {
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
GetFirstAdmin(ctx context.Context) (*model.User, error)
|
||||
},
|
||||
userService *service.UserService,
|
||||
settingService *service.SettingService,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查 x-api-key header(Admin API Key 认证)
|
||||
apiKey := c.GetHeader("x-api-key")
|
||||
if apiKey != "" {
|
||||
if !validateAdminApiKey(c, apiKey, settingService, userRepo) {
|
||||
if !validateAdminApiKey(c, apiKey, settingService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
@@ -38,7 +45,7 @@ func AdminAuth(
|
||||
if authHeader != "" {
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
if !validateJWTForAdmin(c, parts[1], authService, userRepo) {
|
||||
if !validateJWTForAdmin(c, parts[1], authService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
@@ -56,9 +63,7 @@ func validateAdminApiKey(
|
||||
c *gin.Context,
|
||||
key string,
|
||||
settingService *service.SettingService,
|
||||
userRepo interface {
|
||||
GetFirstAdmin(ctx context.Context) (*model.User, error)
|
||||
},
|
||||
userService *service.UserService,
|
||||
) bool {
|
||||
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
@@ -73,7 +78,7 @@ func validateAdminApiKey(
|
||||
}
|
||||
|
||||
// 获取真实的管理员用户
|
||||
admin, err := userRepo.GetFirstAdmin(c.Request.Context())
|
||||
admin, err := userService.GetFirstAdmin(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
|
||||
return false
|
||||
@@ -89,14 +94,12 @@ func validateJWTForAdmin(
|
||||
c *gin.Context,
|
||||
token string,
|
||||
authService *service.AuthService,
|
||||
userRepo interface {
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
},
|
||||
userService *service.UserService,
|
||||
) bool {
|
||||
// 验证 JWT token
|
||||
claims, err := authService.ValidateToken(token)
|
||||
if err != nil {
|
||||
if err == service.ErrTokenExpired {
|
||||
if errors.Is(err, service.ErrTokenExpired) {
|
||||
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||
return false
|
||||
}
|
||||
@@ -105,7 +108,7 @@ func validateJWTForAdmin(
|
||||
}
|
||||
|
||||
// 从数据库获取用户
|
||||
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID)
|
||||
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||
return false
|
||||
@@ -1,37 +1,24 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ApiKeyAuthService 定义API Key认证服务需要的接口
|
||||
type ApiKeyAuthService interface {
|
||||
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
|
||||
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
|
||||
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) ApiKeyAuthMiddleware {
|
||||
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService))
|
||||
}
|
||||
|
||||
// SubscriptionAuthService 定义订阅认证服务需要的接口
|
||||
type SubscriptionAuthService interface {
|
||||
GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
|
||||
ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error
|
||||
CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error
|
||||
CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error
|
||||
CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error
|
||||
}
|
||||
|
||||
// ApiKeyAuth API Key认证中间件
|
||||
func ApiKeyAuth(apiKeyRepo ApiKeyAuthService) gin.HandlerFunc {
|
||||
return ApiKeyAuthWithSubscription(apiKeyRepo, nil)
|
||||
}
|
||||
|
||||
// ApiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
func ApiKeyAuthWithSubscription(apiKeyRepo ApiKeyAuthService, subscriptionService SubscriptionAuthService) gin.HandlerFunc {
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 尝试从Authorization header中提取API key (Bearer scheme)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
@@ -57,7 +44,7 @@ func ApiKeyAuthWithSubscription(apiKeyRepo ApiKeyAuthService, subscriptionServic
|
||||
}
|
||||
|
||||
// 从数据库验证API key
|
||||
apiKey, err := apiKeyRepo.GetByKey(c.Request.Context(), apiKeyString)
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
|
||||
@@ -1,18 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// JWTAuth JWT认证中间件
|
||||
func JWTAuth(authService *service.AuthService, userRepo interface {
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
}) gin.HandlerFunc {
|
||||
// NewJWTAuthMiddleware 创建 JWT 认证中间件
|
||||
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
|
||||
return JWTAuthMiddleware(jwtAuth(authService, userService))
|
||||
}
|
||||
|
||||
// jwtAuth JWT认证中间件实现
|
||||
func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从Authorization header中提取token
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
@@ -37,7 +41,7 @@ func JWTAuth(authService *service.AuthService, userRepo interface {
|
||||
// 验证token
|
||||
claims, err := authService.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
if err == service.ErrTokenExpired {
|
||||
if errors.Is(err, service.ErrTokenExpired) {
|
||||
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||
return
|
||||
}
|
||||
@@ -46,7 +50,7 @@ func JWTAuth(authService *service.AuthService, userRepo interface {
|
||||
}
|
||||
|
||||
// 从数据库获取最新的用户信息
|
||||
user, err := userRepo.GetByID(c.Request.Context(), claims.UserID)
|
||||
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||
return
|
||||
64
backend/internal/server/middleware/recovery.go
Normal file
64
backend/internal/server/middleware/recovery.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Recovery converts panics into the project's standard JSON error envelope.
|
||||
//
|
||||
// It preserves Gin's broken-pipe handling by not attempting to write a response
|
||||
// when the client connection is already gone.
|
||||
func Recovery() gin.HandlerFunc {
|
||||
return gin.CustomRecoveryWithWriter(gin.DefaultErrorWriter, func(c *gin.Context, recovered any) {
|
||||
recoveredErr, _ := recovered.(error)
|
||||
|
||||
if isBrokenPipe(recoveredErr) {
|
||||
if recoveredErr != nil {
|
||||
_ = c.Error(recoveredErr)
|
||||
}
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if c.Writer.Written() {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
response.ErrorWithDetails(
|
||||
c,
|
||||
http.StatusInternalServerError,
|
||||
infraerrors.UnknownMessage,
|
||||
infraerrors.UnknownReason,
|
||||
nil,
|
||||
)
|
||||
c.Abort()
|
||||
})
|
||||
}
|
||||
|
||||
func isBrokenPipe(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var opErr *net.OpError
|
||||
if !errors.As(err, &opErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
var syscallErr *os.SyscallError
|
||||
if !errors.As(opErr.Err, &syscallErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
msg := strings.ToLower(syscallErr.Error())
|
||||
return strings.Contains(msg, "broken pipe") || strings.Contains(msg, "connection reset by peer")
|
||||
}
|
||||
81
backend/internal/server/middleware/recovery_test.go
Normal file
81
backend/internal/server/middleware/recovery_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecovery(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
handler gin.HandlerFunc
|
||||
wantHTTPCode int
|
||||
wantBody response.Response
|
||||
}{
|
||||
{
|
||||
name: "panic_returns_standard_json_500",
|
||||
handler: func(c *gin.Context) {
|
||||
panic("boom")
|
||||
},
|
||||
wantHTTPCode: http.StatusInternalServerError,
|
||||
wantBody: response.Response{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: infraerrors.UnknownMessage,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no_panic_passthrough",
|
||||
handler: func(c *gin.Context) {
|
||||
response.Success(c, gin.H{"ok": true})
|
||||
},
|
||||
wantHTTPCode: http.StatusOK,
|
||||
wantBody: response.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: map[string]any{"ok": true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "panic_after_write_does_not_override_body",
|
||||
handler: func(c *gin.Context) {
|
||||
response.Success(c, gin.H{"ok": true})
|
||||
panic("boom")
|
||||
},
|
||||
wantHTTPCode: http.StatusOK,
|
||||
wantBody: response.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: map[string]any{"ok": true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(Recovery())
|
||||
r.GET("/t", tt.handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tt.wantHTTPCode, w.Code)
|
||||
|
||||
var got response.Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.wantBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
22
backend/internal/server/middleware/wire.go
Normal file
22
backend/internal/server/middleware/wire.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// JWTAuthMiddleware JWT 认证中间件类型
|
||||
type JWTAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// AdminAuthMiddleware 管理员认证中间件类型
|
||||
type AdminAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ApiKeyAuthMiddleware API Key 认证中间件类型
|
||||
type ApiKeyAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ProviderSet 中间件层的依赖注入
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewJWTAuthMiddleware,
|
||||
NewAdminAuthMiddleware,
|
||||
NewApiKeyAuthMiddleware,
|
||||
)
|
||||
@@ -1,312 +1,54 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/routes"
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SetupRouter 配置路由器中间件和路由
|
||||
func SetupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine {
|
||||
func SetupRouter(
|
||||
r *gin.Engine,
|
||||
handlers *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
) *gin.Engine {
|
||||
// 应用中间件
|
||||
r.Use(middleware.Logger())
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, services, repos)
|
||||
r.Use(middleware2.Logger())
|
||||
r.Use(middleware2.CORS())
|
||||
|
||||
// Serve embedded frontend if available
|
||||
if web.HasEmbeddedFrontend() {
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// registerRoutes 注册所有 HTTP 路由
|
||||
func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) {
|
||||
// 健康检查
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Claude Code 遥测日志(忽略,直接返回200)
|
||||
r.POST("/api/event_logging/batch", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// Setup status endpoint (always returns needs_setup: false in normal mode)
|
||||
// This is used by the frontend to detect when the service has restarted after setup
|
||||
r.GET("/setup/status", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": gin.H{
|
||||
"needs_setup": false,
|
||||
"step": "completed",
|
||||
},
|
||||
})
|
||||
})
|
||||
func registerRoutes(
|
||||
r *gin.Engine,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
) {
|
||||
// 通用路由(健康检查、状态等)
|
||||
routes.RegisterCommonRoutes(r)
|
||||
|
||||
// API v1
|
||||
v1 := r.Group("/api/v1")
|
||||
{
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
settings := v1.Group("/settings")
|
||||
{
|
||||
settings.GET("/public", h.Setting.GetPublicSettings)
|
||||
}
|
||||
|
||||
// 需要认证的接口
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(middleware.JWTAuth(s.Auth, repos.User))
|
||||
{
|
||||
// 当前用户信息
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
|
||||
// 用户接口
|
||||
user := authenticated.Group("/user")
|
||||
{
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
user.PUT("", h.User.UpdateProfile)
|
||||
}
|
||||
|
||||
// API Key管理
|
||||
keys := authenticated.Group("/keys")
|
||||
{
|
||||
keys.GET("", h.APIKey.List)
|
||||
keys.GET("/:id", h.APIKey.GetByID)
|
||||
keys.POST("", h.APIKey.Create)
|
||||
keys.PUT("/:id", h.APIKey.Update)
|
||||
keys.DELETE("/:id", h.APIKey.Delete)
|
||||
}
|
||||
|
||||
// 用户可用分组(非管理员接口)
|
||||
groups := authenticated.Group("/groups")
|
||||
{
|
||||
groups.GET("/available", h.APIKey.GetAvailableGroups)
|
||||
}
|
||||
|
||||
// 使用记录
|
||||
usage := authenticated.Group("/usage")
|
||||
{
|
||||
usage.GET("", h.Usage.List)
|
||||
usage.GET("/:id", h.Usage.GetByID)
|
||||
usage.GET("/stats", h.Usage.Stats)
|
||||
// User dashboard endpoints
|
||||
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
|
||||
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
|
||||
usage.GET("/dashboard/models", h.Usage.DashboardModels)
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
|
||||
}
|
||||
|
||||
// 卡密兑换
|
||||
redeem := authenticated.Group("/redeem")
|
||||
{
|
||||
redeem.POST("", h.Redeem.Redeem)
|
||||
redeem.GET("/history", h.Redeem.GetHistory)
|
||||
}
|
||||
|
||||
// 用户订阅
|
||||
subscriptions := authenticated.Group("/subscriptions")
|
||||
{
|
||||
subscriptions.GET("", h.Subscription.List)
|
||||
subscriptions.GET("/active", h.Subscription.GetActive)
|
||||
subscriptions.GET("/progress", h.Subscription.GetProgress)
|
||||
subscriptions.GET("/summary", h.Subscription.GetSummary)
|
||||
}
|
||||
}
|
||||
|
||||
// 管理员接口
|
||||
admin := v1.Group("/admin")
|
||||
admin.Use(middleware.AdminAuth(s.Auth, repos.User, s.Setting))
|
||||
{
|
||||
// 仪表盘
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
|
||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
|
||||
}
|
||||
|
||||
// 用户管理
|
||||
users := admin.Group("/users")
|
||||
{
|
||||
users.GET("", h.Admin.User.List)
|
||||
users.GET("/:id", h.Admin.User.GetByID)
|
||||
users.POST("", h.Admin.User.Create)
|
||||
users.PUT("/:id", h.Admin.User.Update)
|
||||
users.DELETE("/:id", h.Admin.User.Delete)
|
||||
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
||||
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||
}
|
||||
|
||||
// 分组管理
|
||||
groups := admin.Group("/groups")
|
||||
{
|
||||
groups.GET("", h.Admin.Group.List)
|
||||
groups.GET("/all", h.Admin.Group.GetAll)
|
||||
groups.GET("/:id", h.Admin.Group.GetByID)
|
||||
groups.POST("", h.Admin.Group.Create)
|
||||
groups.PUT("/:id", h.Admin.Group.Update)
|
||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||
}
|
||||
|
||||
// 账号管理
|
||||
accounts := admin.Group("/accounts")
|
||||
{
|
||||
accounts.GET("", h.Admin.Account.List)
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
|
||||
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||
|
||||
// Claude OAuth routes
|
||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
|
||||
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
|
||||
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
|
||||
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
|
||||
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
|
||||
}
|
||||
|
||||
// OpenAI OAuth routes
|
||||
openai := admin.Group("/openai")
|
||||
{
|
||||
openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
|
||||
openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
|
||||
openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
|
||||
openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
|
||||
openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
|
||||
}
|
||||
|
||||
// 代理管理
|
||||
proxies := admin.Group("/proxies")
|
||||
{
|
||||
proxies.GET("", h.Admin.Proxy.List)
|
||||
proxies.GET("/all", h.Admin.Proxy.GetAll)
|
||||
proxies.GET("/:id", h.Admin.Proxy.GetByID)
|
||||
proxies.POST("", h.Admin.Proxy.Create)
|
||||
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
||||
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
|
||||
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
||||
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
||||
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
||||
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
|
||||
}
|
||||
|
||||
// 卡密管理
|
||||
codes := admin.Group("/redeem-codes")
|
||||
{
|
||||
codes.GET("", h.Admin.Redeem.List)
|
||||
codes.GET("/stats", h.Admin.Redeem.GetStats)
|
||||
codes.GET("/export", h.Admin.Redeem.Export)
|
||||
codes.GET("/:id", h.Admin.Redeem.GetByID)
|
||||
codes.POST("/generate", h.Admin.Redeem.Generate)
|
||||
codes.DELETE("/:id", h.Admin.Redeem.Delete)
|
||||
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
|
||||
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
|
||||
}
|
||||
|
||||
// 系统设置
|
||||
adminSettings := admin.Group("/settings")
|
||||
{
|
||||
adminSettings.GET("", h.Admin.Setting.GetSettings)
|
||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
|
||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||
// Admin API Key 管理
|
||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
|
||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
|
||||
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
|
||||
}
|
||||
|
||||
// 系统管理
|
||||
system := admin.Group("/system")
|
||||
{
|
||||
system.GET("/version", h.Admin.System.GetVersion)
|
||||
system.GET("/check-updates", h.Admin.System.CheckUpdates)
|
||||
system.POST("/update", h.Admin.System.PerformUpdate)
|
||||
system.POST("/rollback", h.Admin.System.Rollback)
|
||||
system.POST("/restart", h.Admin.System.RestartService)
|
||||
}
|
||||
|
||||
// 订阅管理
|
||||
subscriptions := admin.Group("/subscriptions")
|
||||
{
|
||||
subscriptions.GET("", h.Admin.Subscription.List)
|
||||
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
|
||||
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
|
||||
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
||||
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
||||
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
||||
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
||||
}
|
||||
|
||||
// 分组下的订阅列表
|
||||
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
|
||||
|
||||
// 用户下的订阅列表
|
||||
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
|
||||
|
||||
// 使用记录管理
|
||||
usage := admin.Group("/usage")
|
||||
{
|
||||
usage.GET("", h.Admin.Usage.List)
|
||||
usage.GET("/stats", h.Admin.Usage.Stats)
|
||||
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
|
||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// API网关(Claude API兼容)
|
||||
gateway := r.Group("/v1")
|
||||
gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription))
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
}
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription), h.OpenAIGateway.Responses)
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth)
|
||||
}
|
||||
|
||||
221
backend/internal/server/routes/admin.go
Normal file
221
backend/internal/server/routes/admin.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterAdminRoutes 注册管理员路由
|
||||
func RegisterAdminRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
adminAuth middleware.AdminAuthMiddleware,
|
||||
) {
|
||||
admin := v1.Group("/admin")
|
||||
admin.Use(gin.HandlerFunc(adminAuth))
|
||||
{
|
||||
// 仪表盘
|
||||
registerDashboardRoutes(admin, h)
|
||||
|
||||
// 用户管理
|
||||
registerUserManagementRoutes(admin, h)
|
||||
|
||||
// 分组管理
|
||||
registerGroupRoutes(admin, h)
|
||||
|
||||
// 账号管理
|
||||
registerAccountRoutes(admin, h)
|
||||
|
||||
// OpenAI OAuth
|
||||
registerOpenAIOAuthRoutes(admin, h)
|
||||
|
||||
// 代理管理
|
||||
registerProxyRoutes(admin, h)
|
||||
|
||||
// 卡密管理
|
||||
registerRedeemCodeRoutes(admin, h)
|
||||
|
||||
// 系统设置
|
||||
registerSettingsRoutes(admin, h)
|
||||
|
||||
// 系统管理
|
||||
registerSystemRoutes(admin, h)
|
||||
|
||||
// 订阅管理
|
||||
registerSubscriptionRoutes(admin, h)
|
||||
|
||||
// 使用记录管理
|
||||
registerUsageRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
|
||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
|
||||
}
|
||||
}
|
||||
|
||||
func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
users := admin.Group("/users")
|
||||
{
|
||||
users.GET("", h.Admin.User.List)
|
||||
users.GET("/:id", h.Admin.User.GetByID)
|
||||
users.POST("", h.Admin.User.Create)
|
||||
users.PUT("/:id", h.Admin.User.Update)
|
||||
users.DELETE("/:id", h.Admin.User.Delete)
|
||||
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
|
||||
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
|
||||
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
|
||||
}
|
||||
}
|
||||
|
||||
func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
groups := admin.Group("/groups")
|
||||
{
|
||||
groups.GET("", h.Admin.Group.List)
|
||||
groups.GET("/all", h.Admin.Group.GetAll)
|
||||
groups.GET("/:id", h.Admin.Group.GetByID)
|
||||
groups.POST("", h.Admin.Group.Create)
|
||||
groups.PUT("/:id", h.Admin.Group.Update)
|
||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||
}
|
||||
}
|
||||
|
||||
func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts := admin.Group("/accounts")
|
||||
{
|
||||
accounts.GET("", h.Admin.Account.List)
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
|
||||
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||
|
||||
// Claude OAuth routes
|
||||
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
|
||||
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
|
||||
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
|
||||
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
|
||||
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
|
||||
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
openai := admin.Group("/openai")
|
||||
{
|
||||
openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
|
||||
openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
|
||||
openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
|
||||
openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
|
||||
openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
proxies := admin.Group("/proxies")
|
||||
{
|
||||
proxies.GET("", h.Admin.Proxy.List)
|
||||
proxies.GET("/all", h.Admin.Proxy.GetAll)
|
||||
proxies.GET("/:id", h.Admin.Proxy.GetByID)
|
||||
proxies.POST("", h.Admin.Proxy.Create)
|
||||
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
||||
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
|
||||
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
||||
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
||||
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
||||
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
|
||||
}
|
||||
}
|
||||
|
||||
func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
codes := admin.Group("/redeem-codes")
|
||||
{
|
||||
codes.GET("", h.Admin.Redeem.List)
|
||||
codes.GET("/stats", h.Admin.Redeem.GetStats)
|
||||
codes.GET("/export", h.Admin.Redeem.Export)
|
||||
codes.GET("/:id", h.Admin.Redeem.GetByID)
|
||||
codes.POST("/generate", h.Admin.Redeem.Generate)
|
||||
codes.DELETE("/:id", h.Admin.Redeem.Delete)
|
||||
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
|
||||
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
adminSettings := admin.Group("/settings")
|
||||
{
|
||||
adminSettings.GET("", h.Admin.Setting.GetSettings)
|
||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
|
||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||
// Admin API Key 管理
|
||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
|
||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
|
||||
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
system := admin.Group("/system")
|
||||
{
|
||||
system.GET("/version", h.Admin.System.GetVersion)
|
||||
system.GET("/check-updates", h.Admin.System.CheckUpdates)
|
||||
system.POST("/update", h.Admin.System.PerformUpdate)
|
||||
system.POST("/rollback", h.Admin.System.Rollback)
|
||||
system.POST("/restart", h.Admin.System.RestartService)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
subscriptions := admin.Group("/subscriptions")
|
||||
{
|
||||
subscriptions.GET("", h.Admin.Subscription.List)
|
||||
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
|
||||
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
|
||||
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
||||
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
||||
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
||||
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
||||
}
|
||||
|
||||
// 分组下的订阅列表
|
||||
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
|
||||
|
||||
// 用户下的订阅列表
|
||||
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
|
||||
}
|
||||
|
||||
func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
usage := admin.Group("/usage")
|
||||
{
|
||||
usage.GET("", h.Admin.Usage.List)
|
||||
usage.GET("/stats", h.Admin.Usage.Stats)
|
||||
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
|
||||
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
|
||||
}
|
||||
}
|
||||
36
backend/internal/server/routes/auth.go
Normal file
36
backend/internal/server/routes/auth.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterAuthRoutes 注册认证相关路由
|
||||
func RegisterAuthRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
) {
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
settings := v1.Group("/settings")
|
||||
{
|
||||
settings.GET("/public", h.Setting.GetPublicSettings)
|
||||
}
|
||||
|
||||
// 需要认证的当前用户信息
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
{
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
}
|
||||
}
|
||||
32
backend/internal/server/routes/common.go
Normal file
32
backend/internal/server/routes/common.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)
|
||||
func RegisterCommonRoutes(r *gin.Engine) {
|
||||
// 健康检查
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Claude Code 遥测日志(忽略,直接返回200)
|
||||
r.POST("/api/event_logging/batch", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
// Setup status endpoint (always returns needs_setup: false in normal mode)
|
||||
// This is used by the frontend to detect when the service has restarted after setup
|
||||
r.GET("/setup/status", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"data": gin.H{
|
||||
"needs_setup": false,
|
||||
"step": "completed",
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
30
backend/internal/server/routes/gateway.go
Normal file
30
backend/internal/server/routes/gateway.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterGatewayRoutes 注册 API 网关路由(Claude/OpenAI 兼容)
|
||||
func RegisterGatewayRoutes(
|
||||
r *gin.Engine,
|
||||
h *handler.Handlers,
|
||||
apiKeyAuth middleware.ApiKeyAuthMiddleware,
|
||||
) {
|
||||
// API网关(Claude API兼容)
|
||||
gateway := r.Group("/v1")
|
||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
}
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
||||
}
|
||||
72
backend/internal/server/routes/user.go
Normal file
72
backend/internal/server/routes/user.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterUserRoutes 注册用户相关路由(需要认证)
|
||||
func RegisterUserRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
) {
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
{
|
||||
// 用户接口
|
||||
user := authenticated.Group("/user")
|
||||
{
|
||||
user.GET("/profile", h.User.GetProfile)
|
||||
user.PUT("/password", h.User.ChangePassword)
|
||||
user.PUT("", h.User.UpdateProfile)
|
||||
}
|
||||
|
||||
// API Key管理
|
||||
keys := authenticated.Group("/keys")
|
||||
{
|
||||
keys.GET("", h.APIKey.List)
|
||||
keys.GET("/:id", h.APIKey.GetByID)
|
||||
keys.POST("", h.APIKey.Create)
|
||||
keys.PUT("/:id", h.APIKey.Update)
|
||||
keys.DELETE("/:id", h.APIKey.Delete)
|
||||
}
|
||||
|
||||
// 用户可用分组(非管理员接口)
|
||||
groups := authenticated.Group("/groups")
|
||||
{
|
||||
groups.GET("/available", h.APIKey.GetAvailableGroups)
|
||||
}
|
||||
|
||||
// 使用记录
|
||||
usage := authenticated.Group("/usage")
|
||||
{
|
||||
usage.GET("", h.Usage.List)
|
||||
usage.GET("/:id", h.Usage.GetByID)
|
||||
usage.GET("/stats", h.Usage.Stats)
|
||||
// User dashboard endpoints
|
||||
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
|
||||
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
|
||||
usage.GET("/dashboard/models", h.Usage.DashboardModels)
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
|
||||
}
|
||||
|
||||
// 卡密兑换
|
||||
redeem := authenticated.Group("/redeem")
|
||||
{
|
||||
redeem.POST("", h.Redeem.Redeem)
|
||||
redeem.GET("/history", h.Redeem.GetHistory)
|
||||
}
|
||||
|
||||
// 用户订阅
|
||||
subscriptions := authenticated.Group("/subscriptions")
|
||||
{
|
||||
subscriptions.GET("", h.Subscription.List)
|
||||
subscriptions.GET("/active", h.Subscription.GetActive)
|
||||
subscriptions.GET("/progress", h.Subscription.GetProgress)
|
||||
subscriptions.GET("/summary", h.Subscription.GetSummary)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,19 +2,63 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAccountNotFound = errors.New("account not found")
|
||||
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
|
||||
)
|
||||
|
||||
type AccountRepository interface {
|
||||
Create(ctx context.Context, account *model.Account) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Account, error)
|
||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||
// Returns (nil, nil) if not found.
|
||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error)
|
||||
Update(ctx context.Context, account *model.Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
ListActive(ctx context.Context) ([]model.Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error)
|
||||
|
||||
UpdateLastUsed(ctx context.Context, id int64) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]model.Account, error)
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error)
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
}
|
||||
|
||||
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
||||
// Nil pointers mean "do not change".
|
||||
type AccountBulkUpdate struct {
|
||||
Name *string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status *string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
// CreateAccountRequest 创建账号请求
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
@@ -42,12 +86,12 @@ type UpdateAccountRequest struct {
|
||||
|
||||
// AccountService 账号管理服务
|
||||
type AccountService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
groupRepo ports.GroupRepository
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
}
|
||||
|
||||
// NewAccountService 创建账号服务实例
|
||||
func NewAccountService(accountRepo ports.AccountRepository, groupRepo ports.GroupRepository) *AccountService {
|
||||
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
|
||||
return &AccountService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -61,9 +105,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
||||
for _, groupID := range req.GroupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("group %d not found", groupID)
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -100,9 +141,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
||||
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAccountNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
return account, nil
|
||||
@@ -139,9 +177,6 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
|
||||
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAccountNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
@@ -184,9 +219,6 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
|
||||
for _, groupID := range *req.GroupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("group %d not found", groupID)
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -204,9 +236,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
|
||||
// 检查账号是否存在
|
||||
_, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrAccountNotFound
|
||||
}
|
||||
return fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
@@ -221,9 +250,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
|
||||
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrAccountNotFound
|
||||
}
|
||||
return fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
@@ -249,9 +275,6 @@ func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", ErrAccountNotFound
|
||||
}
|
||||
return "", fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
@@ -262,9 +285,6 @@ func (s *AccountService) GetCredential(ctx context.Context, id int64, key string
|
||||
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrAccountNotFound
|
||||
}
|
||||
return fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,8 +17,6 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -40,14 +38,14 @@ type TestEvent struct {
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
accountRepo AccountRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
httpUpstream ports.HTTPUpstream
|
||||
httpUpstream HTTPUpstream
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream ports.HTTPUpstream) *AccountTestService {
|
||||
func NewAccountTestService(accountRepo AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream HTTPUpstream) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
|
||||
@@ -8,10 +8,49 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
Create(ctx context.Context, log *model.UsageLog) error
|
||||
GetByID(ctx context.Context, id int64) (*model.UsageLog, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
|
||||
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
|
||||
|
||||
// Admin dashboard stats
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
||||
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
|
||||
|
||||
// User dashboard stats
|
||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||
GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error)
|
||||
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)
|
||||
|
||||
// Admin usage listing/stats
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
|
||||
// Account stats
|
||||
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)
|
||||
}
|
||||
|
||||
// usageCache 用于缓存usage数据
|
||||
type usageCache struct {
|
||||
data *UsageInfo
|
||||
@@ -69,13 +108,13 @@ type ClaudeUsageFetcher interface {
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
type AccountUsageService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
|
||||
@@ -9,9 +9,6 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AdminService interface defines admin management operations
|
||||
@@ -221,24 +218,24 @@ type ProxyExitInfoProber interface {
|
||||
|
||||
// adminServiceImpl implements AdminService
|
||||
type adminServiceImpl struct {
|
||||
userRepo ports.UserRepository
|
||||
groupRepo ports.GroupRepository
|
||||
accountRepo ports.AccountRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
apiKeyRepo ports.ApiKeyRepository
|
||||
redeemCodeRepo ports.RedeemCodeRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo ApiKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
func NewAdminService(
|
||||
userRepo ports.UserRepository,
|
||||
groupRepo ports.GroupRepository,
|
||||
accountRepo ports.AccountRepository,
|
||||
proxyRepo ports.ProxyRepository,
|
||||
apiKeyRepo ports.ApiKeyRepository,
|
||||
redeemCodeRepo ports.RedeemCodeRepository,
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
apiKeyRepo ApiKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
) AdminService {
|
||||
@@ -552,61 +549,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
// 先获取分组信息,检查是否存在
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
|
||||
// 订阅类型分组:先获取受影响的用户ID列表(用于事务后失效缓存)
|
||||
var affectedUserIDs []int64
|
||||
if group.IsSubscriptionType() && s.billingCacheService != nil {
|
||||
var subscriptions []model.UserSubscription
|
||||
if err := s.groupRepo.DB().WithContext(ctx).
|
||||
Where("group_id = ?", id).
|
||||
Select("user_id").
|
||||
Find(&subscriptions).Error; err == nil {
|
||||
for _, sub := range subscriptions {
|
||||
affectedUserIDs = append(affectedUserIDs, sub.UserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用事务处理所有级联删除
|
||||
db := s.groupRepo.DB()
|
||||
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 1. 如果是订阅类型分组,删除 user_subscriptions 中的相关记录
|
||||
if group.IsSubscriptionType() {
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
|
||||
return fmt.Errorf("delete user subscriptions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil(任何类型的分组都需要)
|
||||
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
|
||||
return fmt.Errorf("clear api key group_id: %w", err)
|
||||
}
|
||||
|
||||
// 3. 从 users.allowed_groups 数组中移除该分组 ID
|
||||
if err := tx.Model(&model.User{}).
|
||||
Where("? = ANY(allowed_groups)", id).
|
||||
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
|
||||
return fmt.Errorf("remove from allowed_groups: %w", err)
|
||||
}
|
||||
|
||||
// 4. 删除 account_groups 中间表的数据
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
return fmt.Errorf("delete account groups: %w", err)
|
||||
}
|
||||
|
||||
// 5. 删除分组本身
|
||||
if err := tx.Delete(&model.Group{}, id).Error; err != nil {
|
||||
return fmt.Errorf("delete group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -695,6 +638,7 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
account.ProxyID = input.ProxyID
|
||||
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
|
||||
}
|
||||
// 只在指针非 nil 时更新 Concurrency(支持设置为 0)
|
||||
if input.Concurrency != nil {
|
||||
@@ -719,7 +663,8 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
}
|
||||
|
||||
return account, nil
|
||||
// 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象)
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// BulkUpdateAccounts updates multiple accounts in one request.
|
||||
@@ -734,7 +679,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
}
|
||||
|
||||
// Prepare bulk updates for columns and JSONB fields.
|
||||
repoUpdates := ports.AccountBulkUpdate{
|
||||
repoUpdates := AccountBulkUpdate{
|
||||
Credentials: input.Credentials,
|
||||
Extra: input.Extra,
|
||||
}
|
||||
|
||||
@@ -6,30 +6,55 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApiKeyNotFound = errors.New("api key not found")
|
||||
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
|
||||
ErrApiKeyExists = errors.New("api key already exists")
|
||||
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
|
||||
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
|
||||
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
Create(ctx context.Context, key *model.ApiKey) error
|
||||
GetByID(ctx context.Context, id int64) (*model.ApiKey, error)
|
||||
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
|
||||
Update(ctx context.Context, key *model.ApiKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
|
||||
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
|
||||
// ApiKeyCache defines cache operations for API key service
|
||||
type ApiKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
|
||||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// CreateApiKeyRequest 创建API Key请求
|
||||
type CreateApiKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
@@ -46,21 +71,21 @@ type UpdateApiKeyRequest struct {
|
||||
|
||||
// ApiKeyService API Key服务
|
||||
type ApiKeyService struct {
|
||||
apiKeyRepo ports.ApiKeyRepository
|
||||
userRepo ports.UserRepository
|
||||
groupRepo ports.GroupRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
cache ports.ApiKeyCache
|
||||
apiKeyRepo ApiKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache ApiKeyCache
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewApiKeyService 创建API Key服务实例
|
||||
func NewApiKeyService(
|
||||
apiKeyRepo ports.ApiKeyRepository,
|
||||
userRepo ports.UserRepository,
|
||||
groupRepo ports.GroupRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
cache ports.ApiKeyCache,
|
||||
apiKeyRepo ApiKeyRepository,
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache ApiKeyCache,
|
||||
cfg *config.Config,
|
||||
) *ApiKeyService {
|
||||
return &ApiKeyService{
|
||||
@@ -158,9 +183,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
@@ -168,9 +190,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
if req.GroupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
@@ -244,9 +263,6 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrApiKeyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
@@ -260,9 +276,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
|
||||
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
|
||||
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrApiKeyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
@@ -279,9 +292,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
|
||||
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrApiKeyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
@@ -304,9 +314,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
@@ -336,9 +343,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrApiKeyNotFound
|
||||
}
|
||||
return fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
@@ -369,15 +373,12 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
|
||||
|
||||
// 检查API Key状态
|
||||
if !apiKey.IsActive() {
|
||||
return nil, nil, errors.New("api key is not active")
|
||||
return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, ErrUserNotFound
|
||||
}
|
||||
return nil, nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
@@ -411,9 +412,6 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
@@ -425,7 +423,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
|
||||
|
||||
// 获取用户的所有有效订阅
|
||||
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active subscriptions: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,26 +4,26 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid email or password")
|
||||
ErrUserNotActive = errors.New("user is not active")
|
||||
ErrEmailExists = errors.New("email already exists")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrTokenExpired = errors.New("token has expired")
|
||||
ErrEmailVerifyRequired = errors.New("email verification is required")
|
||||
ErrRegDisabled = errors.New("registration is currently disabled")
|
||||
ErrServiceUnavailable = errors.New("service temporarily unavailable")
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
)
|
||||
|
||||
// JWTClaims JWT载荷数据
|
||||
@@ -36,7 +36,7 @@ type JWTClaims struct {
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
userRepo ports.UserRepository
|
||||
userRepo UserRepository
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
@@ -46,7 +46,7 @@ type AuthService struct {
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(
|
||||
userRepo ports.UserRepository,
|
||||
userRepo UserRepository,
|
||||
cfg *config.Config,
|
||||
settingService *SettingService,
|
||||
emailService *EmailService,
|
||||
@@ -255,7 +255,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
// 查找用户
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return "", nil, ErrInvalidCredentials
|
||||
}
|
||||
// 记录数据库错误但不暴露给用户
|
||||
@@ -357,7 +357,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
||||
// 获取最新的用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, claims.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
log.Printf("[Auth] Database error refreshing token: %v", err)
|
||||
|
||||
15
backend/internal/service/billing_cache_port.go
Normal file
15
backend/internal/service/billing_cache_port.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SubscriptionCacheData represents cached subscription data
|
||||
type SubscriptionCacheData struct {
|
||||
Status string
|
||||
ExpiresAt time.Time
|
||||
DailyUsage float64
|
||||
WeeklyUsage float64
|
||||
MonthlyUsage float64
|
||||
Version int64
|
||||
}
|
||||
@@ -2,20 +2,19 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
// 注:ErrInsufficientBalance在redeem_service.go中定义
|
||||
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
|
||||
var (
|
||||
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
|
||||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
@@ -31,13 +30,13 @@ type subscriptionCacheData struct {
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
cache ports.BillingCache
|
||||
userRepo ports.UserRepository
|
||||
subRepo ports.UserSubscriptionRepository
|
||||
cache BillingCache
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache ports.BillingCache, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository) *BillingCacheService {
|
||||
return &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
@@ -149,7 +148,7 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCacheData) *subscriptionCacheData {
|
||||
func (s *BillingCacheService) convertFromPortsData(data *SubscriptionCacheData) *subscriptionCacheData {
|
||||
return &subscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
@@ -160,8 +159,8 @@ func (s *BillingCacheService) convertFromPortsData(data *ports.SubscriptionCache
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *ports.SubscriptionCacheData {
|
||||
return &ports.SubscriptionCacheData{
|
||||
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *SubscriptionCacheData {
|
||||
return &SubscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
DailyUsage: data.DailyUsage,
|
||||
|
||||
@@ -1,12 +1,30 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// BillingCache defines cache operations for billing service
|
||||
type BillingCache interface {
|
||||
// Balance operations
|
||||
GetUserBalance(ctx context.Context, userID int64) (float64, error)
|
||||
SetUserBalance(ctx context.Context, userID int64, balance float64) error
|
||||
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
|
||||
InvalidateUserBalance(ctx context.Context, userID int64) error
|
||||
|
||||
// Subscription operations
|
||||
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
|
||||
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
|
||||
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
|
||||
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
|
||||
}
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
type ModelPricing struct {
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
|
||||
@@ -7,10 +7,28 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// ConcurrencyCache defines cache operations for concurrency service
|
||||
// Uses independent keys per request slot with native Redis TTL for automatic cleanup
|
||||
type ConcurrencyCache interface {
|
||||
// Account slot management - each slot is a separate key with independent TTL
|
||||
// Key format: concurrency:account:{accountID}:{requestID}
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// User slot management - each slot is a separate key with independent TTL
|
||||
// Key format: concurrency:user:{userID}:{requestID}
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
|
||||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
// Wait queue - uses counter with TTL set only on creation
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID for concurrency slot tracking
|
||||
// Uses 8 random bytes (16 hex chars) for uniqueness
|
||||
func generateRequestID() string {
|
||||
@@ -29,11 +47,11 @@ const (
|
||||
|
||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||||
type ConcurrencyService struct {
|
||||
cache ports.ConcurrencyCache
|
||||
cache ConcurrencyCache
|
||||
}
|
||||
|
||||
// NewConcurrencyService creates a new ConcurrencyService
|
||||
func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
|
||||
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
|
||||
return &ConcurrencyService{cache: cache}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,19 +13,18 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
type CRSSyncService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
}
|
||||
|
||||
func NewCRSSyncService(
|
||||
accountRepo ports.AccountRepository,
|
||||
proxyRepo ports.ProxyRepository,
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
) *CRSSyncService {
|
||||
|
||||
@@ -6,15 +6,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// DashboardService provides aggregated statistics for admin dashboard.
|
||||
type DashboardService struct {
|
||||
usageRepo ports.UsageLogRepository
|
||||
usageRepo UsageLogRepository
|
||||
}
|
||||
|
||||
func NewDashboardService(usageRepo ports.UsageLogRepository) *DashboardService {
|
||||
func NewDashboardService(usageRepo UsageLogRepository) *DashboardService {
|
||||
return &DashboardService{
|
||||
usageRepo: usageRepo,
|
||||
}
|
||||
|
||||
@@ -4,23 +4,37 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmailNotConfigured = errors.New("email service not configured")
|
||||
ErrInvalidVerifyCode = errors.New("invalid or expired verification code")
|
||||
ErrVerifyCodeTooFrequent = errors.New("please wait before requesting a new code")
|
||||
ErrVerifyCodeMaxAttempts = errors.New("too many failed attempts, please request a new code")
|
||||
ErrEmailNotConfigured = infraerrors.ServiceUnavailable("EMAIL_NOT_CONFIGURED", "email service not configured")
|
||||
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
|
||||
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
|
||||
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
|
||||
)
|
||||
|
||||
// EmailCache defines cache operations for email service
|
||||
type EmailCache interface {
|
||||
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
|
||||
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
|
||||
DeleteVerificationCode(ctx context.Context, email string) error
|
||||
}
|
||||
|
||||
// VerificationCodeData represents verification code data
|
||||
type VerificationCodeData struct {
|
||||
Code string
|
||||
Attempts int
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
@@ -40,12 +54,12 @@ type SmtpConfig struct {
|
||||
|
||||
// EmailService 邮件服务
|
||||
type EmailService struct {
|
||||
settingRepo ports.SettingRepository
|
||||
cache ports.EmailCache
|
||||
settingRepo SettingRepository
|
||||
cache EmailCache
|
||||
}
|
||||
|
||||
// NewEmailService 创建邮件服务实例
|
||||
func NewEmailService(settingRepo ports.SettingRepository, cache ports.EmailCache) *EmailService {
|
||||
func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailService {
|
||||
return &EmailService{
|
||||
settingRepo: settingRepo,
|
||||
cache: cache,
|
||||
@@ -205,7 +219,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
||||
}
|
||||
|
||||
// 保存验证码到 Redis
|
||||
data := &ports.VerificationCodeData{
|
||||
data := &VerificationCodeData{
|
||||
Code: code,
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now(),
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
@@ -54,6 +53,13 @@ var allowedHeaders = map[string]bool{
|
||||
"content-type": true,
|
||||
}
|
||||
|
||||
// GatewayCache defines cache operations for gateway service
|
||||
type GatewayCache interface {
|
||||
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
|
||||
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
|
||||
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// ClaudeUsage 表示Claude API返回的usage信息
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
@@ -74,32 +80,32 @@ type ForwardResult struct {
|
||||
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
cache ports.GatewayCache
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
httpUpstream ports.HTTPUpstream
|
||||
httpUpstream HTTPUpstream
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
func NewGatewayService(
|
||||
accountRepo ports.AccountRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userRepo ports.UserRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
cache ports.GatewayCache,
|
||||
accountRepo AccountRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
identityService *IdentityService,
|
||||
httpUpstream ports.HTTPUpstream,
|
||||
httpUpstream HTTPUpstream,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -362,8 +368,8 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou
|
||||
|
||||
// 重试相关常量
|
||||
const (
|
||||
maxRetries = 5 // 最大重试次数
|
||||
retryDelay = 6 * time.Second // 重试等待时间
|
||||
maxRetries = 10 // 最大重试次数
|
||||
retryDelay = 3 * time.Second // 重试等待时间
|
||||
)
|
||||
|
||||
// shouldRetryUpstreamError 判断是否应该重试上游错误
|
||||
@@ -507,7 +513,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
// OAuth账号:应用统一指纹
|
||||
var fingerprint *ports.Fingerprint
|
||||
var fingerprint *Fingerprint
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
|
||||
@@ -2,20 +2,35 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrGroupNotFound = errors.New("group not found")
|
||||
ErrGroupExists = errors.New("group name already exists")
|
||||
ErrGroupNotFound = infraerrors.NotFound("GROUP_NOT_FOUND", "group not found")
|
||||
ErrGroupExists = infraerrors.Conflict("GROUP_EXISTS", "group name already exists")
|
||||
)
|
||||
|
||||
type GroupRepository interface {
|
||||
Create(ctx context.Context, group *model.Group) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Group, error)
|
||||
Update(ctx context.Context, group *model.Group) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]model.Group, error)
|
||||
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error)
|
||||
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
||||
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
|
||||
// CreateGroupRequest 创建分组请求
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
@@ -35,11 +50,11 @@ type UpdateGroupRequest struct {
|
||||
|
||||
// GroupService 分组管理服务
|
||||
type GroupService struct {
|
||||
groupRepo ports.GroupRepository
|
||||
groupRepo GroupRepository
|
||||
}
|
||||
|
||||
// NewGroupService 创建分组服务实例
|
||||
func NewGroupService(groupRepo ports.GroupRepository) *GroupService {
|
||||
func NewGroupService(groupRepo GroupRepository) *GroupService {
|
||||
return &GroupService{
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
@@ -76,9 +91,6 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
|
||||
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
return group, nil
|
||||
@@ -106,9 +118,6 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
|
||||
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
@@ -153,9 +162,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
||||
// 检查分组是否存在
|
||||
_, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrGroupNotFound
|
||||
}
|
||||
return fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
@@ -170,9 +176,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
||||
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package ports
|
||||
package service
|
||||
|
||||
import "net/http"
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
@@ -24,7 +23,7 @@ var (
|
||||
)
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = ports.Fingerprint{
|
||||
var defaultFingerprint = Fingerprint{
|
||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.52.0",
|
||||
@@ -34,20 +33,38 @@ var defaultFingerprint = ports.Fingerprint{
|
||||
StainlessRuntimeVersion: "v22.14.0",
|
||||
}
|
||||
|
||||
// Fingerprint represents account fingerprint data
|
||||
type Fingerprint struct {
|
||||
ClientID string
|
||||
UserAgent string
|
||||
StainlessLang string
|
||||
StainlessPackageVersion string
|
||||
StainlessOS string
|
||||
StainlessArch string
|
||||
StainlessRuntime string
|
||||
StainlessRuntimeVersion string
|
||||
}
|
||||
|
||||
// IdentityCache defines cache operations for identity service
|
||||
type IdentityCache interface {
|
||||
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
|
||||
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
|
||||
}
|
||||
|
||||
// IdentityService 管理OAuth账号的请求身份指纹
|
||||
type IdentityService struct {
|
||||
cache ports.IdentityCache
|
||||
cache IdentityCache
|
||||
}
|
||||
|
||||
// NewIdentityService 创建新的IdentityService
|
||||
func NewIdentityService(cache ports.IdentityCache) *IdentityService {
|
||||
func NewIdentityService(cache IdentityCache) *IdentityService {
|
||||
return &IdentityService{cache: cache}
|
||||
}
|
||||
|
||||
// GetOrCreateFingerprint 获取或创建账号的指纹
|
||||
// 如果缓存存在,检测user-agent版本,新版本则更新
|
||||
// 如果缓存不存在,生成随机ClientID并从请求头创建指纹,然后缓存
|
||||
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*ports.Fingerprint, error) {
|
||||
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
|
||||
// 尝试从缓存获取指纹
|
||||
cached, err := s.cache.GetFingerprint(ctx, accountID)
|
||||
if err == nil && cached != nil {
|
||||
@@ -79,8 +96,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// createFingerprintFromHeaders 从请求头创建指纹
|
||||
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *ports.Fingerprint {
|
||||
fp := &ports.Fingerprint{}
|
||||
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
|
||||
fp := &Fingerprint{}
|
||||
|
||||
// 获取User-Agent
|
||||
if ua := headers.Get("User-Agent"); ua != "" {
|
||||
@@ -109,7 +126,7 @@ func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
|
||||
}
|
||||
|
||||
// ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头)
|
||||
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerprint) {
|
||||
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
|
||||
if fp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,9 +8,15 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
// OpenAIOAuthClient interface for OpenAI OAuth operations
|
||||
type OpenAIOAuthClient interface {
|
||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
||||
}
|
||||
|
||||
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
||||
type ClaudeOAuthClient interface {
|
||||
GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error)
|
||||
@@ -22,12 +28,12 @@ type ClaudeOAuthClient interface {
|
||||
// OAuthService handles OAuth authentication flows
|
||||
type OAuthService struct {
|
||||
sessionStore *oauth.SessionStore
|
||||
proxyRepo ports.ProxyRepository
|
||||
proxyRepo ProxyRepository
|
||||
oauthClient ClaudeOAuthClient
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuth service
|
||||
func NewOAuthService(proxyRepo ports.ProxyRepository, oauthClient ClaudeOAuthClient) *OAuthService {
|
||||
func NewOAuthService(proxyRepo ProxyRepository, oauthClient ClaudeOAuthClient) *OAuthService {
|
||||
return &OAuthService{
|
||||
sessionStore: oauth.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user