mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-01 11:50:44 +08:00
fix: 按 review 意见重构数据库备份服务(安全性 + 架构 + 健壮性)
1. S3 凭证加密存储:使用 SecretEncryptor (AES-256-GCM) 加密 SecretAccessKey, 防止备份文件中泄露 S3 凭证,兼容旧的未加密数据 2. 修复 saveRecord 竞态条件:添加 recordsMu 互斥锁保护 records 的 load/save 3. 恢复操作增加服务端验证:handler 层要求重新输入管理员密码,通过 bcrypt 校验,前端弹出密码输入框 4. pg_dump/psql/S3 操作抽象为接口:定义 DBDumper 和 BackupObjectStore 接口, 实现放入 repository 层,遵循项目依赖注入架构规范 5. 改为流式处理避免大数据库 OOM:备份时 pg_dump stdout -> gzip -> io.Pipe -> S3 upload;恢复时 S3 download -> gzip reader -> psql stdin,不再全量加载 6. loadRecords 区分"无数据"和"数据损坏"场景:JSON 解析失败返回明确错误 7. 添加 18 个核心逻辑单元测试:覆盖加密、并发、流式备份/恢复、错误处理等 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,23 +1,16 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/google/uuid"
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
@@ -39,8 +32,32 @@ var (
|
||||
ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found")
|
||||
ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress")
|
||||
ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress")
|
||||
ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted")
|
||||
ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted")
|
||||
)
|
||||
|
||||
// ─── 接口定义 ───
|
||||
|
||||
// DBDumper abstracts database dump/restore operations
|
||||
type DBDumper interface {
|
||||
Dump(ctx context.Context) (io.ReadCloser, error)
|
||||
Restore(ctx context.Context, data io.Reader) error
|
||||
}
|
||||
|
||||
// BackupObjectStore abstracts object storage for backup files
|
||||
type BackupObjectStore interface {
|
||||
Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error)
|
||||
Download(ctx context.Context, key string) (io.ReadCloser, error)
|
||||
Delete(ctx context.Context, key string) error
|
||||
PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error)
|
||||
HeadBucket(ctx context.Context) error
|
||||
}
|
||||
|
||||
// BackupObjectStoreFactory creates an object store from S3 config
|
||||
type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error)
|
||||
|
||||
// ─── 数据模型 ───
|
||||
|
||||
// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2)
|
||||
type BackupS3Config struct {
|
||||
Endpoint string `json:"endpoint"` // e.g. https://<account_id>.r2.cloudflarestorage.com
|
||||
@@ -82,26 +99,39 @@ type BackupRecord struct {
|
||||
|
||||
// BackupService 数据库备份恢复服务
|
||||
type BackupService struct {
|
||||
settingRepo SettingRepository
|
||||
dbCfg *config.DatabaseConfig
|
||||
settingRepo SettingRepository
|
||||
dbCfg *config.DatabaseConfig
|
||||
encryptor SecretEncryptor
|
||||
storeFactory BackupObjectStoreFactory
|
||||
dumper DBDumper
|
||||
|
||||
mu sync.Mutex
|
||||
s3Client *s3.Client
|
||||
store BackupObjectStore
|
||||
s3Cfg *BackupS3Config
|
||||
backingUp bool
|
||||
restoring bool
|
||||
|
||||
recordsMu sync.Mutex // 保护 records 的 load/save 操作
|
||||
|
||||
cronMu sync.Mutex
|
||||
cronSched *cron.Cron
|
||||
cronEntryID cron.EntryID
|
||||
}
|
||||
|
||||
func NewBackupService(settingRepo SettingRepository, cfg *config.Config) *BackupService {
|
||||
svc := &BackupService{
|
||||
settingRepo: settingRepo,
|
||||
dbCfg: &cfg.Database,
|
||||
func NewBackupService(
|
||||
settingRepo SettingRepository,
|
||||
cfg *config.Config,
|
||||
encryptor SecretEncryptor,
|
||||
storeFactory BackupObjectStoreFactory,
|
||||
dumper DBDumper,
|
||||
) *BackupService {
|
||||
return &BackupService{
|
||||
settingRepo: settingRepo,
|
||||
dbCfg: &cfg.Database,
|
||||
encryptor: encryptor,
|
||||
storeFactory: storeFactory,
|
||||
dumper: dumper,
|
||||
}
|
||||
return svc
|
||||
}
|
||||
|
||||
// Start 启动定时备份调度器
|
||||
@@ -136,17 +166,16 @@ func (s *BackupService) Stop() {
|
||||
// ─── S3 配置管理 ───
|
||||
|
||||
func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
|
||||
if err != nil || raw == "" {
|
||||
return &BackupS3Config{}, nil
|
||||
cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var cfg BackupS3Config
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
if cfg == nil {
|
||||
return &BackupS3Config{}, nil
|
||||
}
|
||||
// 脱敏返回
|
||||
cfg.SecretAccessKey = ""
|
||||
return &cfg, nil
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) {
|
||||
@@ -156,6 +185,13 @@ func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config)
|
||||
if old != nil {
|
||||
cfg.SecretAccessKey = old.SecretAccessKey
|
||||
}
|
||||
} else {
|
||||
// 加密 SecretAccessKey
|
||||
encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encrypt secret: %w", err)
|
||||
}
|
||||
cfg.SecretAccessKey = encrypted
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
@@ -168,7 +204,7 @@ func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config)
|
||||
|
||||
// 清除缓存的 S3 客户端
|
||||
s.mu.Lock()
|
||||
s.s3Client = nil
|
||||
s.store = nil
|
||||
s.s3Cfg = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
@@ -189,17 +225,11 @@ func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config
|
||||
return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required")
|
||||
}
|
||||
|
||||
client, err := s.buildS3Client(ctx, &cfg)
|
||||
store, err := s.storeFactory(ctx, &cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &cfg.Bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
return store.HeadBucket(ctx)
|
||||
}
|
||||
|
||||
// ─── 定时备份管理 ───
|
||||
@@ -313,7 +343,7 @@ func (s *BackupService) runScheduledBackup() {
|
||||
|
||||
// ─── 备份/恢复核心 ───
|
||||
|
||||
// CreateBackup 创建全量数据库备份并上传到 S3
|
||||
// CreateBackup 创建全量数据库备份并上传到 S3(流式处理)
|
||||
// expireDays: 备份过期天数,0=永不过期,默认14天
|
||||
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
||||
s.mu.Lock()
|
||||
@@ -337,9 +367,9 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
client, err := s.getOrCreateS3Client(ctx, s3Cfg)
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init S3 client: %w", err)
|
||||
return nil, fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
@@ -363,8 +393,8 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
// 执行全量 pg_dump
|
||||
dumpData, err := s.pgDump(ctx)
|
||||
// 流式执行: pg_dump -> gzip -> S3 upload
|
||||
dumpReader, err := s.dumper.Dump(ctx)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
|
||||
@@ -373,38 +403,40 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
|
||||
return record, fmt.Errorf("pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// gzip 压缩
|
||||
var compressed bytes.Buffer
|
||||
gzWriter := gzip.NewWriter(&compressed)
|
||||
if _, err := gzWriter.Write(dumpData); err != nil {
|
||||
record.Status = "failed"
|
||||
record.ErrorMsg = fmt.Sprintf("gzip failed: %v", err)
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("gzip: %w", err)
|
||||
}
|
||||
if err := gzWriter.Close(); err != nil {
|
||||
return nil, fmt.Errorf("gzip close: %w", err)
|
||||
}
|
||||
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
|
||||
pr, pw := io.Pipe()
|
||||
var gzipErr error
|
||||
go func() {
|
||||
gzWriter := gzip.NewWriter(pw)
|
||||
_, gzipErr = io.Copy(gzWriter, dumpReader)
|
||||
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if gzipErr != nil {
|
||||
_ = pw.CloseWithError(gzipErr)
|
||||
} else {
|
||||
_ = pw.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
record.SizeBytes = int64(compressed.Len())
|
||||
|
||||
// 上传到 S3
|
||||
contentType := "application/gzip"
|
||||
_, err = client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: &s3Cfg.Bucket,
|
||||
Key: &s3Key,
|
||||
Body: bytes.NewReader(compressed.Bytes()),
|
||||
ContentType: &contentType,
|
||||
})
|
||||
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
record.ErrorMsg = fmt.Sprintf("S3 upload failed: %v", err)
|
||||
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
||||
if gzipErr != nil {
|
||||
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
|
||||
}
|
||||
record.ErrorMsg = errMsg
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("s3 upload: %w", err)
|
||||
return record, fmt.Errorf("backup upload: %w", err)
|
||||
}
|
||||
|
||||
record.SizeBytes = sizeBytes
|
||||
record.Status = "completed"
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
if err := s.saveRecord(ctx, record); err != nil {
|
||||
@@ -414,7 +446,7 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// RestoreBackup 从 S3 下载备份并恢复到数据库
|
||||
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
|
||||
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
|
||||
s.mu.Lock()
|
||||
if s.restoring {
|
||||
@@ -441,35 +473,27 @@ func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) erro
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client, err := s.getOrCreateS3Client(ctx, s3Cfg)
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init S3 client: %w", err)
|
||||
return fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
// 从 S3 下载
|
||||
result, err := client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s3Cfg.Bucket,
|
||||
Key: &record.S3Key,
|
||||
})
|
||||
// 从 S3 流式下载
|
||||
body, err := objectStore.Download(ctx, record.S3Key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 download failed: %w", err)
|
||||
}
|
||||
defer func() { _ = result.Body.Close() }()
|
||||
defer func() { _ = body.Close() }()
|
||||
|
||||
// 解压 gzip
|
||||
gzReader, err := gzip.NewReader(result.Body)
|
||||
// 流式解压 gzip -> psql(不将全部数据加载到内存)
|
||||
gzReader, err := gzip.NewReader(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gzip reader: %w", err)
|
||||
}
|
||||
defer func() { _ = gzReader.Close() }()
|
||||
|
||||
sqlData, err := io.ReadAll(gzReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read backup data: %w", err)
|
||||
}
|
||||
|
||||
// 执行 psql 恢复
|
||||
if err := s.pgRestore(ctx, sqlData); err != nil {
|
||||
// 流式恢复
|
||||
if err := s.dumper.Restore(ctx, gzReader); err != nil {
|
||||
return fmt.Errorf("pg restore: %w", err)
|
||||
}
|
||||
|
||||
@@ -504,7 +528,10 @@ func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*
|
||||
}
|
||||
|
||||
func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error {
|
||||
records, err := s.loadRecords(ctx)
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -526,17 +553,14 @@ func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error
|
||||
if found.S3Key != "" && found.Status == "completed" {
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() {
|
||||
client, err := s.getOrCreateS3Client(ctx, s3Cfg)
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err == nil {
|
||||
_, _ = client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &s3Cfg.Bucket,
|
||||
Key: &found.S3Key,
|
||||
})
|
||||
_ = objectStore.Delete(ctx, found.S3Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.saveRecords(ctx, remaining)
|
||||
return s.saveRecordsLocked(ctx, remaining)
|
||||
}
|
||||
|
||||
// GetBackupDownloadURL 获取备份文件预签名下载 URL
|
||||
@@ -553,20 +577,16 @@ func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID strin
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
client, err := s.getOrCreateS3Client(ctx, s3Cfg)
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
presignClient := s3.NewPresignClient(client)
|
||||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s3Cfg.Bucket,
|
||||
Key: &record.S3Key,
|
||||
}, s3.WithPresignExpires(1*time.Hour))
|
||||
url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return result.URL, nil
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// ─── 内部方法 ───
|
||||
@@ -574,63 +594,44 @@ func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID strin
|
||||
func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // no config is a valid state
|
||||
}
|
||||
var cfg BackupS3Config
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return nil, nil
|
||||
return nil, ErrBackupS3ConfigCorrupt
|
||||
}
|
||||
// 解密 SecretAccessKey
|
||||
if cfg.SecretAccessKey != "" {
|
||||
decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
// 兼容未加密的旧数据:如果解密失败,保持原值
|
||||
logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err)
|
||||
} else {
|
||||
cfg.SecretAccessKey = decrypted
|
||||
}
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) buildS3Client(ctx context.Context, cfg *BackupS3Config) (*s3.Client, error) {
|
||||
region := cfg.Region
|
||||
if region == "" {
|
||||
region = "auto" // Cloudflare R2 默认 region
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(region),
|
||||
awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = &cfg.Endpoint
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
})
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) getOrCreateS3Client(ctx context.Context, cfg *BackupS3Config) (*s3.Client, error) {
|
||||
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.s3Client != nil && s.s3Cfg != nil {
|
||||
return s.s3Client, nil
|
||||
if s.store != nil && s.s3Cfg != nil {
|
||||
return s.store, nil
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
client, err := s.buildS3Client(ctx, cfg)
|
||||
store, err := s.storeFactory(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.s3Client = client
|
||||
s.store = store
|
||||
s.s3Cfg = cfg
|
||||
return client, nil
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string {
|
||||
@@ -641,76 +642,34 @@ func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string
|
||||
return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName)
|
||||
}
|
||||
|
||||
func (s *BackupService) pgDump(ctx context.Context) ([]byte, error) {
|
||||
args := []string{
|
||||
"-h", s.dbCfg.Host,
|
||||
"-p", fmt.Sprintf("%d", s.dbCfg.Port),
|
||||
"-U", s.dbCfg.User,
|
||||
"-d", s.dbCfg.DBName,
|
||||
"--no-owner",
|
||||
"--no-acl",
|
||||
"--clean",
|
||||
"--if-exists",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_dump", args...)
|
||||
if s.dbCfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+s.dbCfg.Password)
|
||||
}
|
||||
if s.dbCfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+s.dbCfg.SSLMode)
|
||||
}
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("%v: %s", err, stderr.String())
|
||||
}
|
||||
return stdout.Bytes(), nil
|
||||
}
|
||||
|
||||
func (s *BackupService) pgRestore(ctx context.Context, sqlData []byte) error {
|
||||
args := []string{
|
||||
"-h", s.dbCfg.Host,
|
||||
"-p", fmt.Sprintf("%d", s.dbCfg.Port),
|
||||
"-U", s.dbCfg.User,
|
||||
"-d", s.dbCfg.DBName,
|
||||
"--single-transaction",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
if s.dbCfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+s.dbCfg.Password)
|
||||
}
|
||||
if s.dbCfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+s.dbCfg.SSLMode)
|
||||
}
|
||||
|
||||
cmd.Stdin = bytes.NewReader(sqlData)
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("%v: %s", err, stderr.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadRecords 加载备份记录,区分"无数据"和"数据损坏"
|
||||
func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
return s.loadRecordsLocked(ctx)
|
||||
}
|
||||
|
||||
// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录
|
||||
func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // no records is a valid state
|
||||
}
|
||||
var records []BackupRecord
|
||||
if err := json.Unmarshal([]byte(raw), &records); err != nil {
|
||||
return nil, nil
|
||||
return nil, ErrBackupRecordsCorrupt
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) saveRecords(ctx context.Context, records []BackupRecord) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
return s.saveRecordsLocked(ctx, records)
|
||||
}
|
||||
|
||||
// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录
|
||||
func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error {
|
||||
data, err := json.Marshal(records)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -718,8 +677,12 @@ func (s *BackupService) saveRecords(ctx context.Context, records []BackupRecord)
|
||||
return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data))
|
||||
}
|
||||
|
||||
// saveRecord 保存单条记录(带互斥锁保护)
|
||||
func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error {
|
||||
records, _ := s.loadRecords(ctx)
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, _ := s.loadRecordsLocked(ctx)
|
||||
|
||||
// 更新已有记录或追加
|
||||
found := false
|
||||
@@ -739,7 +702,7 @@ func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) er
|
||||
records = records[len(records)-maxBackupRecords:]
|
||||
}
|
||||
|
||||
return s.saveRecords(ctx, records)
|
||||
return s.saveRecordsLocked(ctx, records)
|
||||
}
|
||||
|
||||
func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error {
|
||||
@@ -747,7 +710,10 @@ func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupS
|
||||
return nil
|
||||
}
|
||||
|
||||
records, err := s.loadRecords(ctx)
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -792,7 +758,7 @@ func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupS
|
||||
|
||||
if len(toDelete) > 0 {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete))
|
||||
return s.saveRecords(ctx, toKeep)
|
||||
return s.saveRecordsLocked(ctx, toKeep)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -802,13 +768,9 @@ func (s *BackupService) deleteS3Object(ctx context.Context, key string) error {
|
||||
if err != nil || s3Cfg == nil {
|
||||
return nil
|
||||
}
|
||||
client, err := s.getOrCreateS3Client(ctx, s3Cfg)
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &s3Cfg.Bucket,
|
||||
Key: &key,
|
||||
})
|
||||
return err
|
||||
return objectStore.Delete(ctx, key)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user