Files

309 lines
9.0 KiB
Go
Raw Permalink Normal View History

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func init() {
// cors_test 与 security_headers_test 在同一个包,但 init 是幂等的
gin.SetMode(gin.TestMode)
}
// --- Task 8.2: 验证 CORS 条件化头部 ---
func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
tests := []struct {
name string
method string
origin string
}{
{
name: "preflight_disallowed_origin",
method: http.MethodOptions,
origin: "https://evil.example.com",
},
{
name: "get_disallowed_origin",
method: http.MethodGet,
origin: "https://evil.example.com",
},
{
name: "post_disallowed_origin",
method: http.MethodPost,
origin: "https://attacker.example.com",
},
{
name: "preflight_no_origin",
method: http.MethodOptions,
origin: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(tt.method, "/", nil)
if tt.origin != "" {
c.Request.Header.Set("Origin", tt.origin)
}
middleware(c)
// 不应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"),
"不允许的 origin 不应收到 Allow-Headers")
assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"),
"不允许的 origin 不应收到 Allow-Methods")
assert.Empty(t, w.Header().Get("Access-Control-Max-Age"),
"不允许的 origin 不应收到 Max-Age")
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
"不允许的 origin 不应收到 Allow-Origin")
})
}
}
func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
tests := []struct {
name string
method string
}{
{name: "preflight_OPTIONS", method: http.MethodOptions},
{name: "normal_GET", method: http.MethodGet},
{name: "normal_POST", method: http.MethodPost},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(tt.method, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
// 应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
"允许的 origin 应收到 Allow-Headers")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
"允许的 origin 应收到 Allow-Methods")
assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"),
"允许的 origin 应收到 Max-Age=86400")
assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"),
"允许的 origin 应收到 Allow-Origin")
})
}
}
func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
middleware(c)
assert.Equal(t, http.StatusForbidden, w.Code,
"不允许的 origin 的 preflight 请求应返回 403")
}
func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Equal(t, http.StatusNoContent, w.Code,
"允许的 origin 的 preflight 请求应返回 204")
}
func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://any-origin.example.com")
middleware(c)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"),
"通配符配置应返回 *")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
"通配符 origin 应设置 Allow-Headers")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
"通配符 origin 应设置 Allow-Methods")
}
func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: true,
}
middleware := CORS(cfg)
t.Run("allowed_origin_gets_credentials", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"),
"允许的 origin 且开启 credentials 应设置 Allow-Credentials")
})
t.Run("disallowed_origin_no_credentials", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
middleware(c)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
"不允许的 origin 不应收到 Allow-Credentials")
})
}
func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://any.example.com")
middleware(c)
// 通配符 + credentials 不兼容credentials 应被禁用
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
"通配符 origin 应禁用 Allow-Credentials")
}
func TestCORS_MultipleAllowedOrigins(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{
"https://app1.example.com",
"https://app2.example.com",
},
AllowCredentials: false,
}
middleware := CORS(cfg)
t.Run("first_origin_allowed", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app1.example.com")
middleware(c)
assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
t.Run("second_origin_allowed", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app2.example.com")
middleware(c)
assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
t.Run("unlisted_origin_rejected", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app3.example.com")
middleware(c)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
}
func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Contains(t, w.Header().Values("Vary"), "Origin",
"非通配符允许的 origin 应设置 Vary: Origin")
}
func TestNormalizeOrigins(t *testing.T) {
tests := []struct {
name string
input []string
expect []string
}{
{name: "nil_input", input: nil, expect: nil},
{name: "empty_input", input: []string{}, expect: nil},
{name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}},
{name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := normalizeOrigins(tt.input)
assert.Equal(t, tt.expect, result)
})
}
}