fix(csp): add timeout ctx, preserve cache on error, validate scheme in extractOrigin

This commit is contained in:
erio
2026-03-02 01:14:58 +08:00
parent 8a82a2a648
commit daa7c783b9
2 changed files with 51 additions and 2 deletions

View File

@@ -6,6 +6,7 @@ import (
"net/url"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
@@ -19,6 +20,7 @@ import (
)
// extractOrigin returns the scheme+host origin from rawURL, or "" on error.
// Only http and https schemes are accepted; other values (e.g. "//host/path") return "".
func extractOrigin(rawURL string) string {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
@@ -28,9 +30,14 @@ func extractOrigin(rawURL string) string {
if err != nil || u.Host == "" {
return ""
}
if u.Scheme != "http" && u.Scheme != "https" {
return ""
}
return u.Scheme + "://" + u.Host
}
const paymentOriginFetchTimeout = 5 * time.Second
// SetupRouter 配置路由器中间件和路由
func SetupRouter(
r *gin.Engine,
@@ -51,8 +58,14 @@ func SetupRouter(
cachedPaymentOrigin.Store(&empty)
refreshPaymentOrigin := func() {
settings, err := settingService.GetPublicSettings(context.Background())
if err == nil && settings.PurchaseSubscriptionEnabled {
ctx, cancel := context.WithTimeout(context.Background(), paymentOriginFetchTimeout)
defer cancel()
settings, err := settingService.GetPublicSettings(ctx)
if err != nil {
// 获取失败时保留已有缓存,避免 frame-src 被意外清空
return
}
if settings.PurchaseSubscriptionEnabled {
origin := extractOrigin(settings.PurchaseSubscriptionURL)
cachedPaymentOrigin.Store(&origin)
} else {

View File

@@ -0,0 +1,36 @@
//go:build unit
package server
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestExtractOrigin(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"empty string", "", ""},
{"whitespace only", " ", ""},
{"valid https", "https://pay.example.com/checkout", "https://pay.example.com"},
{"valid http", "http://pay.example.com/checkout", "http://pay.example.com"},
{"https with port", "https://pay.example.com:8443/checkout", "https://pay.example.com:8443"},
{"protocol-relative //host", "//pay.example.com/path", ""},
{"no scheme", "pay.example.com/path", ""},
{"ftp scheme rejected", "ftp://pay.example.com/file", ""},
{"empty host after parse", "https:///path", ""},
{"invalid url", "://bad url", ""},
{"only scheme", "https://", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractOrigin(tt.input)
assert.Equal(t, tt.want, got)
})
}
}