diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 139d883a..d7ce1340 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -145,8 +145,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) dataManagementService := service.NewDataManagementService() dataManagementHandler := admin.NewDataManagementHandler(dataManagementService) - backupService := service.ProvideBackupService(settingRepository, configConfig) - backupHandler := admin.NewBackupHandler(backupService) + backupObjectStoreFactory := repository.NewS3BackupStoreFactory() + dbDumper := repository.NewPgDumper(configConfig) + backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper) + backupHandler := admin.NewBackupHandler(backupService, userService) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) diff --git a/backend/internal/handler/admin/backup_handler.go b/backend/internal/handler/admin/backup_handler.go index 818928c6..d19713ee 100644 --- a/backend/internal/handler/admin/backup_handler.go +++ b/backend/internal/handler/admin/backup_handler.go @@ -2,16 +2,21 @@ package admin import ( "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) type BackupHandler struct { backupService *service.BackupService + userService *service.UserService } -func NewBackupHandler(backupService *service.BackupService) *BackupHandler { - return &BackupHandler{backupService: backupService} +func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler { + return &BackupHandler{ + backupService: backupService, + userService: userService, + } } // ─── S3 配置 ─── @@ -154,7 +159,11 @@ func (h *BackupHandler) GetDownloadURL(c *gin.Context) { response.Success(c, gin.H{"url": url}) } -// ─── 恢复操作 ─── +// ─── 恢复操作(需要重新输入管理员密码) ─── + +type RestoreBackupRequest struct { + Password string `json:"password" binding:"required"` +} func (h *BackupHandler) RestoreBackup(c *gin.Context) { backupID := c.Param("id") @@ -162,6 +171,31 @@ func (h *BackupHandler) RestoreBackup(c *gin.Context) { response.BadRequest(c, "backup ID is required") return } + + var req RestoreBackupRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "password is required for restore operation") + return + } + + // 从上下文获取当前管理员用户 ID + sub, ok := middleware.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "unauthorized") + return + } + + // 获取管理员用户并验证密码 + user, err := h.userService.GetByID(c.Request.Context(), sub.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if !user.CheckPassword(req.Password) { + response.BadRequest(c, "incorrect admin password") + return + } + if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/repository/backup_pg_dumper.go b/backend/internal/repository/backup_pg_dumper.go new file mode 100644 index 00000000..e9a92ef2 --- /dev/null +++ b/backend/internal/repository/backup_pg_dumper.go @@ -0,0 +1,98 @@ +package repository + +import ( + "context" + "fmt" + "io" + "os/exec" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// PgDumper implements service.DBDumper using pg_dump/psql +type PgDumper struct { + cfg *config.DatabaseConfig +} + +// NewPgDumper creates a new PgDumper +func NewPgDumper(cfg *config.Config) service.DBDumper { + return &PgDumper{cfg: &cfg.Database} +} + +// Dump executes pg_dump and returns a streaming reader of the output +func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) { + args := []string{ + "-h", d.cfg.Host, + "-p", fmt.Sprintf("%d", d.cfg.Port), + "-U", d.cfg.User, + "-d", d.cfg.DBName, + "--no-owner", + "--no-acl", + "--clean", + "--if-exists", + } + + cmd := exec.CommandContext(ctx, "pg_dump", args...) + if d.cfg.Password != "" { + cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password) + } + if d.cfg.SSLMode != "" { + cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("create stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("start pg_dump: %w", err) + } + + // 返回一个 ReadCloser:读 stdout,关闭时等待进程退出 + return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil +} + +// Restore executes psql to restore from a streaming reader +func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error { + args := []string{ + "-h", d.cfg.Host, + "-p", fmt.Sprintf("%d", d.cfg.Port), + "-U", d.cfg.User, + "-d", d.cfg.DBName, + "--single-transaction", + } + + cmd := exec.CommandContext(ctx, "psql", args...) + if d.cfg.Password != "" { + cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password) + } + if d.cfg.SSLMode != "" { + cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode) + } + + cmd.Stdin = data + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("%v: %s", err, string(output)) + } + return nil +} + +// cmdReadCloser wraps a command stdout pipe and waits for the process on Close +type cmdReadCloser struct { + io.ReadCloser + cmd *exec.Cmd +} + +func (c *cmdReadCloser) Close() error { + // Close the pipe first + _ = c.ReadCloser.Close() + // Wait for the process to exit + if err := c.cmd.Wait(); err != nil { + return fmt.Errorf("pg_dump exited with error: %w", err) + } + return nil +} diff --git a/backend/internal/repository/backup_s3_store.go b/backend/internal/repository/backup_s3_store.go new file mode 100644 index 00000000..ba5434f5 --- /dev/null +++ b/backend/internal/repository/backup_s3_store.go @@ -0,0 +1,116 @@ +package repository + +import ( + "bytes" + "context" + "fmt" + "io" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage +type S3BackupStore struct { + client *s3.Client + bucket string +} + +// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores +func NewS3BackupStoreFactory() service.BackupObjectStoreFactory { + return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) { + region := cfg.Region + if region == "" { + region = "auto" // Cloudflare R2 默认 region + } + + awsCfg, err := awsconfig.LoadDefaultConfig(ctx, + awsconfig.WithRegion(region), + awsconfig.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""), + ), + ) + if err != nil { + return nil, fmt.Errorf("load aws config: %w", err) + } + + client := s3.NewFromConfig(awsCfg, func(o *s3.Options) { + if cfg.Endpoint != "" { + o.BaseEndpoint = &cfg.Endpoint + } + if cfg.ForcePathStyle { + o.UsePathStyle = true + } + o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware) + o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired + }) + + return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil + } +} + +func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) { + // 读取全部内容以获取大小(S3 PutObject 需要知道内容长度) + data, err := io.ReadAll(body) + if err != nil { + return 0, fmt.Errorf("read body: %w", err) + } + + _, err = s.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: &s.bucket, + Key: &key, + Body: bytes.NewReader(data), + ContentType: &contentType, + }) + if err != nil { + return 0, fmt.Errorf("S3 PutObject: %w", err) + } + return int64(len(data)), nil +} + +func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) { + result, err := s.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: &s.bucket, + Key: &key, + }) + if err != nil { + return nil, fmt.Errorf("S3 GetObject: %w", err) + } + return result.Body, nil +} + +func (s *S3BackupStore) Delete(ctx context.Context, key string) error { + _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: &s.bucket, + Key: &key, + }) + return err +} + +func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) { + presignClient := s3.NewPresignClient(s.client) + result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ + Bucket: &s.bucket, + Key: &key, + }, s3.WithPresignExpires(expiry)) + if err != nil { + return "", fmt.Errorf("presign url: %w", err) + } + return result.URL, nil +} + +func (s *S3BackupStore) HeadBucket(ctx context.Context) error { + _, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: &s.bucket, + }) + if err != nil { + return fmt.Errorf("S3 HeadBucket failed: %w", err) + } + return nil +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 5fe7a98e..1d0b5600 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -99,6 +99,10 @@ var ProviderSet = wire.NewSet( // Encryptors NewAESEncryptor, + // Backup infrastructure + NewPgDumper, + NewS3BackupStoreFactory, + // HTTP service ports (DI Strategy A: return interface directly) NewTurnstileVerifier, ProvidePricingRemoteClient, diff --git a/backend/internal/service/backup_service.go b/backend/internal/service/backup_service.go index 53ae888b..d6cfaef6 100644 --- a/backend/internal/service/backup_service.go +++ b/backend/internal/service/backup_service.go @@ -1,23 +1,16 @@ package service import ( - "bytes" "compress/gzip" "context" "encoding/json" "fmt" "io" - "os/exec" "sort" "strings" "sync" "time" - "github.com/aws/aws-sdk-go-v2/aws" - v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" - awsconfig "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/google/uuid" "github.com/robfig/cron/v3" @@ -39,8 +32,32 @@ var ( ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found") ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress") ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress") + ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted") + ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted") ) +// ─── 接口定义 ─── + +// DBDumper abstracts database dump/restore operations +type DBDumper interface { + Dump(ctx context.Context) (io.ReadCloser, error) + Restore(ctx context.Context, data io.Reader) error +} + +// BackupObjectStore abstracts object storage for backup files +type BackupObjectStore interface { + Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error) + Download(ctx context.Context, key string) (io.ReadCloser, error) + Delete(ctx context.Context, key string) error + PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) + HeadBucket(ctx context.Context) error +} + +// BackupObjectStoreFactory creates an object store from S3 config +type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) + +// ─── 数据模型 ─── + // BackupS3Config S3 兼容存储配置(支持 Cloudflare R2) type BackupS3Config struct { Endpoint string `json:"endpoint"` // e.g. https://.r2.cloudflarestorage.com @@ -82,26 +99,39 @@ type BackupRecord struct { // BackupService 数据库备份恢复服务 type BackupService struct { - settingRepo SettingRepository - dbCfg *config.DatabaseConfig + settingRepo SettingRepository + dbCfg *config.DatabaseConfig + encryptor SecretEncryptor + storeFactory BackupObjectStoreFactory + dumper DBDumper mu sync.Mutex - s3Client *s3.Client + store BackupObjectStore s3Cfg *BackupS3Config backingUp bool restoring bool + recordsMu sync.Mutex // 保护 records 的 load/save 操作 + cronMu sync.Mutex cronSched *cron.Cron cronEntryID cron.EntryID } -func NewBackupService(settingRepo SettingRepository, cfg *config.Config) *BackupService { - svc := &BackupService{ - settingRepo: settingRepo, - dbCfg: &cfg.Database, +func NewBackupService( + settingRepo SettingRepository, + cfg *config.Config, + encryptor SecretEncryptor, + storeFactory BackupObjectStoreFactory, + dumper DBDumper, +) *BackupService { + return &BackupService{ + settingRepo: settingRepo, + dbCfg: &cfg.Database, + encryptor: encryptor, + storeFactory: storeFactory, + dumper: dumper, } - return svc } // Start 启动定时备份调度器 @@ -136,17 +166,16 @@ func (s *BackupService) Stop() { // ─── S3 配置管理 ─── func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) { - raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config) - if err != nil || raw == "" { - return &BackupS3Config{}, nil + cfg, err := s.loadS3Config(ctx) + if err != nil { + return nil, err } - var cfg BackupS3Config - if err := json.Unmarshal([]byte(raw), &cfg); err != nil { + if cfg == nil { return &BackupS3Config{}, nil } // 脱敏返回 cfg.SecretAccessKey = "" - return &cfg, nil + return cfg, nil } func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) { @@ -156,6 +185,13 @@ func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) if old != nil { cfg.SecretAccessKey = old.SecretAccessKey } + } else { + // 加密 SecretAccessKey + encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey) + if err != nil { + return nil, fmt.Errorf("encrypt secret: %w", err) + } + cfg.SecretAccessKey = encrypted } data, err := json.Marshal(cfg) @@ -168,7 +204,7 @@ func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) // 清除缓存的 S3 客户端 s.mu.Lock() - s.s3Client = nil + s.store = nil s.s3Cfg = nil s.mu.Unlock() @@ -189,17 +225,11 @@ func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required") } - client, err := s.buildS3Client(ctx, &cfg) + store, err := s.storeFactory(ctx, &cfg) if err != nil { return err } - _, err = client.HeadBucket(ctx, &s3.HeadBucketInput{ - Bucket: &cfg.Bucket, - }) - if err != nil { - return fmt.Errorf("S3 HeadBucket failed: %w", err) - } - return nil + return store.HeadBucket(ctx) } // ─── 定时备份管理 ─── @@ -313,7 +343,7 @@ func (s *BackupService) runScheduledBackup() { // ─── 备份/恢复核心 ─── -// CreateBackup 创建全量数据库备份并上传到 S3 +// CreateBackup 创建全量数据库备份并上传到 S3(流式处理) // expireDays: 备份过期天数,0=永不过期,默认14天 func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) { s.mu.Lock() @@ -337,9 +367,9 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex return nil, ErrBackupS3NotConfigured } - client, err := s.getOrCreateS3Client(ctx, s3Cfg) + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) if err != nil { - return nil, fmt.Errorf("init S3 client: %w", err) + return nil, fmt.Errorf("init object store: %w", err) } now := time.Now() @@ -363,8 +393,8 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex ExpiresAt: expiresAt, } - // 执行全量 pg_dump - dumpData, err := s.pgDump(ctx) + // 流式执行: pg_dump -> gzip -> S3 upload + dumpReader, err := s.dumper.Dump(ctx) if err != nil { record.Status = "failed" record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err) @@ -373,38 +403,40 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex return record, fmt.Errorf("pg_dump: %w", err) } - // gzip 压缩 - var compressed bytes.Buffer - gzWriter := gzip.NewWriter(&compressed) - if _, err := gzWriter.Write(dumpData); err != nil { - record.Status = "failed" - record.ErrorMsg = fmt.Sprintf("gzip failed: %v", err) - record.FinishedAt = time.Now().Format(time.RFC3339) - _ = s.saveRecord(ctx, record) - return record, fmt.Errorf("gzip: %w", err) - } - if err := gzWriter.Close(); err != nil { - return nil, fmt.Errorf("gzip close: %w", err) - } + // 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传 + pr, pw := io.Pipe() + var gzipErr error + go func() { + gzWriter := gzip.NewWriter(pw) + _, gzipErr = io.Copy(gzWriter, dumpReader) + if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil { + gzipErr = closeErr + } + if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil { + gzipErr = closeErr + } + if gzipErr != nil { + _ = pw.CloseWithError(gzipErr) + } else { + _ = pw.Close() + } + }() - record.SizeBytes = int64(compressed.Len()) - - // 上传到 S3 contentType := "application/gzip" - _, err = client.PutObject(ctx, &s3.PutObjectInput{ - Bucket: &s3Cfg.Bucket, - Key: &s3Key, - Body: bytes.NewReader(compressed.Bytes()), - ContentType: &contentType, - }) + sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType) if err != nil { record.Status = "failed" - record.ErrorMsg = fmt.Sprintf("S3 upload failed: %v", err) + errMsg := fmt.Sprintf("S3 upload failed: %v", err) + if gzipErr != nil { + errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr) + } + record.ErrorMsg = errMsg record.FinishedAt = time.Now().Format(time.RFC3339) _ = s.saveRecord(ctx, record) - return record, fmt.Errorf("s3 upload: %w", err) + return record, fmt.Errorf("backup upload: %w", err) } + record.SizeBytes = sizeBytes record.Status = "completed" record.FinishedAt = time.Now().Format(time.RFC3339) if err := s.saveRecord(ctx, record); err != nil { @@ -414,7 +446,7 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex return record, nil } -// RestoreBackup 从 S3 下载备份并恢复到数据库 +// RestoreBackup 从 S3 下载备份并流式恢复到数据库 func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error { s.mu.Lock() if s.restoring { @@ -441,35 +473,27 @@ func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) erro if err != nil { return err } - client, err := s.getOrCreateS3Client(ctx, s3Cfg) + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) if err != nil { - return fmt.Errorf("init S3 client: %w", err) + return fmt.Errorf("init object store: %w", err) } - // 从 S3 下载 - result, err := client.GetObject(ctx, &s3.GetObjectInput{ - Bucket: &s3Cfg.Bucket, - Key: &record.S3Key, - }) + // 从 S3 流式下载 + body, err := objectStore.Download(ctx, record.S3Key) if err != nil { return fmt.Errorf("S3 download failed: %w", err) } - defer func() { _ = result.Body.Close() }() + defer func() { _ = body.Close() }() - // 解压 gzip - gzReader, err := gzip.NewReader(result.Body) + // 流式解压 gzip -> psql(不将全部数据加载到内存) + gzReader, err := gzip.NewReader(body) if err != nil { return fmt.Errorf("gzip reader: %w", err) } defer func() { _ = gzReader.Close() }() - sqlData, err := io.ReadAll(gzReader) - if err != nil { - return fmt.Errorf("read backup data: %w", err) - } - - // 执行 psql 恢复 - if err := s.pgRestore(ctx, sqlData); err != nil { + // 流式恢复 + if err := s.dumper.Restore(ctx, gzReader); err != nil { return fmt.Errorf("pg restore: %w", err) } @@ -504,7 +528,10 @@ func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (* } func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error { - records, err := s.loadRecords(ctx) + s.recordsMu.Lock() + defer s.recordsMu.Unlock() + + records, err := s.loadRecordsLocked(ctx) if err != nil { return err } @@ -526,17 +553,14 @@ func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error if found.S3Key != "" && found.Status == "completed" { s3Cfg, err := s.loadS3Config(ctx) if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() { - client, err := s.getOrCreateS3Client(ctx, s3Cfg) + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) if err == nil { - _, _ = client.DeleteObject(ctx, &s3.DeleteObjectInput{ - Bucket: &s3Cfg.Bucket, - Key: &found.S3Key, - }) + _ = objectStore.Delete(ctx, found.S3Key) } } } - return s.saveRecords(ctx, remaining) + return s.saveRecordsLocked(ctx, remaining) } // GetBackupDownloadURL 获取备份文件预签名下载 URL @@ -553,20 +577,16 @@ func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID strin if err != nil { return "", err } - client, err := s.getOrCreateS3Client(ctx, s3Cfg) + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) if err != nil { return "", err } - presignClient := s3.NewPresignClient(client) - result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ - Bucket: &s3Cfg.Bucket, - Key: &record.S3Key, - }, s3.WithPresignExpires(1*time.Hour)) + url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour) if err != nil { return "", fmt.Errorf("presign url: %w", err) } - return result.URL, nil + return url, nil } // ─── 内部方法 ─── @@ -574,63 +594,44 @@ func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID strin func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) { raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config) if err != nil || raw == "" { - return nil, nil + return nil, nil //nolint:nilnil // no config is a valid state } var cfg BackupS3Config if err := json.Unmarshal([]byte(raw), &cfg); err != nil { - return nil, nil + return nil, ErrBackupS3ConfigCorrupt + } + // 解密 SecretAccessKey + if cfg.SecretAccessKey != "" { + decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey) + if err != nil { + // 兼容未加密的旧数据:如果解密失败,保持原值 + logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err) + } else { + cfg.SecretAccessKey = decrypted + } } return &cfg, nil } -func (s *BackupService) buildS3Client(ctx context.Context, cfg *BackupS3Config) (*s3.Client, error) { - region := cfg.Region - if region == "" { - region = "auto" // Cloudflare R2 默认 region - } - - awsCfg, err := awsconfig.LoadDefaultConfig(ctx, - awsconfig.WithRegion(region), - awsconfig.WithCredentialsProvider( - credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""), - ), - ) - if err != nil { - return nil, fmt.Errorf("load aws config: %w", err) - } - - client := s3.NewFromConfig(awsCfg, func(o *s3.Options) { - if cfg.Endpoint != "" { - o.BaseEndpoint = &cfg.Endpoint - } - if cfg.ForcePathStyle { - o.UsePathStyle = true - } - o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware) - o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired - }) - return client, nil -} - -func (s *BackupService) getOrCreateS3Client(ctx context.Context, cfg *BackupS3Config) (*s3.Client, error) { +func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) { s.mu.Lock() defer s.mu.Unlock() - if s.s3Client != nil && s.s3Cfg != nil { - return s.s3Client, nil + if s.store != nil && s.s3Cfg != nil { + return s.store, nil } if cfg == nil { return nil, ErrBackupS3NotConfigured } - client, err := s.buildS3Client(ctx, cfg) + store, err := s.storeFactory(ctx, cfg) if err != nil { return nil, err } - s.s3Client = client + s.store = store s.s3Cfg = cfg - return client, nil + return store, nil } func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string { @@ -641,76 +642,34 @@ func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName) } -func (s *BackupService) pgDump(ctx context.Context) ([]byte, error) { - args := []string{ - "-h", s.dbCfg.Host, - "-p", fmt.Sprintf("%d", s.dbCfg.Port), - "-U", s.dbCfg.User, - "-d", s.dbCfg.DBName, - "--no-owner", - "--no-acl", - "--clean", - "--if-exists", - } - - cmd := exec.CommandContext(ctx, "pg_dump", args...) - if s.dbCfg.Password != "" { - cmd.Env = append(cmd.Environ(), "PGPASSWORD="+s.dbCfg.Password) - } - if s.dbCfg.SSLMode != "" { - cmd.Env = append(cmd.Environ(), "PGSSLMODE="+s.dbCfg.SSLMode) - } - - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - return nil, fmt.Errorf("%v: %s", err, stderr.String()) - } - return stdout.Bytes(), nil -} - -func (s *BackupService) pgRestore(ctx context.Context, sqlData []byte) error { - args := []string{ - "-h", s.dbCfg.Host, - "-p", fmt.Sprintf("%d", s.dbCfg.Port), - "-U", s.dbCfg.User, - "-d", s.dbCfg.DBName, - "--single-transaction", - } - - cmd := exec.CommandContext(ctx, "psql", args...) - if s.dbCfg.Password != "" { - cmd.Env = append(cmd.Environ(), "PGPASSWORD="+s.dbCfg.Password) - } - if s.dbCfg.SSLMode != "" { - cmd.Env = append(cmd.Environ(), "PGSSLMODE="+s.dbCfg.SSLMode) - } - - cmd.Stdin = bytes.NewReader(sqlData) - var stderr bytes.Buffer - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - return fmt.Errorf("%v: %s", err, stderr.String()) - } - return nil -} - +// loadRecords 加载备份记录,区分"无数据"和"数据损坏" func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) { + s.recordsMu.Lock() + defer s.recordsMu.Unlock() + return s.loadRecordsLocked(ctx) +} + +// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录 +func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) { raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords) if err != nil || raw == "" { - return nil, nil + return nil, nil //nolint:nilnil // no records is a valid state } var records []BackupRecord if err := json.Unmarshal([]byte(raw), &records); err != nil { - return nil, nil + return nil, ErrBackupRecordsCorrupt } return records, nil } func (s *BackupService) saveRecords(ctx context.Context, records []BackupRecord) error { + s.recordsMu.Lock() + defer s.recordsMu.Unlock() + return s.saveRecordsLocked(ctx, records) +} + +// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录 +func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error { data, err := json.Marshal(records) if err != nil { return err @@ -718,8 +677,12 @@ func (s *BackupService) saveRecords(ctx context.Context, records []BackupRecord) return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data)) } +// saveRecord 保存单条记录(带互斥锁保护) func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error { - records, _ := s.loadRecords(ctx) + s.recordsMu.Lock() + defer s.recordsMu.Unlock() + + records, _ := s.loadRecordsLocked(ctx) // 更新已有记录或追加 found := false @@ -739,7 +702,7 @@ func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) er records = records[len(records)-maxBackupRecords:] } - return s.saveRecords(ctx, records) + return s.saveRecordsLocked(ctx, records) } func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error { @@ -747,7 +710,10 @@ func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupS return nil } - records, err := s.loadRecords(ctx) + s.recordsMu.Lock() + defer s.recordsMu.Unlock() + + records, err := s.loadRecordsLocked(ctx) if err != nil { return err } @@ -792,7 +758,7 @@ func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupS if len(toDelete) > 0 { logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete)) - return s.saveRecords(ctx, toKeep) + return s.saveRecordsLocked(ctx, toKeep) } return nil } @@ -802,13 +768,9 @@ func (s *BackupService) deleteS3Object(ctx context.Context, key string) error { if err != nil || s3Cfg == nil { return nil } - client, err := s.getOrCreateS3Client(ctx, s3Cfg) + objectStore, err := s.getOrCreateStore(ctx, s3Cfg) if err != nil { return err } - _, err = client.DeleteObject(ctx, &s3.DeleteObjectInput{ - Bucket: &s3Cfg.Bucket, - Key: &key, - }) - return err + return objectStore.Delete(ctx, key) } diff --git a/backend/internal/service/backup_service_test.go b/backend/internal/service/backup_service_test.go new file mode 100644 index 00000000..e752997c --- /dev/null +++ b/backend/internal/service/backup_service_test.go @@ -0,0 +1,528 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +// ─── Mocks ─── + +type mockSettingRepo struct { + mu sync.Mutex + data map[string]string +} + +func newMockSettingRepo() *mockSettingRepo { + return &mockSettingRepo{data: make(map[string]string)} +} + +func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) { + m.mu.Lock() + defer m.mu.Unlock() + v, ok := m.data[key] + if !ok { + return nil, ErrSettingNotFound + } + return &Setting{Key: key, Value: v}, nil +} + +func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() + v, ok := m.data[key] + if !ok { + return "", nil + } + return v, nil +} + +func (m *mockSettingRepo) Set(_ context.Context, key, value string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.data[key] = value + return nil +} + +func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + result := make(map[string]string) + for _, k := range keys { + if v, ok := m.data[k]; ok { + result[k] = v + } + } + return result, nil +} + +func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error { + m.mu.Lock() + defer m.mu.Unlock() + for k, v := range settings { + m.data[k] = v + } + return nil +} + +func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + result := make(map[string]string, len(m.data)) + for k, v := range m.data { + result[k] = v + } + return result, nil +} + +func (m *mockSettingRepo) Delete(_ context.Context, key string) error { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.data, key) + return nil +} + +// plainEncryptor 仅做 base64-like 包装,用于测试 +type plainEncryptor struct{} + +func (e *plainEncryptor) Encrypt(plaintext string) (string, error) { + return "ENC:" + plaintext, nil +} + +func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) { + if strings.HasPrefix(ciphertext, "ENC:") { + return strings.TrimPrefix(ciphertext, "ENC:"), nil + } + return ciphertext, fmt.Errorf("not encrypted") +} + +type mockDumper struct { + dumpData []byte + dumpErr error + restored []byte + restErr error +} + +func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) { + if m.dumpErr != nil { + return nil, m.dumpErr + } + return io.NopCloser(bytes.NewReader(m.dumpData)), nil +} + +func (m *mockDumper) Restore(_ context.Context, data io.Reader) error { + if m.restErr != nil { + return m.restErr + } + d, err := io.ReadAll(data) + if err != nil { + return err + } + m.restored = d + return nil +} + +type mockObjectStore struct { + objects map[string][]byte + mu sync.Mutex +} + +func newMockObjectStore() *mockObjectStore { + return &mockObjectStore{objects: make(map[string][]byte)} +} + +func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) { + data, err := io.ReadAll(body) + if err != nil { + return 0, err + } + m.mu.Lock() + m.objects[key] = data + m.mu.Unlock() + return int64(len(data)), nil +} + +func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) { + m.mu.Lock() + data, ok := m.objects[key] + m.mu.Unlock() + if !ok { + return nil, fmt.Errorf("not found: %s", key) + } + return io.NopCloser(bytes.NewReader(data)), nil +} + +func (m *mockObjectStore) Delete(_ context.Context, key string) error { + m.mu.Lock() + delete(m.objects, key) + m.mu.Unlock() + return nil +} + +func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) { + return "https://presigned.example.com/" + key, nil +} + +func (m *mockObjectStore) HeadBucket(_ context.Context) error { + return nil +} + +func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService { + cfg := &config.Config{ + Database: config.DatabaseConfig{ + Host: "localhost", + Port: 5432, + User: "test", + DBName: "testdb", + }, + } + factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) { + return store, nil + } + return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper) +} + +func seedS3Config(t *testing.T, repo *mockSettingRepo) { + t.Helper() + cfg := BackupS3Config{ + Bucket: "test-bucket", + AccessKeyID: "AKID", + SecretAccessKey: "ENC:secret123", + Prefix: "backups", + } + data, _ := json.Marshal(cfg) + require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data))) +} + +// ─── Tests ─── + +func TestBackupService_S3ConfigEncryption(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + // 保存配置 -> SecretAccessKey 应被加密 + _, err := svc.UpdateS3Config(context.Background(), BackupS3Config{ + Bucket: "my-bucket", + AccessKeyID: "AKID", + SecretAccessKey: "my-secret", + Prefix: "backups", + }) + require.NoError(t, err) + + // 直接读取数据库中存储的值,应该是加密后的 + raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config) + var stored BackupS3Config + require.NoError(t, json.Unmarshal([]byte(raw), &stored)) + require.Equal(t, "ENC:my-secret", stored.SecretAccessKey) + + // 通过 GetS3Config 获取应该脱敏 + cfg, err := svc.GetS3Config(context.Background()) + require.NoError(t, err) + require.Empty(t, cfg.SecretAccessKey) + require.Equal(t, "my-bucket", cfg.Bucket) + + // loadS3Config 内部应解密 + internal, err := svc.loadS3Config(context.Background()) + require.NoError(t, err) + require.Equal(t, "my-secret", internal.SecretAccessKey) +} + +func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + // 先保存一个有 secret 的配置 + _, err := svc.UpdateS3Config(context.Background(), BackupS3Config{ + Bucket: "my-bucket", + AccessKeyID: "AKID", + SecretAccessKey: "original-secret", + }) + require.NoError(t, err) + + // 再更新时不提供 secret,应保留原值 + _, err = svc.UpdateS3Config(context.Background(), BackupS3Config{ + Bucket: "my-bucket", + AccessKeyID: "AKID-NEW", + }) + require.NoError(t, err) + + internal, err := svc.loadS3Config(context.Background()) + require.NoError(t, err) + require.Equal(t, "original-secret", internal.SecretAccessKey) + require.Equal(t, "AKID-NEW", internal.AccessKeyID) +} + +func TestBackupService_SaveRecordConcurrency(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + var wg sync.WaitGroup + n := 20 + wg.Add(n) + for i := 0; i < n; i++ { + go func(idx int) { + defer wg.Done() + record := &BackupRecord{ + ID: fmt.Sprintf("rec-%d", idx), + Status: "completed", + StartedAt: time.Now().Format(time.RFC3339), + } + _ = svc.saveRecord(context.Background(), record) + }(i) + } + wg.Wait() + + records, err := svc.loadRecords(context.Background()) + require.NoError(t, err) + require.Len(t, records, n) +} + +func TestBackupService_LoadRecords_Empty(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + records, err := svc.loadRecords(context.Background()) + require.NoError(t, err) + require.Nil(t, records) // 无数据时返回 nil +} + +func TestBackupService_LoadRecords_Corrupted(t *testing.T) { + repo := newMockSettingRepo() + _ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{") + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + records, err := svc.loadRecords(context.Background()) + require.Error(t, err) // 损坏数据应返回错误 + require.Nil(t, records) +} + +func TestBackupService_CreateBackup_Streaming(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n" + dumper := &mockDumper{dumpData: []byte(dumpContent)} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + record, err := svc.CreateBackup(context.Background(), "manual", 14) + require.NoError(t, err) + require.Equal(t, "completed", record.Status) + require.Greater(t, record.SizeBytes, int64(0)) + require.NotEmpty(t, record.S3Key) + + // 验证 S3 上确实有文件 + store.mu.Lock() + require.Len(t, store.objects, 1) + store.mu.Unlock() +} + +func TestBackupService_CreateBackup_DumpFailure(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + record, err := svc.CreateBackup(context.Background(), "manual", 14) + require.Error(t, err) + require.Equal(t, "failed", record.Status) + require.Contains(t, record.ErrorMsg, "pg_dump") +} + +func TestBackupService_CreateBackup_NoS3Config(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + _, err := svc.CreateBackup(context.Background(), "manual", 14) + require.ErrorIs(t, err, ErrBackupS3NotConfigured) +} + +func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + // 使用一个慢速 dumper 来模拟正在进行的备份 + dumper := &mockDumper{dumpData: []byte("data")} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + // 手动设置 backingUp 标志 + svc.mu.Lock() + svc.backingUp = true + svc.mu.Unlock() + + _, err := svc.CreateBackup(context.Background(), "manual", 14) + require.ErrorIs(t, err, ErrBackupInProgress) +} + +func TestBackupService_RestoreBackup_Streaming(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n" + dumper := &mockDumper{dumpData: []byte(dumpContent)} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + // 先创建一个备份 + record, err := svc.CreateBackup(context.Background(), "manual", 14) + require.NoError(t, err) + + // 恢复 + err = svc.RestoreBackup(context.Background(), record.ID) + require.NoError(t, err) + + // 验证 psql 收到的数据是否与原始 dump 内容一致 + require.Equal(t, dumpContent, string(dumper.restored)) +} + +func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + // 手动插入一条 failed 记录 + _ = svc.saveRecord(context.Background(), &BackupRecord{ + ID: "fail-1", + Status: "failed", + }) + + err := svc.RestoreBackup(context.Background(), "fail-1") + require.Error(t, err) +} + +func TestBackupService_DeleteBackup(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumpContent := "data" + dumper := &mockDumper{dumpData: []byte(dumpContent)} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + record, err := svc.CreateBackup(context.Background(), "manual", 14) + require.NoError(t, err) + + // S3 中应有文件 + store.mu.Lock() + require.Len(t, store.objects, 1) + store.mu.Unlock() + + // 删除 + err = svc.DeleteBackup(context.Background(), record.ID) + require.NoError(t, err) + + // S3 中文件应被删除 + store.mu.Lock() + require.Len(t, store.objects, 0) + store.mu.Unlock() + + // 记录应不存在 + _, err = svc.GetBackupRecord(context.Background(), record.ID) + require.ErrorIs(t, err, ErrBackupNotFound) +} + +func TestBackupService_GetDownloadURL(t *testing.T) { + repo := newMockSettingRepo() + seedS3Config(t, repo) + + dumper := &mockDumper{dumpData: []byte("data")} + store := newMockObjectStore() + svc := newTestBackupService(repo, dumper, store) + + record, err := svc.CreateBackup(context.Background(), "manual", 14) + require.NoError(t, err) + + url, err := svc.GetBackupDownloadURL(context.Background(), record.ID) + require.NoError(t, err) + require.Contains(t, url, "https://presigned.example.com/") +} + +func TestBackupService_ListBackups_Sorted(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + now := time.Now() + for i := 0; i < 3; i++ { + _ = svc.saveRecord(context.Background(), &BackupRecord{ + ID: fmt.Sprintf("rec-%d", i), + Status: "completed", + StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339), + }) + } + + records, err := svc.ListBackups(context.Background()) + require.NoError(t, err) + require.Len(t, records, 3) + // 最新在前 + require.Equal(t, "rec-2", records[0].ID) + require.Equal(t, "rec-0", records[2].ID) +} + +func TestBackupService_TestS3Connection(t *testing.T) { + repo := newMockSettingRepo() + store := newMockObjectStore() + svc := newTestBackupService(repo, &mockDumper{}, store) + + err := svc.TestS3Connection(context.Background(), BackupS3Config{ + Bucket: "test", + AccessKeyID: "ak", + SecretAccessKey: "sk", + }) + require.NoError(t, err) +} + +func TestBackupService_TestS3Connection_Incomplete(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + err := svc.TestS3Connection(context.Background(), BackupS3Config{ + Bucket: "test", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "incomplete") +} + +func TestBackupService_Schedule_CronValidation(t *testing.T) { + repo := newMockSettingRepo() + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + svc.cronSched = nil // 未初始化 cron + + // 启用但 cron 为空 + _, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{ + Enabled: true, + CronExpr: "", + }) + require.Error(t, err) + + // 无效的 cron 表达式 + _, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{ + Enabled: true, + CronExpr: "invalid", + }) + require.Error(t, err) +} + +func TestBackupService_LoadS3Config_Corrupted(t *testing.T) { + repo := newMockSettingRepo() + _ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!") + svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore()) + + cfg, err := svc.loadS3Config(context.Background()) + require.Error(t, err) + require.Nil(t, cfg) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 4ae06731..3d2d5d68 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -323,8 +323,14 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC } // ProvideBackupService creates and starts BackupService -func ProvideBackupService(settingRepo SettingRepository, cfg *config.Config) *BackupService { - svc := NewBackupService(settingRepo, cfg) +func ProvideBackupService( + settingRepo SettingRepository, + cfg *config.Config, + encryptor SecretEncryptor, + storeFactory BackupObjectStoreFactory, + dumper DBDumper, +) *BackupService { + svc := NewBackupService(settingRepo, cfg, encryptor, storeFactory, dumper) svc.Start() return svc } diff --git a/frontend/src/api/admin/backup.ts b/frontend/src/api/admin/backup.ts index eff70492..d349c862 100644 --- a/frontend/src/api/admin/backup.ts +++ b/frontend/src/api/admin/backup.ts @@ -93,8 +93,8 @@ export async function getDownloadURL(id: string): Promise<{ url: string }> { } // Restore -export async function restoreBackup(id: string): Promise { - await apiClient.post(`/admin/backups/${id}/restore`, {}, { timeout: 600000 }) +export async function restoreBackup(id: string, password: string): Promise { + await apiClient.post(`/admin/backups/${id}/restore`, { password }, { timeout: 600000 }) } export const backupAPI = { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 40d1a8eb..9f609d72 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1034,6 +1034,7 @@ export default { download: 'Download', restore: 'Restore', restoreConfirm: 'Are you sure you want to restore from this backup? This will overwrite the current database!', + restorePasswordPrompt: 'Please enter your admin password to confirm the restore operation', restoreSuccess: 'Database restored successfully', deleteConfirm: 'Are you sure you want to delete this backup?', deleted: 'Backup deleted' diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index b276f059..d443783e 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1056,6 +1056,7 @@ export default { download: '下载', restore: '恢复', restoreConfirm: '确定要从此备份恢复吗?这将覆盖当前数据库!', + restorePasswordPrompt: '请输入管理员密码以确认恢复操作', restoreSuccess: '数据库恢复成功', deleteConfirm: '确定要删除此备份吗?', deleted: '备份已删除' diff --git a/frontend/src/views/admin/BackupView.vue b/frontend/src/views/admin/BackupView.vue index 2c54f365..ae10f42c 100644 --- a/frontend/src/views/admin/BackupView.vue +++ b/frontend/src/views/admin/BackupView.vue @@ -440,9 +440,11 @@ async function downloadBackup(id: string) { async function restoreBackup(id: string) { if (!window.confirm(t('admin.backup.actions.restoreConfirm'))) return + const password = window.prompt(t('admin.backup.actions.restorePasswordPrompt')) + if (!password) return restoringId.value = id try { - await adminAPI.backup.restoreBackup(id) + await adminAPI.backup.restoreBackup(id, password) appStore.showSuccess(t('admin.backup.actions.restoreSuccess')) } catch (error) { appStore.showError((error as { message?: string })?.message || t('errors.networkError'))