2026-02-08 12:05:39 +08:00
|
|
|
|
//go:build unit
|
|
|
|
|
|
|
|
|
|
|
|
package middleware
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"encoding/json"
|
|
|
|
|
|
"errors"
|
|
|
|
|
|
"net/http"
|
|
|
|
|
|
"net/http/httptest"
|
|
|
|
|
|
"testing"
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// stubJWTUserRepo 实现 UserRepository 的最小子集,仅支持 GetByID。
|
|
|
|
|
|
type stubJWTUserRepo struct {
|
|
|
|
|
|
service.UserRepository
|
|
|
|
|
|
users map[int64]*service.User
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, error) {
|
|
|
|
|
|
u, ok := r.users[id]
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return nil, errors.New("user not found")
|
|
|
|
|
|
}
|
|
|
|
|
|
return u, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
|
|
|
|
|
|
// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。
|
|
|
|
|
|
func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
|
|
|
|
|
|
cfg := &config.Config{}
|
|
|
|
|
|
cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!"
|
|
|
|
|
|
cfg.JWT.AccessTokenExpireMinutes = 60
|
|
|
|
|
|
|
|
|
|
|
|
userRepo := &stubJWTUserRepo{users: users}
|
2026-03-09 13:13:39 +08:00
|
|
|
|
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
2026-02-08 12:05:39 +08:00
|
|
|
|
userSvc := service.NewUserService(userRepo, nil, nil)
|
|
|
|
|
|
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
|
|
|
|
|
|
|
|
|
|
|
r := gin.New()
|
|
|
|
|
|
r.Use(gin.HandlerFunc(mw))
|
|
|
|
|
|
r.GET("/protected", func(c *gin.Context) {
|
|
|
|
|
|
subject, _ := GetAuthSubjectFromContext(c)
|
|
|
|
|
|
role, _ := GetUserRoleFromContext(c)
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
|
|
|
|
"user_id": subject.UserID,
|
|
|
|
|
|
"role": role,
|
|
|
|
|
|
})
|
|
|
|
|
|
})
|
|
|
|
|
|
return r, authSvc
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestJWTAuth_ValidToken(t *testing.T) {
|
|
|
|
|
|
user := &service.User{
|
|
|
|
|
|
ID: 1,
|
|
|
|
|
|
Email: "test@example.com",
|
|
|
|
|
|
Role: "user",
|
|
|
|
|
|
Status: service.StatusActive,
|
|
|
|
|
|
Concurrency: 5,
|
|
|
|
|
|
TokenVersion: 1,
|
|
|
|
|
|
}
|
|
|
|
|
|
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
|
|
|
|
|
|
|
|
|
|
|
|
token, err := authSvc.GenerateToken(user)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
|
|
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
|
|
|
|
|
|
|
|
require.Equal(t, http.StatusOK, w.Code)
|
|
|
|
|
|
|
|
|
|
|
|
var body map[string]any
|
|
|
|
|
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
|
|
|
|
|
require.Equal(t, float64(1), body["user_id"])
|
|
|
|
|
|
require.Equal(t, "user", body["role"])
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-10 00:37:47 +08:00
|
|
|
|
func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
|
|
|
|
|
|
user := &service.User{
|
|
|
|
|
|
ID: 1,
|
|
|
|
|
|
Email: "test@example.com",
|
|
|
|
|
|
Role: "user",
|
|
|
|
|
|
Status: service.StatusActive,
|
|
|
|
|
|
Concurrency: 5,
|
|
|
|
|
|
TokenVersion: 1,
|
|
|
|
|
|
}
|
|
|
|
|
|
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
|
|
|
|
|
|
|
|
|
|
|
|
token, err := authSvc.GenerateToken(user)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
|
|
|
|
req.Header.Set("Authorization", "bearer "+token)
|
|
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
|
|
|
|
|
|
|
|
require.Equal(t, http.StatusOK, w.Code)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-08 12:05:39 +08:00
|
|
|
|
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
|
|
|
|
|
|
router, _ := newJWTTestEnv(nil)
|
|
|
|
|
|
|
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
|
|
|
|
|
|
|
|
require.Equal(t, http.StatusUnauthorized, w.Code)
|
|
|
|
|
|
var body ErrorResponse
|
|
|
|
|
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
|
|
|
|
|
require.Equal(t, "UNAUTHORIZED", body.Code)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestJWTAuth_InvalidHeaderFormat(t *testing.T) {
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
|
name string
|
|
|
|
|
|
header string
|
|
|
|
|
|
}{
|
|
|
|
|
|
{"无Bearer前缀", "Token abc123"},
|
|
|
|
|
|
{"缺少空格分隔", "Bearerabc123"},
|
|
|
|
|
|
{"仅有单词", "abc123"},
|
|
|
|
|
|
}
|
|
|
|
|
|
router, _ := newJWTTestEnv(nil)
|
|
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
|
|
|
|
req.Header.Set("Authorization", tt.header)
|
|
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
|
|
|
|
|
|
|
|
require.Equal(t, http.StatusUnauthorized, w.Code)
|
|
|
|
|
|
var body ErrorResponse
|
|
|
|
|
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
|
|
|
|
|
require.Equal(t, "INVALID_AUTH_HEADER", body.Code)
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestJWTAuth_EmptyToken(t *testing.T) {
|
|
|
|
|
|
router, _ := newJWTTestEnv(nil)
|
|
|
|
|
|
|
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
|
|
|
|
req.Header.Set("Authorization", "Bearer ")
|
|
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
|
|
|
|
|
|
|
|
require.Equal(t, http.StatusUnauthorized, w.Code)
|
|
|
|
|
|
var body ErrorResponse
|
|
|
|
|
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
|
|
|
|
|
require.Equal(t, "EMPTY_TOKEN", body.Code)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestJWTAuth_TamperedToken(t *testing.T) {
|
|
|
|
|
|
router, _ := newJWTTestEnv(nil)
|
|
|
|
|
|
|
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
|
|
|
|
req.Header.Set("Authorization", "Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.invalid_signature")
|
|
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
|
|
|
|
|
|
|
|
require.Equal(t, http.StatusUnauthorized, w.Code)
|
|
|
|
|
|
var body ErrorResponse
|
|
|
|
|
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
|
|
|
|
|
require.Equal(t, "INVALID_TOKEN", body.Code)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestJWTAuth_UserNotFound(t *testing.T) {
|
|
|
|
|
|
// 使用 user ID=1 的 token,但 repo 中没有该用户
|
|
|
|
|
|
fakeUser := &service.User{
|
|
|
|
|
|
ID: 999,
|
|
|
|
|
|
Email: "ghost@example.com",
|
|
|
|
|
|
Role: "user",
|
|
|
|
|
|
Status: service.StatusActive,
|
|
|
|
|
|
TokenVersion: 1,
|
|
|
|
|
|
}
|
|
|
|
|
|
// 创建环境时不注入此用户,这样 GetByID 会失败
|
|
|
|
|
|
router, authSvc := newJWTTestEnv(map[int64]*service.User{})
|
|
|
|
|
|
|
|
|
|
|
|
token, err := authSvc.GenerateToken(fakeUser)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
|
|
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
|
|
|
|
|
|
|
|
require.Equal(t, http.StatusUnauthorized, w.Code)
|
|
|
|
|
|
var body ErrorResponse
|
|
|
|
|
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
|
|
|
|
|
require.Equal(t, "USER_NOT_FOUND", body.Code)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestJWTAuth_UserInactive(t *testing.T) {
|
|
|
|
|
|
user := &service.User{
|
|
|
|
|
|
ID: 1,
|
|
|
|
|
|
Email: "disabled@example.com",
|
|
|
|
|
|
Role: "user",
|
|
|
|
|
|
Status: service.StatusDisabled,
|
|
|
|
|
|
TokenVersion: 1,
|
|
|
|
|
|
}
|
|
|
|
|
|
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
|
|
|
|
|
|
|
|
|
|
|
|
token, err := authSvc.GenerateToken(user)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
|
|
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
|
|
|
|
|
|
|
|
require.Equal(t, http.StatusUnauthorized, w.Code)
|
|
|
|
|
|
var body ErrorResponse
|
|
|
|
|
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
|
|
|
|
|
require.Equal(t, "USER_INACTIVE", body.Code)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func TestJWTAuth_TokenVersionMismatch(t *testing.T) {
|
|
|
|
|
|
// Token 生成时 TokenVersion=1,但数据库中用户已更新为 TokenVersion=2(密码修改)
|
|
|
|
|
|
userForToken := &service.User{
|
|
|
|
|
|
ID: 1,
|
|
|
|
|
|
Email: "test@example.com",
|
|
|
|
|
|
Role: "user",
|
|
|
|
|
|
Status: service.StatusActive,
|
|
|
|
|
|
TokenVersion: 1,
|
|
|
|
|
|
}
|
|
|
|
|
|
userInDB := &service.User{
|
|
|
|
|
|
ID: 1,
|
|
|
|
|
|
Email: "test@example.com",
|
|
|
|
|
|
Role: "user",
|
|
|
|
|
|
Status: service.StatusActive,
|
|
|
|
|
|
TokenVersion: 2, // 密码修改后版本递增
|
|
|
|
|
|
}
|
|
|
|
|
|
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: userInDB})
|
|
|
|
|
|
|
|
|
|
|
|
token, err := authSvc.GenerateToken(userForToken)
|
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
|
|
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
|
|
|
|
|
|
|
|
require.Equal(t, http.StatusUnauthorized, w.Code)
|
|
|
|
|
|
var body ErrorResponse
|
|
|
|
|
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
|
|
|
|
|
require.Equal(t, "TOKEN_REVOKED", body.Code)
|
|
|
|
|
|
}
|