mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-16 12:54:45 +08:00
1. S3 凭证加密存储:使用 SecretEncryptor (AES-256-GCM) 加密 SecretAccessKey, 防止备份文件中泄露 S3 凭证,兼容旧的未加密数据 2. 修复 saveRecord 竞态条件:添加 recordsMu 互斥锁保护 records 的 load/save 3. 恢复操作增加服务端验证:handler 层要求重新输入管理员密码,通过 bcrypt 校验,前端弹出密码输入框 4. pg_dump/psql/S3 操作抽象为接口:定义 DBDumper 和 BackupObjectStore 接口, 实现放入 repository 层,遵循项目依赖注入架构规范 5. 改为流式处理避免大数据库 OOM:备份时 pg_dump stdout -> gzip -> io.Pipe -> S3 upload;恢复时 S3 download -> gzip reader -> psql stdin,不再全量加载 6. loadRecords 区分"无数据"和"数据损坏"场景:JSON 解析失败返回明确错误 7. 添加 18 个核心逻辑单元测试:覆盖加密、并发、流式备份/恢复、错误处理等 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
117 lines
3.2 KiB
Go
117 lines
3.2 KiB
Go
package repository
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"time"
|
||
|
||
"github.com/aws/aws-sdk-go-v2/aws"
|
||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
)
|
||
|
||
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
|
||
type S3BackupStore struct {
|
||
client *s3.Client
|
||
bucket string
|
||
}
|
||
|
||
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
|
||
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
||
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
|
||
region := cfg.Region
|
||
if region == "" {
|
||
region = "auto" // Cloudflare R2 默认 region
|
||
}
|
||
|
||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||
awsconfig.WithRegion(region),
|
||
awsconfig.WithCredentialsProvider(
|
||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||
),
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("load aws config: %w", err)
|
||
}
|
||
|
||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||
if cfg.Endpoint != "" {
|
||
o.BaseEndpoint = &cfg.Endpoint
|
||
}
|
||
if cfg.ForcePathStyle {
|
||
o.UsePathStyle = true
|
||
}
|
||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||
})
|
||
|
||
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
|
||
}
|
||
}
|
||
|
||
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||
data, err := io.ReadAll(body)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("read body: %w", err)
|
||
}
|
||
|
||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||
Bucket: &s.bucket,
|
||
Key: &key,
|
||
Body: bytes.NewReader(data),
|
||
ContentType: &contentType,
|
||
})
|
||
if err != nil {
|
||
return 0, fmt.Errorf("S3 PutObject: %w", err)
|
||
}
|
||
return int64(len(data)), nil
|
||
}
|
||
|
||
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
|
||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||
Bucket: &s.bucket,
|
||
Key: &key,
|
||
})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("S3 GetObject: %w", err)
|
||
}
|
||
return result.Body, nil
|
||
}
|
||
|
||
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
|
||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||
Bucket: &s.bucket,
|
||
Key: &key,
|
||
})
|
||
return err
|
||
}
|
||
|
||
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
|
||
presignClient := s3.NewPresignClient(s.client)
|
||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||
Bucket: &s.bucket,
|
||
Key: &key,
|
||
}, s3.WithPresignExpires(expiry))
|
||
if err != nil {
|
||
return "", fmt.Errorf("presign url: %w", err)
|
||
}
|
||
return result.URL, nil
|
||
}
|
||
|
||
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
|
||
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||
Bucket: &s.bucket,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||
}
|
||
return nil
|
||
}
|