mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-11 10:34:46 +08:00
Merge pull request #966 from GuangYiDing/feat/db-backup-restore
feat: 数据库定时备份与恢复(S3 兼容存储,支持 Cloudflare R2)
This commit is contained in:
@@ -94,6 +94,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -230,6 +231,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -146,6 +146,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
dataManagementService := service.NewDataManagementService()
|
||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||
dbDumper := repository.NewPgDumper(configConfig)
|
||||
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -201,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -232,7 +236,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -285,6 +289,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -420,6 +425,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
204
backend/internal/handler/admin/backup_handler.go
Normal file
204
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupHandler struct {
|
||||
backupService *service.BackupService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||
return &BackupHandler{
|
||||
backupService: backupService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置 ───
|
||||
|
||||
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||
}
|
||||
|
||||
// ─── 定时备份 ───
|
||||
|
||||
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||
var req service.BackupScheduleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// ─── 备份操作 ───
|
||||
|
||||
type CreateBackupRequest struct {
|
||||
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||
}
|
||||
|
||||
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||
var req CreateBackupRequest
|
||||
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||
|
||||
expireDays := 14 // 默认14天过期
|
||||
if req.ExpireDays != nil {
|
||||
expireDays = *req.ExpireDays
|
||||
}
|
||||
|
||||
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if records == nil {
|
||||
records = []service.BackupRecord{}
|
||||
}
|
||||
response.Success(c, gin.H{"items": records})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||
|
||||
type RestoreBackupRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req RestoreBackupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "password is required for restore operation")
|
||||
return
|
||||
}
|
||||
|
||||
// 从上下文获取当前管理员用户 ID
|
||||
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取管理员用户并验证密码
|
||||
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if !user.CheckPassword(req.Password) {
|
||||
response.BadRequest(c, "incorrect admin password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"restored": true})
|
||||
}
|
||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
Backup *admin.BackupHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
|
||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
dataManagementHandler *admin.DataManagementHandler,
|
||||
backupHandler *admin.BackupHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
DataManagement: dataManagementHandler,
|
||||
Backup: backupHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewDataManagementHandler,
|
||||
admin.NewBackupHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
|
||||
98
backend/internal/repository/backup_pg_dumper.go
Normal file
98
backend/internal/repository/backup_pg_dumper.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// PgDumper implements service.DBDumper using pg_dump/psql
|
||||
type PgDumper struct {
|
||||
cfg *config.DatabaseConfig
|
||||
}
|
||||
|
||||
// NewPgDumper creates a new PgDumper
|
||||
func NewPgDumper(cfg *config.Config) service.DBDumper {
|
||||
return &PgDumper{cfg: &cfg.Database}
|
||||
}
|
||||
|
||||
// Dump executes pg_dump and returns a streaming reader of the output
|
||||
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--no-owner",
|
||||
"--no-acl",
|
||||
"--clean",
|
||||
"--if-exists",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_dump", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 返回一个 ReadCloser:读 stdout,关闭时等待进程退出
|
||||
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
|
||||
}
|
||||
|
||||
// Restore executes psql to restore from a streaming reader
|
||||
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--single-transaction",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
cmd.Stdin = data
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %s", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
|
||||
type cmdReadCloser struct {
|
||||
io.ReadCloser
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (c *cmdReadCloser) Close() error {
|
||||
// Close the pipe first
|
||||
_ = c.ReadCloser.Close()
|
||||
// Wait for the process to exit
|
||||
if err := c.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("pg_dump exited with error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
backend/internal/repository/backup_s3_store.go
Normal file
116
backend/internal/repository/backup_s3_store.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
|
||||
type S3BackupStore struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
}
|
||||
|
||||
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
|
||||
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
||||
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
|
||||
region := cfg.Region
|
||||
if region == "" {
|
||||
region = "auto" // Cloudflare R2 默认 region
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(region),
|
||||
awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = &cfg.Endpoint
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
})
|
||||
|
||||
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
Body: bytes.NewReader(data),
|
||||
ContentType: &contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S3 PutObject: %w", err)
|
||||
}
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("S3 GetObject: %w", err)
|
||||
}
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
|
||||
presignClient := s3.NewPresignClient(s.client)
|
||||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
}, s3.WithPresignExpires(expiry))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
|
||||
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &s.bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -100,6 +100,10 @@ var ProviderSet = wire.NewSet(
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
// Backup infrastructure
|
||||
NewPgDumper,
|
||||
NewS3BackupStoreFactory,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
ProvidePricingRemoteClient,
|
||||
|
||||
@@ -58,6 +58,9 @@ func RegisterAdminRoutes(
|
||||
// 数据管理
|
||||
registerDataManagementRoutes(admin, h)
|
||||
|
||||
// 数据库备份恢复
|
||||
registerBackupRoutes(admin, h)
|
||||
|
||||
// 运维监控(Ops)
|
||||
registerOpsRoutes(admin, h)
|
||||
|
||||
@@ -440,6 +443,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
backup := admin.Group("/backups")
|
||||
{
|
||||
// S3 存储配置
|
||||
backup.GET("/s3-config", h.Admin.Backup.GetS3Config)
|
||||
backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config)
|
||||
backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection)
|
||||
|
||||
// 定时备份配置
|
||||
backup.GET("/schedule", h.Admin.Backup.GetSchedule)
|
||||
backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule)
|
||||
|
||||
// 备份操作
|
||||
backup.POST("", h.Admin.Backup.CreateBackup)
|
||||
backup.GET("", h.Admin.Backup.ListBackups)
|
||||
backup.GET("/:id", h.Admin.Backup.GetBackup)
|
||||
backup.DELETE("/:id", h.Admin.Backup.DeleteBackup)
|
||||
backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL)
|
||||
|
||||
// 恢复操作
|
||||
backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
system := admin.Group("/system")
|
||||
{
|
||||
|
||||
776
backend/internal/service/backup_service.go
Normal file
776
backend/internal/service/backup_service.go
Normal file
@@ -0,0 +1,776 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
settingKeyBackupS3Config = "backup_s3_config"
|
||||
settingKeyBackupSchedule = "backup_schedule"
|
||||
settingKeyBackupRecords = "backup_records"
|
||||
|
||||
maxBackupRecords = 100
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured")
|
||||
ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found")
|
||||
ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress")
|
||||
ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress")
|
||||
ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted")
|
||||
ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted")
|
||||
)
|
||||
|
||||
// ─── 接口定义 ───
|
||||
|
||||
// DBDumper abstracts database dump/restore operations
|
||||
type DBDumper interface {
|
||||
Dump(ctx context.Context) (io.ReadCloser, error)
|
||||
Restore(ctx context.Context, data io.Reader) error
|
||||
}
|
||||
|
||||
// BackupObjectStore abstracts object storage for backup files
|
||||
type BackupObjectStore interface {
|
||||
Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error)
|
||||
Download(ctx context.Context, key string) (io.ReadCloser, error)
|
||||
Delete(ctx context.Context, key string) error
|
||||
PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error)
|
||||
HeadBucket(ctx context.Context) error
|
||||
}
|
||||
|
||||
// BackupObjectStoreFactory creates an object store from S3 config
|
||||
type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error)
|
||||
|
||||
// ─── 数据模型 ───
|
||||
|
||||
// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2)
|
||||
type BackupS3Config struct {
|
||||
Endpoint string `json:"endpoint"` // e.g. https://<account_id>.r2.cloudflarestorage.com
|
||||
Region string `json:"region"` // R2 用 "auto"
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key,omitempty"` //nolint:revive // field name follows AWS convention
|
||||
Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/"
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
}
|
||||
|
||||
// IsConfigured 检查必要字段是否已配置
|
||||
func (c *BackupS3Config) IsConfigured() bool {
|
||||
return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != ""
|
||||
}
|
||||
|
||||
// BackupScheduleConfig 定时备份配置
|
||||
type BackupScheduleConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点
|
||||
RetainDays int `json:"retain_days"` // 备份文件过期天数,默认14,0=不自动清理
|
||||
RetainCount int `json:"retain_count"` // 最多保留份数,0=不限制
|
||||
}
|
||||
|
||||
// BackupRecord 备份记录
|
||||
type BackupRecord struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"` // pending, running, completed, failed
|
||||
BackupType string `json:"backup_type"` // postgres
|
||||
FileName string `json:"file_name"`
|
||||
S3Key string `json:"s3_key"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TriggeredBy string `json:"triggered_by"` // manual, scheduled
|
||||
ErrorMsg string `json:"error_message,omitempty"`
|
||||
StartedAt string `json:"started_at"`
|
||||
FinishedAt string `json:"finished_at,omitempty"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
|
||||
}
|
||||
|
||||
// BackupService 数据库备份恢复服务
|
||||
type BackupService struct {
|
||||
settingRepo SettingRepository
|
||||
dbCfg *config.DatabaseConfig
|
||||
encryptor SecretEncryptor
|
||||
storeFactory BackupObjectStoreFactory
|
||||
dumper DBDumper
|
||||
|
||||
mu sync.Mutex
|
||||
store BackupObjectStore
|
||||
s3Cfg *BackupS3Config
|
||||
backingUp bool
|
||||
restoring bool
|
||||
|
||||
recordsMu sync.Mutex // 保护 records 的 load/save 操作
|
||||
|
||||
cronMu sync.Mutex
|
||||
cronSched *cron.Cron
|
||||
cronEntryID cron.EntryID
|
||||
}
|
||||
|
||||
func NewBackupService(
|
||||
settingRepo SettingRepository,
|
||||
cfg *config.Config,
|
||||
encryptor SecretEncryptor,
|
||||
storeFactory BackupObjectStoreFactory,
|
||||
dumper DBDumper,
|
||||
) *BackupService {
|
||||
return &BackupService{
|
||||
settingRepo: settingRepo,
|
||||
dbCfg: &cfg.Database,
|
||||
encryptor: encryptor,
|
||||
storeFactory: storeFactory,
|
||||
dumper: dumper,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动定时备份调度器
|
||||
func (s *BackupService) Start() {
|
||||
s.cronSched = cron.New()
|
||||
s.cronSched.Start()
|
||||
|
||||
// 加载已有的定时配置
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
schedule, err := s.GetSchedule(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err)
|
||||
return
|
||||
}
|
||||
if schedule.Enabled && schedule.CronExpr != "" {
|
||||
if err := s.applyCronSchedule(schedule); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止定时备份
|
||||
func (s *BackupService) Stop() {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
if s.cronSched != nil {
|
||||
s.cronSched.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置管理 ───
|
||||
|
||||
func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg == nil {
|
||||
return &BackupS3Config{}, nil
|
||||
}
|
||||
// 脱敏返回
|
||||
cfg.SecretAccessKey = ""
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) {
|
||||
// 如果没提供 secret,保留原有值
|
||||
if cfg.SecretAccessKey == "" {
|
||||
old, _ := s.loadS3Config(ctx)
|
||||
if old != nil {
|
||||
cfg.SecretAccessKey = old.SecretAccessKey
|
||||
}
|
||||
} else {
|
||||
// 加密 SecretAccessKey
|
||||
encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encrypt secret: %w", err)
|
||||
}
|
||||
cfg.SecretAccessKey = encrypted
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal s3 config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil {
|
||||
return nil, fmt.Errorf("save s3 config: %w", err)
|
||||
}
|
||||
|
||||
// 清除缓存的 S3 客户端
|
||||
s.mu.Lock()
|
||||
s.store = nil
|
||||
s.s3Cfg = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
cfg.SecretAccessKey = ""
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error {
|
||||
// 如果没提供 secret,用已保存的
|
||||
if cfg.SecretAccessKey == "" {
|
||||
old, _ := s.loadS3Config(ctx)
|
||||
if old != nil {
|
||||
cfg.SecretAccessKey = old.SecretAccessKey
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
|
||||
return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required")
|
||||
}
|
||||
|
||||
store, err := s.storeFactory(ctx, &cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return store.HeadBucket(ctx)
|
||||
}
|
||||
|
||||
// ─── 定时备份管理 ───
|
||||
|
||||
func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule)
|
||||
if err != nil || raw == "" {
|
||||
return &BackupScheduleConfig{}, nil
|
||||
}
|
||||
var cfg BackupScheduleConfig
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return &BackupScheduleConfig{}, nil
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) {
|
||||
if cfg.Enabled && cfg.CronExpr == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled")
|
||||
}
|
||||
// 验证 cron 表达式
|
||||
if cfg.CronExpr != "" {
|
||||
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
if _, err := parser.Parse(cfg.CronExpr); err != nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal schedule config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil {
|
||||
return nil, fmt.Errorf("save schedule config: %w", err)
|
||||
}
|
||||
|
||||
// 应用或停止定时任务
|
||||
if cfg.Enabled {
|
||||
if err := s.applyCronSchedule(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
s.removeCronSchedule()
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
|
||||
if s.cronSched == nil {
|
||||
return fmt.Errorf("cron scheduler not initialized")
|
||||
}
|
||||
|
||||
// 移除旧任务
|
||||
if s.cronEntryID != 0 {
|
||||
s.cronSched.Remove(s.cronEntryID)
|
||||
s.cronEntryID = 0
|
||||
}
|
||||
|
||||
entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() {
|
||||
s.runScheduledBackup()
|
||||
})
|
||||
if err != nil {
|
||||
return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err))
|
||||
}
|
||||
s.cronEntryID = entryID
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) removeCronSchedule() {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
if s.cronSched != nil && s.cronEntryID != 0 {
|
||||
s.cronSched.Remove(s.cronEntryID)
|
||||
s.cronEntryID = 0
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupService) runScheduledBackup() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// 读取定时备份配置中的过期天数
|
||||
schedule, _ := s.GetSchedule(ctx)
|
||||
expireDays := 14 // 默认14天过期
|
||||
if schedule != nil && schedule.RetainDays > 0 {
|
||||
expireDays = schedule.RetainDays
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
|
||||
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
|
||||
|
||||
// 清理过期备份(复用已加载的 schedule)
|
||||
if schedule == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cleanupOldBackups(ctx, schedule); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── 备份/恢复核心 ───
|
||||
|
||||
// CreateBackup 创建全量数据库备份并上传到 S3(流式处理)
|
||||
// expireDays: 备份过期天数,0=永不过期,默认14天
|
||||
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
||||
s.mu.Lock()
|
||||
if s.backingUp {
|
||||
s.mu.Unlock()
|
||||
return nil, ErrBackupInProgress
|
||||
}
|
||||
s.backingUp = true
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.backingUp = false
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s3Cfg == nil || !s3Cfg.IsConfigured() {
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
backupID := uuid.New().String()[:8]
|
||||
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
|
||||
s3Key := s.buildS3Key(s3Cfg, fileName)
|
||||
|
||||
var expiresAt string
|
||||
if expireDays > 0 {
|
||||
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
record := &BackupRecord{
|
||||
ID: backupID,
|
||||
Status: "running",
|
||||
BackupType: "postgres",
|
||||
FileName: fileName,
|
||||
S3Key: s3Key,
|
||||
TriggeredBy: triggeredBy,
|
||||
StartedAt: now.Format(time.RFC3339),
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
// 流式执行: pg_dump -> gzip -> S3 upload
|
||||
dumpReader, err := s.dumper.Dump(ctx)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
|
||||
pr, pw := io.Pipe()
|
||||
var gzipErr error
|
||||
go func() {
|
||||
gzWriter := gzip.NewWriter(pw)
|
||||
_, gzipErr = io.Copy(gzWriter, dumpReader)
|
||||
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if gzipErr != nil {
|
||||
_ = pw.CloseWithError(gzipErr)
|
||||
} else {
|
||||
_ = pw.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
contentType := "application/gzip"
|
||||
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
||||
if gzipErr != nil {
|
||||
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
|
||||
}
|
||||
record.ErrorMsg = errMsg
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("backup upload: %w", err)
|
||||
}
|
||||
|
||||
record.SizeBytes = sizeBytes
|
||||
record.Status = "completed"
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
if err := s.saveRecord(ctx, record); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
|
||||
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
|
||||
s.mu.Lock()
|
||||
if s.restoring {
|
||||
s.mu.Unlock()
|
||||
return ErrRestoreInProgress
|
||||
}
|
||||
s.restoring = true
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.restoring = false
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
record, err := s.GetBackupRecord(ctx, backupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if record.Status != "completed" {
|
||||
return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
|
||||
}
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
// 从 S3 流式下载
|
||||
body, err := objectStore.Download(ctx, record.S3Key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 download failed: %w", err)
|
||||
}
|
||||
defer func() { _ = body.Close() }()
|
||||
|
||||
// 流式解压 gzip -> psql(不将全部数据加载到内存)
|
||||
gzReader, err := gzip.NewReader(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gzip reader: %w", err)
|
||||
}
|
||||
defer func() { _ = gzReader.Close() }()
|
||||
|
||||
// 流式恢复
|
||||
if err := s.dumper.Restore(ctx, gzReader); err != nil {
|
||||
return fmt.Errorf("pg restore: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ─── 备份记录管理 ───
|
||||
|
||||
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
|
||||
records, err := s.loadRecords(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 倒序返回(最新在前)
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].StartedAt > records[j].StartedAt
|
||||
})
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) {
|
||||
records, err := s.loadRecords(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range records {
|
||||
if records[i].ID == backupID {
|
||||
return &records[i], nil
|
||||
}
|
||||
}
|
||||
return nil, ErrBackupNotFound
|
||||
}
|
||||
|
||||
func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var found *BackupRecord
|
||||
var remaining []BackupRecord
|
||||
for i := range records {
|
||||
if records[i].ID == backupID {
|
||||
found = &records[i]
|
||||
} else {
|
||||
remaining = append(remaining, records[i])
|
||||
}
|
||||
}
|
||||
if found == nil {
|
||||
return ErrBackupNotFound
|
||||
}
|
||||
|
||||
// 从 S3 删除
|
||||
if found.S3Key != "" && found.Status == "completed" {
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() {
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err == nil {
|
||||
_ = objectStore.Delete(ctx, found.S3Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.saveRecordsLocked(ctx, remaining)
|
||||
}
|
||||
|
||||
// GetBackupDownloadURL 获取备份文件预签名下载 URL
|
||||
func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) {
|
||||
record, err := s.GetBackupRecord(ctx, backupID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if record.Status != "completed" {
|
||||
return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed")
|
||||
}
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// ─── 内部方法 ───
|
||||
|
||||
func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil //nolint:nilnil // no config is a valid state
|
||||
}
|
||||
var cfg BackupS3Config
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return nil, ErrBackupS3ConfigCorrupt
|
||||
}
|
||||
// 解密 SecretAccessKey
|
||||
if cfg.SecretAccessKey != "" {
|
||||
decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
// 兼容未加密的旧数据:如果解密失败,保持原值
|
||||
logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err)
|
||||
} else {
|
||||
cfg.SecretAccessKey = decrypted
|
||||
}
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.store != nil && s.s3Cfg != nil {
|
||||
return s.store, nil
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
store, err := s.storeFactory(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.store = store
|
||||
s.s3Cfg = cfg
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string {
|
||||
prefix := strings.TrimRight(cfg.Prefix, "/")
|
||||
if prefix == "" {
|
||||
prefix = "backups"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName)
|
||||
}
|
||||
|
||||
// loadRecords 加载备份记录,区分"无数据"和"数据损坏"
|
||||
func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
return s.loadRecordsLocked(ctx)
|
||||
}
|
||||
|
||||
// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录
|
||||
func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil //nolint:nilnil // no records is a valid state
|
||||
}
|
||||
var records []BackupRecord
|
||||
if err := json.Unmarshal([]byte(raw), &records); err != nil {
|
||||
return nil, ErrBackupRecordsCorrupt
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) saveRecords(ctx context.Context, records []BackupRecord) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
return s.saveRecordsLocked(ctx, records)
|
||||
}
|
||||
|
||||
// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录
|
||||
func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error {
|
||||
data, err := json.Marshal(records)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data))
|
||||
}
|
||||
|
||||
// saveRecord 保存单条记录(带互斥锁保护)
|
||||
func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, _ := s.loadRecordsLocked(ctx)
|
||||
|
||||
// 更新已有记录或追加
|
||||
found := false
|
||||
for i := range records {
|
||||
if records[i].ID == record.ID {
|
||||
records[i] = *record
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
records = append(records, *record)
|
||||
}
|
||||
|
||||
// 限制记录数量
|
||||
if len(records) > maxBackupRecords {
|
||||
records = records[len(records)-maxBackupRecords:]
|
||||
}
|
||||
|
||||
return s.saveRecordsLocked(ctx, records)
|
||||
}
|
||||
|
||||
func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error {
|
||||
if schedule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 按时间倒序
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].StartedAt > records[j].StartedAt
|
||||
})
|
||||
|
||||
var toDelete []BackupRecord
|
||||
var toKeep []BackupRecord
|
||||
|
||||
for i, r := range records {
|
||||
shouldDelete := false
|
||||
|
||||
// 按保留份数清理
|
||||
if schedule.RetainCount > 0 && i >= schedule.RetainCount {
|
||||
shouldDelete = true
|
||||
}
|
||||
|
||||
// 按保留天数清理
|
||||
if schedule.RetainDays > 0 && r.StartedAt != "" {
|
||||
startedAt, err := time.Parse(time.RFC3339, r.StartedAt)
|
||||
if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour {
|
||||
shouldDelete = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDelete && r.Status == "completed" {
|
||||
toDelete = append(toDelete, r)
|
||||
} else {
|
||||
toKeep = append(toKeep, r)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除 S3 上的文件
|
||||
for _, r := range toDelete {
|
||||
if r.S3Key != "" {
|
||||
_ = s.deleteS3Object(ctx, r.S3Key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toDelete) > 0 {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete))
|
||||
return s.saveRecordsLocked(ctx, toKeep)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteS3Object(ctx context.Context, key string) error {
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil || s3Cfg == nil {
|
||||
return nil
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return objectStore.Delete(ctx, key)
|
||||
}
|
||||
528
backend/internal/service/backup_service_test.go
Normal file
528
backend/internal/service/backup_service_test.go
Normal file
@@ -0,0 +1,528 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// ─── Mocks ───
|
||||
|
||||
type mockSettingRepo struct {
|
||||
mu sync.Mutex
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func newMockSettingRepo() *mockSettingRepo {
|
||||
return &mockSettingRepo{data: make(map[string]string)}
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
v, ok := m.data[key]
|
||||
if !ok {
|
||||
return nil, ErrSettingNotFound
|
||||
}
|
||||
return &Setting{Key: key, Value: v}, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
v, ok := m.data[key]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Set(_ context.Context, key, value string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make(map[string]string)
|
||||
for _, k := range keys {
|
||||
if v, ok := m.data[k]; ok {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for k, v := range settings {
|
||||
m.data[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make(map[string]string, len(m.data))
|
||||
for k, v := range m.data {
|
||||
result[k] = v
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Delete(_ context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.data, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// plainEncryptor 仅做 base64-like 包装,用于测试
|
||||
type plainEncryptor struct{}
|
||||
|
||||
func (e *plainEncryptor) Encrypt(plaintext string) (string, error) {
|
||||
return "ENC:" + plaintext, nil
|
||||
}
|
||||
|
||||
func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||
if strings.HasPrefix(ciphertext, "ENC:") {
|
||||
return strings.TrimPrefix(ciphertext, "ENC:"), nil
|
||||
}
|
||||
return ciphertext, fmt.Errorf("not encrypted")
|
||||
}
|
||||
|
||||
type mockDumper struct {
|
||||
dumpData []byte
|
||||
dumpErr error
|
||||
restored []byte
|
||||
restErr error
|
||||
}
|
||||
|
||||
func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) {
|
||||
if m.dumpErr != nil {
|
||||
return nil, m.dumpErr
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(m.dumpData)), nil
|
||||
}
|
||||
|
||||
func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
|
||||
if m.restErr != nil {
|
||||
return m.restErr
|
||||
}
|
||||
d, err := io.ReadAll(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.restored = d
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockObjectStore struct {
|
||||
objects map[string][]byte
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockObjectStore() *mockObjectStore {
|
||||
return &mockObjectStore{objects: make(map[string][]byte)}
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) {
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.objects[key] = data
|
||||
m.mu.Unlock()
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) {
|
||||
m.mu.Lock()
|
||||
data, ok := m.objects[key]
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("not found: %s", key)
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(data)), nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Delete(_ context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
delete(m.objects, key)
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) {
|
||||
return "https://presigned.example.com/" + key, nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) HeadBucket(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "test",
|
||||
DBName: "testdb",
|
||||
},
|
||||
}
|
||||
factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) {
|
||||
return store, nil
|
||||
}
|
||||
return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper)
|
||||
}
|
||||
|
||||
func seedS3Config(t *testing.T, repo *mockSettingRepo) {
|
||||
t.Helper()
|
||||
cfg := BackupS3Config{
|
||||
Bucket: "test-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "ENC:secret123",
|
||||
Prefix: "backups",
|
||||
}
|
||||
data, _ := json.Marshal(cfg)
|
||||
require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data)))
|
||||
}
|
||||
|
||||
// ─── Tests ───
|
||||
|
||||
func TestBackupService_S3ConfigEncryption(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 保存配置 -> SecretAccessKey 应被加密
|
||||
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "my-secret",
|
||||
Prefix: "backups",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 直接读取数据库中存储的值,应该是加密后的
|
||||
raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config)
|
||||
var stored BackupS3Config
|
||||
require.NoError(t, json.Unmarshal([]byte(raw), &stored))
|
||||
require.Equal(t, "ENC:my-secret", stored.SecretAccessKey)
|
||||
|
||||
// 通过 GetS3Config 获取应该脱敏
|
||||
cfg, err := svc.GetS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cfg.SecretAccessKey)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
|
||||
// loadS3Config 内部应解密
|
||||
internal, err := svc.loadS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "my-secret", internal.SecretAccessKey)
|
||||
}
|
||||
|
||||
func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 先保存一个有 secret 的配置
|
||||
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "original-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 再更新时不提供 secret,应保留原值
|
||||
_, err = svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID-NEW",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
internal, err := svc.loadS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original-secret", internal.SecretAccessKey)
|
||||
require.Equal(t, "AKID-NEW", internal.AccessKeyID)
|
||||
}
|
||||
|
||||
func TestBackupService_SaveRecordConcurrency(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
n := 20
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
record := &BackupRecord{
|
||||
ID: fmt.Sprintf("rec-%d", idx),
|
||||
Status: "completed",
|
||||
StartedAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
_ = svc.saveRecord(context.Background(), record)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, records, n)
|
||||
}
|
||||
|
||||
func TestBackupService_LoadRecords_Empty(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, records) // 无数据时返回 nil
|
||||
}
|
||||
|
||||
func TestBackupService_LoadRecords_Corrupted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
_ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{")
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.Error(t, err) // 损坏数据应返回错误
|
||||
require.Nil(t, records)
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_Streaming(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "completed", record.Status)
|
||||
require.Greater(t, record.SizeBytes, int64(0))
|
||||
require.NotEmpty(t, record.S3Key)
|
||||
|
||||
// 验证 S3 上确实有文件
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 1)
|
||||
store.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_DumpFailure(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "failed", record.Status)
|
||||
require.Contains(t, record.ErrorMsg, "pg_dump")
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_NoS3Config(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.ErrorIs(t, err, ErrBackupS3NotConfigured)
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
// 使用一个慢速 dumper 来模拟正在进行的备份
|
||||
dumper := &mockDumper{dumpData: []byte("data")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
// 手动设置 backingUp 标志
|
||||
svc.mu.Lock()
|
||||
svc.backingUp = true
|
||||
svc.mu.Unlock()
|
||||
|
||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.ErrorIs(t, err, ErrBackupInProgress)
|
||||
}
|
||||
|
||||
func TestBackupService_RestoreBackup_Streaming(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
// 先创建一个备份
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 恢复
|
||||
err = svc.RestoreBackup(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证 psql 收到的数据是否与原始 dump 内容一致
|
||||
require.Equal(t, dumpContent, string(dumper.restored))
|
||||
}
|
||||
|
||||
func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 手动插入一条 failed 记录
|
||||
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||
ID: "fail-1",
|
||||
Status: "failed",
|
||||
})
|
||||
|
||||
err := svc.RestoreBackup(context.Background(), "fail-1")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_DeleteBackup(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "data"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S3 中应有文件
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 1)
|
||||
store.mu.Unlock()
|
||||
|
||||
// 删除
|
||||
err = svc.DeleteBackup(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S3 中文件应被删除
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 0)
|
||||
store.mu.Unlock()
|
||||
|
||||
// 记录应不存在
|
||||
_, err = svc.GetBackupRecord(context.Background(), record.ID)
|
||||
require.ErrorIs(t, err, ErrBackupNotFound)
|
||||
}
|
||||
|
||||
func TestBackupService_GetDownloadURL(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumper := &mockDumper{dumpData: []byte("data")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
url, err := svc.GetBackupDownloadURL(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, url, "https://presigned.example.com/")
|
||||
}
|
||||
|
||||
func TestBackupService_ListBackups_Sorted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
now := time.Now()
|
||||
for i := 0; i < 3; i++ {
|
||||
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||
ID: fmt.Sprintf("rec-%d", i),
|
||||
Status: "completed",
|
||||
StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
records, err := svc.ListBackups(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, records, 3)
|
||||
// 最新在前
|
||||
require.Equal(t, "rec-2", records[0].ID)
|
||||
require.Equal(t, "rec-0", records[2].ID)
|
||||
}
|
||||
|
||||
func TestBackupService_TestS3Connection(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, store)
|
||||
|
||||
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||
Bucket: "test",
|
||||
AccessKeyID: "ak",
|
||||
SecretAccessKey: "sk",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_TestS3Connection_Incomplete(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||
Bucket: "test",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "incomplete")
|
||||
}
|
||||
|
||||
func TestBackupService_Schedule_CronValidation(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
svc.cronSched = nil // 未初始化 cron
|
||||
|
||||
// 启用但 cron 为空
|
||||
_, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||
Enabled: true,
|
||||
CronExpr: "",
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// 无效的 cron 表达式
|
||||
_, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||
Enabled: true,
|
||||
CronExpr: "invalid",
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
_ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!")
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
cfg, err := svc.loadS3Config(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Nil(t, cfg)
|
||||
}
|
||||
@@ -322,6 +322,19 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC
|
||||
return apiKeyService
|
||||
}
|
||||
|
||||
// ProvideBackupService creates and starts BackupService
|
||||
func ProvideBackupService(
|
||||
settingRepo SettingRepository,
|
||||
cfg *config.Config,
|
||||
encryptor SecretEncryptor,
|
||||
storeFactory BackupObjectStoreFactory,
|
||||
dumper DBDumper,
|
||||
) *BackupService {
|
||||
svc := NewBackupService(settingRepo, cfg, encryptor, storeFactory, dumper)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideSettingService wires SettingService with group reader for default subscription validation.
|
||||
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService {
|
||||
svc := NewSettingService(settingRepo, cfg)
|
||||
@@ -373,6 +386,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountTestService,
|
||||
ProvideSettingService,
|
||||
NewDataManagementService,
|
||||
ProvideBackupService,
|
||||
ProvideOpsSystemLogSink,
|
||||
NewOpsService,
|
||||
ProvideOpsMetricsCollector,
|
||||
|
||||
Reference in New Issue
Block a user