mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-20 06:44:44 +08:00
feat(csp): auto-inject purchase_subscription_url origin into frame-src
This commit is contained in:
@@ -1 +1 @@
|
|||||||
0.1.87.16
|
0.1.87.17
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ func runSetupServer() {
|
|||||||
r := gin.New()
|
r := gin.New()
|
||||||
r.Use(middleware.Recovery())
|
r.Use(middleware.Recovery())
|
||||||
r.Use(middleware.CORS(config.CORSConfig{}))
|
r.Use(middleware.CORS(config.CORSConfig{}))
|
||||||
r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}))
|
r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}, nil))
|
||||||
|
|
||||||
// Register setup routes
|
// Register setup routes
|
||||||
setup.RegisterRoutes(r)
|
setup.RegisterRoutes(r)
|
||||||
|
|||||||
@@ -41,7 +41,9 @@ func GetNonceFromContext(c *gin.Context) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SecurityHeaders sets baseline security headers for all responses.
|
// SecurityHeaders sets baseline security headers for all responses.
|
||||||
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
// getFrameSrc is an optional function that returns an extra origin to inject into frame-src;
|
||||||
|
// pass nil to disable dynamic frame-src injection.
|
||||||
|
func SecurityHeaders(cfg config.CSPConfig, getFrameSrc func() string) gin.HandlerFunc {
|
||||||
policy := strings.TrimSpace(cfg.Policy)
|
policy := strings.TrimSpace(cfg.Policy)
|
||||||
if policy == "" {
|
if policy == "" {
|
||||||
policy = config.DefaultCSPPolicy
|
policy = config.DefaultCSPPolicy
|
||||||
@@ -51,6 +53,13 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
|||||||
policy = enhanceCSPPolicy(policy)
|
policy = enhanceCSPPolicy(policy)
|
||||||
|
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
finalPolicy := policy
|
||||||
|
if getFrameSrc != nil {
|
||||||
|
if origin := getFrameSrc(); origin != "" {
|
||||||
|
finalPolicy = addToDirective(finalPolicy, "frame-src", origin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.Header("X-Content-Type-Options", "nosniff")
|
c.Header("X-Content-Type-Options", "nosniff")
|
||||||
c.Header("X-Frame-Options", "DENY")
|
c.Header("X-Frame-Options", "DENY")
|
||||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||||
@@ -61,12 +70,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// crypto/rand 失败时降级为无 nonce 的 CSP 策略
|
// crypto/rand 失败时降级为无 nonce 的 CSP 策略
|
||||||
log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err)
|
log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err)
|
||||||
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'")
|
c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'unsafe-inline'"))
|
||||||
c.Header("Content-Security-Policy", finalPolicy)
|
|
||||||
} else {
|
} else {
|
||||||
c.Set(CSPNonceKey, nonce)
|
c.Set(CSPNonceKey, nonce)
|
||||||
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
|
c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'nonce-"+nonce+"'"))
|
||||||
c.Header("Content-Security-Policy", finalPolicy)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ func TestGetNonceFromContext(t *testing.T) {
|
|||||||
func TestSecurityHeaders(t *testing.T) {
|
func TestSecurityHeaders(t *testing.T) {
|
||||||
t.Run("sets_basic_security_headers", func(t *testing.T) {
|
t.Run("sets_basic_security_headers", func(t *testing.T) {
|
||||||
cfg := config.CSPConfig{Enabled: false}
|
cfg := config.CSPConfig{Enabled: false}
|
||||||
middleware := SecurityHeaders(cfg)
|
middleware := SecurityHeaders(cfg, nil)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
@@ -99,7 +99,7 @@ func TestSecurityHeaders(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
|
t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
|
||||||
cfg := config.CSPConfig{Enabled: false}
|
cfg := config.CSPConfig{Enabled: false}
|
||||||
middleware := SecurityHeaders(cfg)
|
middleware := SecurityHeaders(cfg, nil)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
@@ -115,7 +115,7 @@ func TestSecurityHeaders(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
Policy: "default-src 'self'",
|
Policy: "default-src 'self'",
|
||||||
}
|
}
|
||||||
middleware := SecurityHeaders(cfg)
|
middleware := SecurityHeaders(cfg, nil)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
@@ -136,7 +136,7 @@ func TestSecurityHeaders(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
Policy: "script-src 'self' __CSP_NONCE__",
|
Policy: "script-src 'self' __CSP_NONCE__",
|
||||||
}
|
}
|
||||||
middleware := SecurityHeaders(cfg)
|
middleware := SecurityHeaders(cfg, nil)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
@@ -160,7 +160,7 @@ func TestSecurityHeaders(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
Policy: "",
|
Policy: "",
|
||||||
}
|
}
|
||||||
middleware := SecurityHeaders(cfg)
|
middleware := SecurityHeaders(cfg, nil)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
@@ -179,7 +179,7 @@ func TestSecurityHeaders(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
Policy: " \t\n ",
|
Policy: " \t\n ",
|
||||||
}
|
}
|
||||||
middleware := SecurityHeaders(cfg)
|
middleware := SecurityHeaders(cfg, nil)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
@@ -197,7 +197,7 @@ func TestSecurityHeaders(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
|
Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
|
||||||
}
|
}
|
||||||
middleware := SecurityHeaders(cfg)
|
middleware := SecurityHeaders(cfg, nil)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
@@ -215,7 +215,7 @@ func TestSecurityHeaders(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("calls_next_handler", func(t *testing.T) {
|
t.Run("calls_next_handler", func(t *testing.T) {
|
||||||
cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
|
cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
|
||||||
middleware := SecurityHeaders(cfg)
|
middleware := SecurityHeaders(cfg, nil)
|
||||||
|
|
||||||
nextCalled := false
|
nextCalled := false
|
||||||
router := gin.New()
|
router := gin.New()
|
||||||
@@ -238,7 +238,7 @@ func TestSecurityHeaders(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
Policy: "script-src __CSP_NONCE__",
|
Policy: "script-src __CSP_NONCE__",
|
||||||
}
|
}
|
||||||
middleware := SecurityHeaders(cfg)
|
middleware := SecurityHeaders(cfg, nil)
|
||||||
|
|
||||||
nonces := make(map[string]bool)
|
nonces := make(map[string]bool)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
@@ -356,7 +356,7 @@ func BenchmarkSecurityHeadersMiddleware(b *testing.B) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
Policy: "script-src 'self' __CSP_NONCE__",
|
Policy: "script-src 'self' __CSP_NONCE__",
|
||||||
}
|
}
|
||||||
middleware := SecurityHeaders(cfg)
|
middleware := SecurityHeaders(cfg, nil)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log"
|
"log"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
@@ -14,6 +18,19 @@ import (
|
|||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// extractOrigin returns the scheme+host origin from rawURL, or "" on error.
|
||||||
|
func extractOrigin(rawURL string) string {
|
||||||
|
rawURL = strings.TrimSpace(rawURL)
|
||||||
|
if rawURL == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
u, err := url.Parse(rawURL)
|
||||||
|
if err != nil || u.Host == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return u.Scheme + "://" + u.Host
|
||||||
|
}
|
||||||
|
|
||||||
// SetupRouter 配置路由器中间件和路由
|
// SetupRouter 配置路由器中间件和路由
|
||||||
func SetupRouter(
|
func SetupRouter(
|
||||||
r *gin.Engine,
|
r *gin.Engine,
|
||||||
@@ -28,11 +45,33 @@ func SetupRouter(
|
|||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
redisClient *redis.Client,
|
redisClient *redis.Client,
|
||||||
) *gin.Engine {
|
) *gin.Engine {
|
||||||
|
// 缓存 purchase_subscription_url 的 origin,用于动态注入 CSP frame-src
|
||||||
|
var cachedPaymentOrigin atomic.Pointer[string]
|
||||||
|
empty := ""
|
||||||
|
cachedPaymentOrigin.Store(&empty)
|
||||||
|
|
||||||
|
refreshPaymentOrigin := func() {
|
||||||
|
settings, err := settingService.GetPublicSettings(context.Background())
|
||||||
|
if err == nil && settings.PurchaseSubscriptionEnabled {
|
||||||
|
origin := extractOrigin(settings.PurchaseSubscriptionURL)
|
||||||
|
cachedPaymentOrigin.Store(&origin)
|
||||||
|
} else {
|
||||||
|
e := ""
|
||||||
|
cachedPaymentOrigin.Store(&e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
refreshPaymentOrigin() // 启动时初始化
|
||||||
|
|
||||||
// 应用中间件
|
// 应用中间件
|
||||||
r.Use(middleware2.RequestLogger())
|
r.Use(middleware2.RequestLogger())
|
||||||
r.Use(middleware2.Logger())
|
r.Use(middleware2.Logger())
|
||||||
r.Use(middleware2.CORS(cfg.CORS))
|
r.Use(middleware2.CORS(cfg.CORS))
|
||||||
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
|
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() string {
|
||||||
|
if p := cachedPaymentOrigin.Load(); p != nil {
|
||||||
|
return *p
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}))
|
||||||
|
|
||||||
// Serve embedded frontend with settings injection if available
|
// Serve embedded frontend with settings injection if available
|
||||||
if web.HasEmbeddedFrontend() {
|
if web.HasEmbeddedFrontend() {
|
||||||
@@ -40,11 +79,17 @@ func SetupRouter(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err)
|
log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err)
|
||||||
r.Use(web.ServeEmbeddedFrontend())
|
r.Use(web.ServeEmbeddedFrontend())
|
||||||
|
settingService.SetOnUpdateCallback(refreshPaymentOrigin)
|
||||||
} else {
|
} else {
|
||||||
// Register cache invalidation callback
|
// Register combined callback: invalidate HTML cache + refresh payment origin
|
||||||
settingService.SetOnUpdateCallback(frontendServer.InvalidateCache)
|
settingService.SetOnUpdateCallback(func() {
|
||||||
|
frontendServer.InvalidateCache()
|
||||||
|
refreshPaymentOrigin()
|
||||||
|
})
|
||||||
r.Use(frontendServer.Middleware())
|
r.Use(frontendServer.Middleware())
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
settingService.SetOnUpdateCallback(refreshPaymentOrigin)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注册路由
|
// 注册路由
|
||||||
|
|||||||
Reference in New Issue
Block a user