mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-19 06:14:45 +08:00
109 lines
2.9 KiB
Go
109 lines
2.9 KiB
Go
|
|
package repository
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"crypto/rand"
|
||
|
|
"encoding/hex"
|
||
|
|
"fmt"
|
||
|
|
"log"
|
||
|
|
"strings"
|
||
|
|
|
||
|
|
"github.com/Wei-Shaw/sub2api/ent"
|
||
|
|
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
||
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
|
|
)
|
||
|
|
|
||
|
|
const securitySecretKeyJWT = "jwt_secret"
|
||
|
|
|
||
|
|
var readRandomBytes = rand.Read
|
||
|
|
|
||
|
|
func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config.Config) error {
|
||
|
|
if client == nil {
|
||
|
|
return fmt.Errorf("nil ent client")
|
||
|
|
}
|
||
|
|
if cfg == nil {
|
||
|
|
return fmt.Errorf("nil config")
|
||
|
|
}
|
||
|
|
|
||
|
|
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||
|
|
if cfg.JWT.Secret != "" {
|
||
|
|
if err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret); err != nil {
|
||
|
|
return fmt.Errorf("persist jwt secret: %w", err)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
secret, created, err := getOrCreateGeneratedSecuritySecret(ctx, client, securitySecretKeyJWT, 32)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("ensure jwt secret: %w", err)
|
||
|
|
}
|
||
|
|
cfg.JWT.Secret = secret
|
||
|
|
|
||
|
|
if created {
|
||
|
|
log.Println("Warning: JWT secret auto-generated and persisted to database. Consider rotating to a managed secret for production.")
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, key string, byteLength int) (string, bool, error) {
|
||
|
|
existing, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
|
||
|
|
if err == nil {
|
||
|
|
value := strings.TrimSpace(existing.Value)
|
||
|
|
if len([]byte(value)) < 32 {
|
||
|
|
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
|
||
|
|
}
|
||
|
|
return value, false, nil
|
||
|
|
}
|
||
|
|
if !ent.IsNotFound(err) {
|
||
|
|
return "", false, err
|
||
|
|
}
|
||
|
|
|
||
|
|
generated, err := generateHexSecret(byteLength)
|
||
|
|
if err != nil {
|
||
|
|
return "", false, err
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := client.SecuritySecret.Create().
|
||
|
|
SetKey(key).
|
||
|
|
SetValue(generated).
|
||
|
|
OnConflictColumns(securitysecret.FieldKey).
|
||
|
|
DoNothing().
|
||
|
|
Exec(ctx); err != nil {
|
||
|
|
return "", false, err
|
||
|
|
}
|
||
|
|
|
||
|
|
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
|
||
|
|
if err != nil {
|
||
|
|
return "", false, err
|
||
|
|
}
|
||
|
|
value := strings.TrimSpace(stored.Value)
|
||
|
|
if len([]byte(value)) < 32 {
|
||
|
|
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
|
||
|
|
}
|
||
|
|
return value, value == generated, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) error {
|
||
|
|
value = strings.TrimSpace(value)
|
||
|
|
if len([]byte(value)) < 32 {
|
||
|
|
return fmt.Errorf("secret %q must be at least 32 bytes", key)
|
||
|
|
}
|
||
|
|
|
||
|
|
_, err := client.SecuritySecret.Create().SetKey(key).SetValue(value).Save(ctx)
|
||
|
|
if err == nil || ent.IsConstraintError(err) {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func generateHexSecret(byteLength int) (string, error) {
|
||
|
|
if byteLength <= 0 {
|
||
|
|
byteLength = 32
|
||
|
|
}
|
||
|
|
buf := make([]byte, byteLength)
|
||
|
|
if _, err := readRandomBytes(buf); err != nil {
|
||
|
|
return "", fmt.Errorf("generate random secret: %w", err)
|
||
|
|
}
|
||
|
|
return hex.EncodeToString(buf), nil
|
||
|
|
}
|