From 8a82a2a64890e65c8d50e9d9bad08f80b1831be7 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 2 Mar 2026 00:19:25 +0800 Subject: [PATCH] feat(csp): auto-inject purchase_subscription_url origin into frame-src --- backend/cmd/server/VERSION | 2 +- backend/cmd/server/main.go | 2 +- .../server/middleware/security_headers.go | 17 +++++-- .../middleware/security_headers_test.go | 20 ++++---- backend/internal/server/router.go | 51 +++++++++++++++++-- 5 files changed, 72 insertions(+), 20 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 4aaa184a..937b2cde 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.87.16 +0.1.87.17 diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 63095209..46edcb69 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -100,7 +100,7 @@ func runSetupServer() { r := gin.New() r.Use(middleware.Recovery()) r.Use(middleware.CORS(config.CORSConfig{})) - r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy})) + r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}, nil)) // Register setup routes setup.RegisterRoutes(r) diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 67b19c09..f947241e 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -41,7 +41,9 @@ func GetNonceFromContext(c *gin.Context) string { } // SecurityHeaders sets baseline security headers for all responses. -func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { +// getFrameSrc is an optional function that returns an extra origin to inject into frame-src; +// pass nil to disable dynamic frame-src injection. +func SecurityHeaders(cfg config.CSPConfig, getFrameSrc func() string) gin.HandlerFunc { policy := strings.TrimSpace(cfg.Policy) if policy == "" { policy = config.DefaultCSPPolicy @@ -51,6 +53,13 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { policy = enhanceCSPPolicy(policy) return func(c *gin.Context) { + finalPolicy := policy + if getFrameSrc != nil { + if origin := getFrameSrc(); origin != "" { + finalPolicy = addToDirective(finalPolicy, "frame-src", origin) + } + } + c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Frame-Options", "DENY") c.Header("Referrer-Policy", "strict-origin-when-cross-origin") @@ -61,12 +70,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { if err != nil { // crypto/rand 失败时降级为无 nonce 的 CSP 策略 log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err) - finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'") - c.Header("Content-Security-Policy", finalPolicy) + c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'unsafe-inline'")) } else { c.Set(CSPNonceKey, nonce) - finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'") - c.Header("Content-Security-Policy", finalPolicy) + c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'nonce-"+nonce+"'")) } } c.Next() diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go index 43462b82..8fc81fba 100644 --- a/backend/internal/server/middleware/security_headers_test.go +++ b/backend/internal/server/middleware/security_headers_test.go @@ -84,7 +84,7 @@ func TestGetNonceFromContext(t *testing.T) { func TestSecurityHeaders(t *testing.T) { t.Run("sets_basic_security_headers", func(t *testing.T) { cfg := config.CSPConfig{Enabled: false} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -99,7 +99,7 @@ func TestSecurityHeaders(t *testing.T) { t.Run("csp_disabled_no_csp_header", func(t *testing.T) { cfg := config.CSPConfig{Enabled: false} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -115,7 +115,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "default-src 'self'", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -136,7 +136,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -160,7 +160,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -179,7 +179,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: " \t\n ", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -197,7 +197,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -215,7 +215,7 @@ func TestSecurityHeaders(t *testing.T) { t.Run("calls_next_handler", func(t *testing.T) { cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) nextCalled := false router := gin.New() @@ -238,7 +238,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) nonces := make(map[string]bool) for i := 0; i < 10; i++ { @@ -356,7 +356,7 @@ func BenchmarkSecurityHeadersMiddleware(b *testing.B) { Enabled: true, Policy: "script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index fb91bc0e..14335fe6 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -1,7 +1,11 @@ package server import ( + "context" "log" + "net/url" + "strings" + "sync/atomic" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" @@ -14,6 +18,19 @@ import ( "github.com/redis/go-redis/v9" ) +// extractOrigin returns the scheme+host origin from rawURL, or "" on error. +func extractOrigin(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "" + } + u, err := url.Parse(rawURL) + if err != nil || u.Host == "" { + return "" + } + return u.Scheme + "://" + u.Host +} + // SetupRouter 配置路由器中间件和路由 func SetupRouter( r *gin.Engine, @@ -28,11 +45,33 @@ func SetupRouter( cfg *config.Config, redisClient *redis.Client, ) *gin.Engine { + // 缓存 purchase_subscription_url 的 origin,用于动态注入 CSP frame-src + var cachedPaymentOrigin atomic.Pointer[string] + empty := "" + cachedPaymentOrigin.Store(&empty) + + refreshPaymentOrigin := func() { + settings, err := settingService.GetPublicSettings(context.Background()) + if err == nil && settings.PurchaseSubscriptionEnabled { + origin := extractOrigin(settings.PurchaseSubscriptionURL) + cachedPaymentOrigin.Store(&origin) + } else { + e := "" + cachedPaymentOrigin.Store(&e) + } + } + refreshPaymentOrigin() // 启动时初始化 + // 应用中间件 r.Use(middleware2.RequestLogger()) r.Use(middleware2.Logger()) r.Use(middleware2.CORS(cfg.CORS)) - r.Use(middleware2.SecurityHeaders(cfg.Security.CSP)) + r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() string { + if p := cachedPaymentOrigin.Load(); p != nil { + return *p + } + return "" + })) // Serve embedded frontend with settings injection if available if web.HasEmbeddedFrontend() { @@ -40,11 +79,17 @@ func SetupRouter( if err != nil { log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err) r.Use(web.ServeEmbeddedFrontend()) + settingService.SetOnUpdateCallback(refreshPaymentOrigin) } else { - // Register cache invalidation callback - settingService.SetOnUpdateCallback(frontendServer.InvalidateCache) + // Register combined callback: invalidate HTML cache + refresh payment origin + settingService.SetOnUpdateCallback(func() { + frontendServer.InvalidateCache() + refreshPaymentOrigin() + }) r.Use(frontendServer.Middleware()) } + } else { + settingService.SetOnUpdateCallback(refreshPaymentOrigin) } // 注册路由