fix(auth): harden pending oauth and backend mode flows

This commit is contained in:
IanShaw027
2026-04-22 12:30:00 +08:00
parent 1ffebbb568
commit 767f2f2dfe
12 changed files with 568 additions and 44 deletions

View File

@@ -678,6 +678,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
// 不影响登出流程
}
}
h.consumePendingOAuthSessionOnLogout(c)
clearOAuthLogoutCookies(c)
response.Success(c, LogoutResponse{
Message: "Logged out successfully",

View File

@@ -469,6 +469,15 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -757,6 +757,61 @@ func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *test
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-choice-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-choice-subject-1").
SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
t.Helper()
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)

View File

@@ -0,0 +1,68 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("logout-pending-session-token").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("logout-subject-123").
SetBrowserSessionKey("logout-browser-session-key").
SetResolvedEmail("logout@example.com").
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")})
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"})
req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")})
req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")})
req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")})
req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")})
ginCtx.Request = req
handler.Logout(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
cookies := recorder.Result().Cookies()
for _, name := range []string{
oauthPendingSessionCookieName,
oauthPendingBrowserCookieName,
oauthBindAccessTokenCookieName,
linuxDoOAuthStateCookieName,
oidcOAuthStateCookieName,
wechatOAuthStateCookieName,
wechatPaymentOAuthStateName,
} {
cookie := findCookie(cookies, name)
require.NotNil(t, cookie, name)
require.Equal(t, -1, cookie.MaxAge, name)
require.True(t, cookie.HttpOnly, name)
}
storedSession, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, storedSession.ConsumedAt)
}

View File

@@ -310,6 +310,78 @@ func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSes
return nil
}
func buildLegacyCompleteRegistrationPendingResponse(
session *dbent.PendingAuthSession,
forceEmailOnSignup bool,
emailVerificationRequired bool,
) map[string]any {
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{
"step": oauthPendingChoiceStep,
"adoption_required": true,
"create_account_allowed": true,
"force_email_on_signup": forceEmailOnSignup,
}))
if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
if _, exists := completionResponse["email"]; !exists {
completionResponse["email"] = email
}
if _, exists := completionResponse["resolved_email"]; !exists {
completionResponse["resolved_email"] = email
}
}
if _, exists := completionResponse["choice_reason"]; !exists {
switch {
case forceEmailOnSignup:
completionResponse["choice_reason"] = "force_email_on_signup"
case emailVerificationRequired:
completionResponse["choice_reason"] = "email_verification_required"
default:
completionResponse["choice_reason"] = "third_party_signup"
}
}
return completionResponse
}
func (h *AuthHandler) legacyCompleteRegistrationSessionStatus(
c *gin.Context,
session *dbent.PendingAuthSession,
) (*dbent.PendingAuthSession, bool, error) {
if session == nil {
return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
if step := pendingSessionStringValue(payload, "step"); step != "" {
return session, true, nil
}
emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context())
forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context())
if !emailVerificationRequired && !forceEmailOnSignup {
return session, false, nil
}
client := h.entClient()
if client == nil {
return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
updatedSession, err := updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
strings.TrimSpace(session.Intent),
strings.TrimSpace(session.ResolvedEmail),
nil,
buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired),
)
if err != nil {
return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
}
return updatedSession, true, nil
}
func (r oauthAdoptionDecisionRequest) hasDecision() bool {
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
}
@@ -1272,6 +1344,59 @@ func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.Au
return svc, session, clearCookies, nil
}
func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) {
if c == nil || c.Request == nil {
return
}
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
return
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil || strings.TrimSpace(browserSessionKey) == "" {
return
}
svc, err := h.pendingIdentityService()
if err != nil {
return
}
_, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
}
func clearOAuthLogoutCookies(c *gin.Context) {
secureCookie := isRequestHTTPS(c)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
clearOAuthBindAccessTokenCookie(c, secureCookie)
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
}
func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
payload := gin.H{
@@ -1451,6 +1576,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch")
return

View File

@@ -1228,6 +1228,26 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
require.Nil(t, storedSession.ConsumedAt)
}
func TestLogoutClearsPendingOAuthAndBindCookies(t *testing.T) {
handler, _ := newOAuthPendingFlowTestHandler(t, false)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue("pending-session-token")})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("pending-browser-key")})
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-token"})
ginCtx.Request = req
handler.Logout(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName).MaxAge)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingBrowserCookieName).MaxAge)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName).MaxAge)
}
func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
ctx := context.Background()

View File

@@ -374,19 +374,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
ProviderSubject: subject,
}
upstreamClaims := map[string]any{
"email": email,
"username": username,
"subject": subject,
"issuer": issuer,
"email_verified": emailVerified != nil && *emailVerified,
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
"email": email,
"username": username,
"subject": subject,
"issuer": issuer,
"email_verified": emailVerified != nil && *emailVerified,
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string {
if idClaims != nil {
return idClaims.Name
}
return ""
}(), username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
"suggested_avatar_url": userInfoClaims.AvatarURL,
}
if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
upstreamClaims["compat_email"] = compatEmail
@@ -622,6 +622,15 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -692,6 +692,62 @@ func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-choice-session").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-choice-subject-1").
SetResolvedEmail("oidc-choice-subject-1@oidc-connect.invalid").
SetBrowserSessionKey("oidc-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"issuer": "https://issuer.example.com",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-choice-browser")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
type oidcProviderFixture struct {
Subject string
PreferredUsername string

View File

@@ -525,6 +525,15 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -19,7 +19,6 @@ import (
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/repository"
@@ -700,7 +699,7 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes
require.Zero(t, count)
}
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
@@ -773,27 +772,32 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require.Equal(t, http.StatusOK, completeRecorder.Code)
responseData := decodeJSONBody(t, completeRecorder)
require.NotEmpty(t, responseData["access_token"])
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["adoption_required"])
require.Empty(t, responseData["access_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")).
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(pendingSession.ID)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "WeChat Display", userEntity.Username)
require.Nil(t, consumed.ConsumedAt)
identity, err := client.AuthIdentity.Query().
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ("wechat-main"),
authidentity.ProviderSubjectEQ("union-456"),
).
Only(ctx)
Count(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
require.Zero(t, identityCount)
channel, err := client.AuthIdentityChannel.Query().
channelCount, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ProviderKeyEQ("wechat-main"),
@@ -801,25 +805,15 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
authidentitychannel.ChannelSubjectEQ("openid-123"),
).
Only(ctx)
Count(ctx)
require.NoError(t, err)
require.Equal(t, identity.ID, channel.IdentityID)
require.Equal(t, "union-456", channel.Metadata["unionid"])
require.Zero(t, channelCount)
decision, err := client.IdentityAdoptionDecision.Query().
decisionCount, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
Only(ctx)
Count(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(pendingSession.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
require.Zero(t, decisionCount)
}
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
@@ -981,6 +975,62 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("wechat-complete-choice-session").
SetIntent("login").
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("wechat-choice-subject-1").
SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid").
SetBrowserSessionKey("wechat-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "wechat_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")})
completeCtx.Request = req
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL