feat(backup): 备份/恢复异步化,解决 504 超时

POST /backups 和 POST /backups/:id/restore 改为异步:立即返回 HTTP 202,
后台 goroutine 独立执行 pg_dump → gzip → S3 上传,前端每 2s 轮询状态。

后端:
- 新增 StartBackup/StartRestore 方法,后台 goroutine 不依赖 HTTP 连接
- Graceful shutdown 等待活跃操作完成,启动时清理孤立 running 记录
- BackupRecord 新增 progress/restore_status 字段支持进度和恢复状态追踪

前端:
- 创建备份/恢复后轮询 GET /backups/:id 直到完成或失败
- 标签页切换暂停/恢复轮询,组件卸载清理定时器
- 正确处理 409(备份进行中)和轮询超时

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
QTom
2026-03-16 20:03:08 +08:00
parent f42c8f2abe
commit c1fab7f8d8
9 changed files with 780 additions and 68 deletions

View File

@@ -98,12 +98,12 @@ func (h *BackupHandler) CreateBackup(c *gin.Context) {
expireDays = *req.ExpireDays
}
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
record, err := h.backupService.StartBackup(c.Request.Context(), "manual", expireDays)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, record)
response.Accepted(c, record)
}
func (h *BackupHandler) ListBackups(c *gin.Context) {
@@ -196,9 +196,10 @@ func (h *BackupHandler) RestoreBackup(c *gin.Context) {
return
}
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
record, err := h.backupService.StartRestore(c.Request.Context(), backupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"restored": true})
response.Accepted(c, record)
}

View File

@@ -47,6 +47,15 @@ func Created(c *gin.Context, data any) {
})
}
// Accepted 返回异步接受响应 (HTTP 202)
func Accepted(c *gin.Context, data any) {
c.JSON(http.StatusAccepted, Response{
Code: 0,
Message: "accepted",
Data: data,
})
}
// Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{

View File

@@ -57,6 +57,7 @@ func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
// 读取全部内容以获取大小S3 PutObject 需要知道内容长度)
// 注意:阿里云 OSS 不兼容 s3manager 分片上传的签名方式,因此使用 PutObject
data, err := io.ReadAll(body)
if err != nil {
return 0, fmt.Errorf("read body: %w", err)

View File

@@ -4,11 +4,13 @@ import (
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -84,17 +86,21 @@ type BackupScheduleConfig struct {
// BackupRecord 备份记录
type BackupRecord struct {
ID string `json:"id"`
Status string `json:"status"` // pending, running, completed, failed
BackupType string `json:"backup_type"` // postgres
FileName string `json:"file_name"`
S3Key string `json:"s3_key"`
SizeBytes int64 `json:"size_bytes"`
TriggeredBy string `json:"triggered_by"` // manual, scheduled
ErrorMsg string `json:"error_message,omitempty"`
StartedAt string `json:"started_at"`
FinishedAt string `json:"finished_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
ID string `json:"id"`
Status string `json:"status"` // pending, running, completed, failed
BackupType string `json:"backup_type"` // postgres
FileName string `json:"file_name"`
S3Key string `json:"s3_key"`
SizeBytes int64 `json:"size_bytes"`
TriggeredBy string `json:"triggered_by"` // manual, scheduled
ErrorMsg string `json:"error_message,omitempty"`
StartedAt string `json:"started_at"`
FinishedAt string `json:"finished_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
Progress string `json:"progress,omitempty"` // "dumping", "uploading", ""
RestoreStatus string `json:"restore_status,omitempty"` // "", "running", "completed", "failed"
RestoreError string `json:"restore_error,omitempty"`
RestoredAt string `json:"restored_at,omitempty"`
}
// BackupService 数据库备份恢复服务
@@ -105,17 +111,24 @@ type BackupService struct {
storeFactory BackupObjectStoreFactory
dumper DBDumper
mu sync.Mutex
store BackupObjectStore
s3Cfg *BackupS3Config
opMu sync.Mutex // 保护 backingUp/restoring 标志
backingUp bool
restoring bool
storeMu sync.Mutex // 保护 store/s3Cfg 缓存
store BackupObjectStore
s3Cfg *BackupS3Config
recordsMu sync.Mutex // 保护 records 的 load/save 操作
cronMu sync.Mutex
cronSched *cron.Cron
cronEntryID cron.EntryID
wg sync.WaitGroup // 追踪活跃的备份/恢复 goroutine
shuttingDown atomic.Bool // 阻止新备份启动
bgCtx context.Context // 所有后台操作的 parent context
bgCancel context.CancelFunc // 取消所有活跃后台操作
}
func NewBackupService(
@@ -125,20 +138,26 @@ func NewBackupService(
storeFactory BackupObjectStoreFactory,
dumper DBDumper,
) *BackupService {
bgCtx, bgCancel := context.WithCancel(context.Background())
return &BackupService{
settingRepo: settingRepo,
dbCfg: &cfg.Database,
encryptor: encryptor,
storeFactory: storeFactory,
dumper: dumper,
bgCtx: bgCtx,
bgCancel: bgCancel,
}
}
// Start 启动定时备份调度器
// Start 启动定时备份调度器并清理孤立记录
func (s *BackupService) Start() {
s.cronSched = cron.New()
s.cronSched.Start()
// 清理重启后孤立的 running 记录
s.recoverStaleRecords()
// 加载已有的定时配置
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -154,13 +173,65 @@ func (s *BackupService) Start() {
}
}
// Stop 停止定时备份
// recoverStaleRecords 启动时将孤立的 running 记录标记为 failed
func (s *BackupService) recoverStaleRecords() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
records, err := s.loadRecords(ctx)
if err != nil {
return
}
for i := range records {
if records[i].Status == "running" {
records[i].Status = "failed"
records[i].ErrorMsg = "interrupted by server restart"
records[i].Progress = ""
records[i].FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, &records[i])
logger.LegacyPrintf("service.backup", "[Backup] recovered stale running record: %s", records[i].ID)
}
if records[i].RestoreStatus == "running" {
records[i].RestoreStatus = "failed"
records[i].RestoreError = "interrupted by server restart"
_ = s.saveRecord(ctx, &records[i])
logger.LegacyPrintf("service.backup", "[Backup] recovered stale restoring record: %s", records[i].ID)
}
}
}
// Stop 停止定时备份并等待活跃操作完成
func (s *BackupService) Stop() {
s.shuttingDown.Store(true)
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched != nil {
s.cronSched.Stop()
}
s.cronMu.Unlock()
// 等待活跃备份/恢复完成(最多 5 分钟)
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
logger.LegacyPrintf("service.backup", "[Backup] all active operations finished")
case <-time.After(5 * time.Minute):
logger.LegacyPrintf("service.backup", "[Backup] shutdown timeout after 5min, cancelling active operations")
if s.bgCancel != nil {
s.bgCancel() // 取消所有后台操作
}
// 给 goroutine 时间响应取消并完成清理
select {
case <-done:
logger.LegacyPrintf("service.backup", "[Backup] active operations cancelled and cleaned up")
case <-time.After(10 * time.Second):
logger.LegacyPrintf("service.backup", "[Backup] goroutine cleanup timed out")
}
}
}
// ─── S3 配置管理 ───
@@ -203,10 +274,10 @@ func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config)
}
// 清除缓存的 S3 客户端
s.mu.Lock()
s.storeMu.Lock()
s.store = nil
s.s3Cfg = nil
s.mu.Unlock()
s.storeMu.Unlock()
cfg.SecretAccessKey = ""
return &cfg, nil
@@ -314,7 +385,10 @@ func (s *BackupService) removeCronSchedule() {
}
func (s *BackupService) runScheduledBackup() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
s.wg.Add(1)
defer s.wg.Done()
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
defer cancel()
// 读取定时备份配置中的过期天数
@@ -327,7 +401,11 @@ func (s *BackupService) runScheduledBackup() {
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
if err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
if errors.Is(err, ErrBackupInProgress) {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份跳过: 已有备份正在进行中")
} else {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
}
return
}
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
@@ -346,17 +424,21 @@ func (s *BackupService) runScheduledBackup() {
// CreateBackup 创建全量数据库备份并上传到 S3流式处理
// expireDays: 备份过期天数0=永不过期默认14天
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
s.mu.Lock()
if s.shuttingDown.Load() {
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
}
s.opMu.Lock()
if s.backingUp {
s.mu.Unlock()
s.opMu.Unlock()
return nil, ErrBackupInProgress
}
s.backingUp = true
s.mu.Unlock()
s.opMu.Unlock()
defer func() {
s.mu.Lock()
s.opMu.Lock()
s.backingUp = false
s.mu.Unlock()
s.opMu.Unlock()
}()
s3Cfg, err := s.loadS3Config(ctx)
@@ -405,36 +487,47 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
pr, pw := io.Pipe()
var gzipErr error
gzipDone := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck
gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r)
}
}()
gzWriter := gzip.NewWriter(pw)
_, gzipErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
var gzErr error
_, gzErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if gzipErr != nil {
_ = pw.CloseWithError(gzipErr)
if gzErr != nil {
_ = pw.CloseWithError(gzErr)
} else {
_ = pw.Close()
}
gzipDone <- gzErr
}()
contentType := "application/gzip"
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
if err != nil {
_ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂
gzErr := <-gzipDone // 安全等待 gzip goroutine 完成
record.Status = "failed"
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
if gzipErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
if gzErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr)
}
record.ErrorMsg = errMsg
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, record)
return record, fmt.Errorf("backup upload: %w", err)
}
<-gzipDone // 确保 gzip goroutine 已退出
record.SizeBytes = sizeBytes
record.Status = "completed"
@@ -446,19 +539,187 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
return record, nil
}
// StartBackup 异步创建备份,立即返回 running 状态的记录
func (s *BackupService) StartBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
if s.shuttingDown.Load() {
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
}
s.opMu.Lock()
if s.backingUp {
s.opMu.Unlock()
return nil, ErrBackupInProgress
}
s.backingUp = true
s.opMu.Unlock()
// 初始化阶段出错时自动重置标志
launched := false
defer func() {
if !launched {
s.opMu.Lock()
s.backingUp = false
s.opMu.Unlock()
}
}()
// 在返回前加载 S3 配置和创建 store避免 goroutine 中配置被修改
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
if s3Cfg == nil || !s3Cfg.IsConfigured() {
return nil, ErrBackupS3NotConfigured
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return nil, fmt.Errorf("init object store: %w", err)
}
now := time.Now()
backupID := uuid.New().String()[:8]
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
s3Key := s.buildS3Key(s3Cfg, fileName)
var expiresAt string
if expireDays > 0 {
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
}
record := &BackupRecord{
ID: backupID,
Status: "running",
BackupType: "postgres",
FileName: fileName,
S3Key: s3Key,
TriggeredBy: triggeredBy,
StartedAt: now.Format(time.RFC3339),
ExpiresAt: expiresAt,
Progress: "pending",
}
if err := s.saveRecord(ctx, record); err != nil {
return nil, fmt.Errorf("save initial record: %w", err)
}
launched = true
// 在启动 goroutine 前完成拷贝,避免数据竞争
result := *record
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
s.opMu.Lock()
s.backingUp = false
s.opMu.Unlock()
}()
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.backup", "[Backup] panic recovered: %v", r)
record.Status = "failed"
record.ErrorMsg = fmt.Sprintf("internal panic: %v", r)
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(context.Background(), record)
}
}()
s.executeBackup(record, objectStore)
}()
return &result, nil
}
// executeBackup 后台执行备份(独立于 HTTP context
func (s *BackupService) executeBackup(record *BackupRecord, objectStore BackupObjectStore) {
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
defer cancel()
// 阶段1: pg_dump
record.Progress = "dumping"
_ = s.saveRecord(ctx, record)
dumpReader, err := s.dumper.Dump(ctx)
if err != nil {
record.Status = "failed"
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(context.Background(), record)
return
}
// 阶段2: gzip + upload
record.Progress = "uploading"
_ = s.saveRecord(ctx, record)
pr, pw := io.Pipe()
gzipDone := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck
gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r)
}
}()
gzWriter := gzip.NewWriter(pw)
var gzErr error
_, gzErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if gzErr != nil {
_ = pw.CloseWithError(gzErr)
} else {
_ = pw.Close()
}
gzipDone <- gzErr
}()
contentType := "application/gzip"
sizeBytes, err := objectStore.Upload(ctx, record.S3Key, pr, contentType)
if err != nil {
_ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂
gzErr := <-gzipDone // 安全等待 gzip goroutine 完成
record.Status = "failed"
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
if gzErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr)
}
record.ErrorMsg = errMsg
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(context.Background(), record)
return
}
<-gzipDone // 确保 gzip goroutine 已退出
record.SizeBytes = sizeBytes
record.Status = "completed"
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
if err := s.saveRecord(context.Background(), record); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
}
}
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
s.mu.Lock()
s.opMu.Lock()
if s.restoring {
s.mu.Unlock()
s.opMu.Unlock()
return ErrRestoreInProgress
}
s.restoring = true
s.mu.Unlock()
s.opMu.Unlock()
defer func() {
s.mu.Lock()
s.opMu.Lock()
s.restoring = false
s.mu.Unlock()
s.opMu.Unlock()
}()
record, err := s.GetBackupRecord(ctx, backupID)
@@ -500,6 +761,112 @@ func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) erro
return nil
}
// StartRestore 异步恢复备份,立即返回
func (s *BackupService) StartRestore(ctx context.Context, backupID string) (*BackupRecord, error) {
if s.shuttingDown.Load() {
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
}
s.opMu.Lock()
if s.restoring {
s.opMu.Unlock()
return nil, ErrRestoreInProgress
}
s.restoring = true
s.opMu.Unlock()
// 初始化阶段出错时自动重置标志
launched := false
defer func() {
if !launched {
s.opMu.Lock()
s.restoring = false
s.opMu.Unlock()
}
}()
record, err := s.GetBackupRecord(ctx, backupID)
if err != nil {
return nil, err
}
if record.Status != "completed" {
return nil, infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
}
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return nil, fmt.Errorf("init object store: %w", err)
}
record.RestoreStatus = "running"
_ = s.saveRecord(ctx, record)
launched = true
result := *record
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
s.opMu.Lock()
s.restoring = false
s.opMu.Unlock()
}()
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.backup", "[Backup] restore panic recovered: %v", r)
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("internal panic: %v", r)
_ = s.saveRecord(context.Background(), record)
}
}()
s.executeRestore(record, objectStore)
}()
return &result, nil
}
// executeRestore 后台执行恢复
func (s *BackupService) executeRestore(record *BackupRecord, objectStore BackupObjectStore) {
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
defer cancel()
body, err := objectStore.Download(ctx, record.S3Key)
if err != nil {
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("S3 download failed: %v", err)
_ = s.saveRecord(context.Background(), record)
return
}
defer func() { _ = body.Close() }()
gzReader, err := gzip.NewReader(body)
if err != nil {
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("gzip reader: %v", err)
_ = s.saveRecord(context.Background(), record)
return
}
defer func() { _ = gzReader.Close() }()
if err := s.dumper.Restore(ctx, gzReader); err != nil {
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("pg restore: %v", err)
_ = s.saveRecord(context.Background(), record)
return
}
record.RestoreStatus = "completed"
record.RestoredAt = time.Now().Format(time.RFC3339)
if err := s.saveRecord(context.Background(), record); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 保存恢复记录失败: %v", err)
}
}
// ─── 备份记录管理 ───
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
@@ -614,8 +981,8 @@ func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, erro
}
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.storeMu.Lock()
defer s.storeMu.Unlock()
if s.store != nil && s.s3Cfg != nil {
return s.store, nil

View File

@@ -134,6 +134,30 @@ func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
return nil
}
// blockingDumper 可控延迟的 dumper用于测试异步行为
type blockingDumper struct {
blockCh chan struct{}
data []byte
restErr error
}
func (d *blockingDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
select {
case <-d.blockCh:
case <-ctx.Done():
return nil, ctx.Err()
}
return io.NopCloser(bytes.NewReader(d.data)), nil
}
func (d *blockingDumper) Restore(_ context.Context, data io.Reader) error {
if d.restErr != nil {
return d.restErr
}
_, _ = io.ReadAll(data)
return nil
}
type mockObjectStore struct {
objects map[string][]byte
mu sync.Mutex
@@ -179,7 +203,7 @@ func (m *mockObjectStore) HeadBucket(_ context.Context) error {
return nil
}
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
func newTestBackupService(repo *mockSettingRepo, dumper DBDumper, store *mockObjectStore) *BackupService {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
@@ -361,9 +385,9 @@ func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
svc := newTestBackupService(repo, dumper, store)
// 手动设置 backingUp 标志
svc.mu.Lock()
svc.opMu.Lock()
svc.backingUp = true
svc.mu.Unlock()
svc.opMu.Unlock()
_, err := svc.CreateBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupInProgress)
@@ -526,3 +550,154 @@ func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
require.Error(t, err)
require.Nil(t, cfg)
}
// ─── Async Backup Tests ───
func TestStartBackup_ReturnsImmediately(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.StartBackup(context.Background(), "manual", 14)
require.NoError(t, err)
require.Equal(t, "running", record.Status)
require.NotEmpty(t, record.ID)
// 释放 dumper 让后台完成
close(dumper.blockCh)
svc.wg.Wait()
// 验证最终状态
final, err := svc.GetBackupRecord(context.Background(), record.ID)
require.NoError(t, err)
require.Equal(t, "completed", final.Status)
require.Greater(t, final.SizeBytes, int64(0))
}
func TestStartBackup_ConcurrentBlocked(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 第一次启动
_, err := svc.StartBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// 第二次应被阻塞
_, err = svc.StartBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupInProgress)
close(dumper.blockCh)
svc.wg.Wait()
}
func TestStartBackup_ShuttingDown(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
svc := newTestBackupService(repo, &mockDumper{dumpData: []byte("data")}, newMockObjectStore())
svc.shuttingDown.Store(true)
_, err := svc.StartBackup(context.Background(), "manual", 14)
require.Error(t, err)
require.Contains(t, err.Error(), "shutting down")
}
func TestRecoverStaleRecords(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 模拟一条孤立的 running 记录
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: "stale-1",
Status: "running",
StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
})
// 模拟一条孤立的恢复中记录
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: "stale-2",
Status: "completed",
RestoreStatus: "running",
StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
})
svc.recoverStaleRecords()
r1, _ := svc.GetBackupRecord(context.Background(), "stale-1")
require.Equal(t, "failed", r1.Status)
require.Contains(t, r1.ErrorMsg, "server restart")
r2, _ := svc.GetBackupRecord(context.Background(), "stale-2")
require.Equal(t, "failed", r2.RestoreStatus)
require.Contains(t, r2.RestoreError, "server restart")
}
func TestGracefulShutdown(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
_, err := svc.StartBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// Stop 应该等待备份完成
done := make(chan struct{})
go func() {
svc.Stop()
close(done)
}()
// 短暂等待确认 Stop 还在等待
select {
case <-done:
t.Fatal("Stop returned before backup finished")
case <-time.After(100 * time.Millisecond):
// 预期Stop 还在等待
}
// 释放备份
close(dumper.blockCh)
// 现在 Stop 应该完成
select {
case <-done:
// 预期
case <-time.After(5 * time.Second):
t.Fatal("Stop did not return after backup finished")
}
}
func TestStartRestore_Async(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 先创建一个备份(同步方式)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// 异步恢复
restored, err := svc.StartRestore(context.Background(), record.ID)
require.NoError(t, err)
require.Equal(t, "running", restored.RestoreStatus)
svc.wg.Wait()
// 验证最终状态
final, err := svc.GetBackupRecord(context.Background(), record.ID)
require.NoError(t, err)
require.Equal(t, "completed", final.RestoreStatus)
}