From daa7c783b9ddca99c7168cd4847c121a80c912a6 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 2 Mar 2026 01:14:58 +0800 Subject: [PATCH] fix(csp): add timeout ctx, preserve cache on error, validate scheme in extractOrigin --- backend/internal/server/router.go | 17 ++++++++++-- backend/internal/server/router_test.go | 36 ++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 backend/internal/server/router_test.go diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 14335fe6..93b7b808 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -6,6 +6,7 @@ import ( "net/url" "strings" "sync/atomic" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" @@ -19,6 +20,7 @@ import ( ) // extractOrigin returns the scheme+host origin from rawURL, or "" on error. +// Only http and https schemes are accepted; other values (e.g. "//host/path") return "". func extractOrigin(rawURL string) string { rawURL = strings.TrimSpace(rawURL) if rawURL == "" { @@ -28,9 +30,14 @@ func extractOrigin(rawURL string) string { if err != nil || u.Host == "" { return "" } + if u.Scheme != "http" && u.Scheme != "https" { + return "" + } return u.Scheme + "://" + u.Host } +const paymentOriginFetchTimeout = 5 * time.Second + // SetupRouter 配置路由器中间件和路由 func SetupRouter( r *gin.Engine, @@ -51,8 +58,14 @@ func SetupRouter( cachedPaymentOrigin.Store(&empty) refreshPaymentOrigin := func() { - settings, err := settingService.GetPublicSettings(context.Background()) - if err == nil && settings.PurchaseSubscriptionEnabled { + ctx, cancel := context.WithTimeout(context.Background(), paymentOriginFetchTimeout) + defer cancel() + settings, err := settingService.GetPublicSettings(ctx) + if err != nil { + // 获取失败时保留已有缓存,避免 frame-src 被意外清空 + return + } + if settings.PurchaseSubscriptionEnabled { origin := extractOrigin(settings.PurchaseSubscriptionURL) cachedPaymentOrigin.Store(&origin) } else { diff --git a/backend/internal/server/router_test.go b/backend/internal/server/router_test.go new file mode 100644 index 00000000..7e679466 --- /dev/null +++ b/backend/internal/server/router_test.go @@ -0,0 +1,36 @@ +//go:build unit + +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtractOrigin(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"empty string", "", ""}, + {"whitespace only", " ", ""}, + {"valid https", "https://pay.example.com/checkout", "https://pay.example.com"}, + {"valid http", "http://pay.example.com/checkout", "http://pay.example.com"}, + {"https with port", "https://pay.example.com:8443/checkout", "https://pay.example.com:8443"}, + {"protocol-relative //host", "//pay.example.com/path", ""}, + {"no scheme", "pay.example.com/path", ""}, + {"ftp scheme rejected", "ftp://pay.example.com/file", ""}, + {"empty host after parse", "https:///path", ""}, + {"invalid url", "://bad url", ""}, + {"only scheme", "https://", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractOrigin(tt.input) + assert.Equal(t, tt.want, got) + }) + } +}