feat: complete email binding and pending oauth verification flows

This commit is contained in:
IanShaw027
2026-04-21 10:00:06 +08:00
parent 6da08262d7
commit dcd5c43da4
29 changed files with 2117 additions and 107 deletions

View File

@@ -79,7 +79,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient) totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
userHandler := handler.NewUserHandler(userService, emailService, emailCache) userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db) usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)

View File

@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbuser "github.com/Wei-Shaw/sub2api/ent/user"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -27,7 +28,7 @@ import (
const ( const (
oauthPendingBrowserCookiePath = "/api/v1/auth/oauth" oauthPendingBrowserCookiePath = "/api/v1/auth/oauth"
oauthPendingBrowserCookieName = "oauth_pending_browser_session" oauthPendingBrowserCookieName = "oauth_pending_browser_session"
oauthPendingSessionCookiePath = "/api/v1/auth/oauth/pending" oauthPendingSessionCookiePath = "/api/v1/auth/oauth"
oauthPendingSessionCookieName = "oauth_pending_session" oauthPendingSessionCookieName = "oauth_pending_session"
oauthPendingCookieMaxAgeSec = 10 * 60 oauthPendingCookieMaxAgeSec = 10 * 60
@@ -66,6 +67,13 @@ type createPendingOAuthAccountRequest struct {
AdoptAvatar *bool `json:"adopt_avatar,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
} }
type sendPendingOAuthVerifyCodeRequest struct {
Email string `json:"email" binding:"required,email"`
TurnstileToken string `json:"turnstile_token,omitempty"`
PendingAuthToken string `json:"pending_auth_token,omitempty"`
PendingOAuthToken string `json:"pending_oauth_token,omitempty"`
}
func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest { func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
return oauthAdoptionDecisionRequest{ return oauthAdoptionDecisionRequest{
AdoptDisplayName: r.AdoptDisplayName, AdoptDisplayName: r.AdoptDisplayName,
@@ -448,6 +456,43 @@ func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
h.createPendingOAuthAccount(c, "") h.createPendingOAuthAccount(c, "")
} }
// SendPendingOAuthVerifyCode sends a verification code for a browser-bound
// pending OAuth account-creation flow.
// POST /api/v1/auth/oauth/pending/send-verify-code
func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
var req sendPendingOAuthVerifyCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
_, session, _, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, SendVerifyCodeResponse{
Message: "Verification code sent successfully",
Countdown: result.Countdown,
})
}
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
c *gin.Context, c *gin.Context,
sessionID int64, sessionID int64,
@@ -1084,6 +1129,41 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi
return payload return payload
} }
func (h *AuthHandler) transitionPendingOAuthAccountToBindLogin(
c *gin.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
email string,
decision oauthAdoptionDecisionRequest,
) (*dbent.PendingAuthSession, error) {
existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email)
if err != nil {
return nil, err
}
completionResponse := mergePendingCompletionResponse(session, map[string]any{
"step": "bind_login_required",
"email": email,
})
session, err = updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
"adopt_existing_user_by_email",
email,
&existingUser.ID,
completionResponse,
)
if err != nil {
return nil, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
}
if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decision); err != nil {
return nil, err
}
return session, nil
}
func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) { func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken, "access_token": tokenPair.AccessToken,
@@ -1199,29 +1279,11 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
return return
} }
if existingUser != nil { if existingUser != nil {
completionResponse := mergePendingCompletionResponse(session, map[string]any{ session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, req.adoptionDecision())
"step": "bind_login_required",
"email": email,
})
session, err = updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
"adopt_existing_user_by_email",
email,
&existingUser.ID,
completionResponse,
)
if err != nil { if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err))
return
}
if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return return
} }
@@ -1239,27 +1301,77 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
strings.TrimSpace(session.ProviderType), strings.TrimSpace(session.ProviderType),
) )
if err != nil { if err != nil {
response.ErrorFrom(c, err) if errors.Is(err, service.ErrEmailExists) {
return session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, req.adoptionDecision())
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return
}
response.ErrorFrom(c, err)
return
}
rollbackCreatedUser := func(originalErr error) bool {
if user == nil || user.ID <= 0 {
return false
}
if rollbackErr := h.authService.RollbackOAuthEmailAccountCreation(
c.Request.Context(),
user.ID,
strings.TrimSpace(req.InvitationCode),
); rollbackErr != nil {
response.ErrorFrom(c, infraerrors.InternalServer(
"PENDING_AUTH_ACCOUNT_ROLLBACK_FAILED",
"failed to rollback pending oauth account creation",
).WithCause(fmt.Errorf("original error: %w; rollback error: %v", originalErr, rollbackErr)))
return true
}
user = nil
return false
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil {
if rollbackCreatedUser(err) {
return
}
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil { if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
if rollbackCreatedUser(err) {
return
}
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return return
} }
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if err := h.authService.FinalizeOAuthEmailAccount(
c.Request.Context(),
user,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
); err != nil {
if rollbackCreatedUser(err) {
return
}
response.ErrorFrom(c, err)
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
if rollbackCreatedUser(err) {
return
}
clearCookies() clearCookies()
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearCookies() clearCookies()
writeOAuthTokenPairResponse(c, tokenPair) writeOAuthTokenPairResponse(c, tokenPair)
} }

View File

@@ -5,6 +5,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@@ -15,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -61,6 +63,18 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *
require.Equal(t, true, payload["adoption_required"]) require.Equal(t, true, payload["adoption_required"])
} }
func TestSetOAuthPendingSessionCookieUsesProviderCompletionPathPrefix(t *testing.T) {
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
ginCtx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
setOAuthPendingSessionCookie(ginCtx, "pending-session-token", false)
cookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, cookie)
require.Equal(t, "/api/v1/auth/oauth", cookie.Path)
}
func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) { func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false) handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background() ctx := context.Background()
@@ -943,6 +957,81 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
require.Nil(t, storedSession.ConsumedAt) require.Nil(t, storedSession.ConsumedAt)
} }
func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
ctx := context.Background()
conflictOwner, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(conflictOwner.ID).
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("oidc-conflict-123").
SetMetadata(map[string]any{
"username": "owner-user",
}).
Save(ctx)
require.NoError(t, err)
invitation, err := client.RedeemCode.Create().
SetCode("INVITE123").
SetType(service.RedeemTypeInvitation).
SetStatus(service.StatusUnused).
SetValue(0).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("create-account-conflict-session-token").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("oidc-conflict-123").
SetBrowserSessionKey("create-account-conflict-browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
}).
SetRedirectTo("/profile").
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","invitation_code":"INVITE123"}`)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-conflict-browser-session-key")})
ginCtx.Request = req
handler.CreateOIDCOAuthAccount(ginCtx)
require.Equal(t, http.StatusInternalServerError, recorder.Code)
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedInvitation, err := client.RedeemCode.Get(ctx, invitation.ID)
require.NoError(t, err)
require.Equal(t, service.StatusUnused, storedInvitation.Status)
require.Nil(t, storedInvitation.UsedBy)
require.Nil(t, storedInvitation.UsedAt)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false) handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background() ctx := context.Background()
@@ -1529,6 +1618,8 @@ type oauthPendingFlowTestHandlerOptions struct {
defaultSubAssigner service.DefaultSubscriptionAssigner defaultSubAssigner service.DefaultSubscriptionAssigner
totpCache service.TotpCache totpCache service.TotpCache
totpEncryptor service.SecretEncryptor totpEncryptor service.SecretEncryptor
redeemRepoFactory func(client *dbent.Client) service.RedeemCodeRepository
userRepoOptions oauthPendingFlowUserRepoOptions
} }
func newOAuthPendingFlowTestHandlerWithDependencies( func newOAuthPendingFlowTestHandlerWithDependencies(
@@ -1590,7 +1681,17 @@ CREATE TABLE IF NOT EXISTS user_avatars (
settingValues[key] = value settingValues[key] = value
} }
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg) settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
userRepo := &oauthPendingFlowUserRepo{client: client} userRepo := &oauthPendingFlowUserRepo{
client: client,
options: options.userRepoOptions,
}
redeemRepo := service.RedeemCodeRepository(nil)
if options.redeemRepoFactory != nil {
redeemRepo = options.redeemRepoFactory(client)
}
if redeemRepo == nil {
redeemRepo = &oauthPendingFlowRedeemCodeRepo{client: client}
}
var emailService *service.EmailService var emailService *service.EmailService
if options.emailCache != nil { if options.emailCache != nil {
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{ emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
@@ -1602,7 +1703,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
authSvc := service.NewAuthService( authSvc := service.NewAuthService(
client, client,
userRepo, userRepo,
nil, redeemRepo,
&oauthPendingFlowRefreshTokenCacheStub{}, &oauthPendingFlowRefreshTokenCacheStub{},
cfg, cfg,
settingSvc, settingSvc,
@@ -1797,6 +1898,127 @@ func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context,
return false, nil return false, nil
} }
type oauthPendingFlowRedeemCodeRepo struct {
client *dbent.Client
}
func (r *oauthPendingFlowRedeemCodeRepo) Create(context.Context, *service.RedeemCode) error {
panic("unexpected Create call")
}
func (r *oauthPendingFlowRedeemCodeRepo) CreateBatch(context.Context, []service.RedeemCode) error {
panic("unexpected CreateBatch call")
}
func (r *oauthPendingFlowRedeemCodeRepo) GetByID(context.Context, int64) (*service.RedeemCode, error) {
panic("unexpected GetByID call")
}
func (r *oauthPendingFlowRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
entity, err := r.client.RedeemCode.Query().Where(redeemcode.CodeEQ(code)).Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrRedeemCodeNotFound
}
return nil, err
}
notes := ""
if entity.Notes != nil {
notes = *entity.Notes
}
return &service.RedeemCode{
ID: entity.ID,
Code: entity.Code,
Type: entity.Type,
Value: entity.Value,
Status: entity.Status,
UsedBy: entity.UsedBy,
UsedAt: entity.UsedAt,
Notes: notes,
CreatedAt: entity.CreatedAt,
GroupID: entity.GroupID,
ValidityDays: entity.ValidityDays,
}, nil
}
func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error {
if code == nil {
return nil
}
update := r.client.RedeemCode.UpdateOneID(code.ID).
SetCode(code.Code).
SetType(code.Type).
SetValue(code.Value).
SetStatus(code.Status).
SetNotes(code.Notes).
SetValidityDays(code.ValidityDays)
if code.UsedBy != nil {
update = update.SetUsedBy(*code.UsedBy)
} else {
update = update.ClearUsedBy()
}
if code.UsedAt != nil {
update = update.SetUsedAt(*code.UsedAt)
} else {
update = update.ClearUsedAt()
}
if code.GroupID != nil {
update = update.SetGroupID(*code.GroupID)
} else {
update = update.ClearGroupID()
}
_, err := update.Save(ctx)
return err
}
func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error {
panic("unexpected Delete call")
}
func (r *oauthPendingFlowRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error {
affected, err := r.client.RedeemCode.Update().
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
SetStatus(service.StatusUsed).
SetUsedBy(userID).
SetUsedAt(time.Now().UTC()).
Save(ctx)
if err != nil {
return err
}
if affected == 0 {
return service.ErrRedeemCodeUsed
}
return nil
}
func (r *oauthPendingFlowRedeemCodeRepo) List(context.Context, pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (r *oauthPendingFlowRedeemCodeRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (r *oauthPendingFlowRedeemCodeRepo) ListByUser(context.Context, int64, int) ([]service.RedeemCode, error) {
panic("unexpected ListByUser call")
}
func (r *oauthPendingFlowRedeemCodeRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListByUserPaginated call")
}
func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
panic("unexpected SumPositiveBalanceByUser call")
}
type oauthPendingFlowFailingUseRedeemRepo struct {
*oauthPendingFlowRedeemCodeRepo
}
func (r *oauthPendingFlowFailingUseRedeemRepo) Use(context.Context, int64, int64) error {
return errors.New("forced invitation use failure")
}
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
t.Helper() t.Helper()
@@ -1872,6 +2094,11 @@ func countProviderGrantRecords(
type oauthPendingFlowUserRepo struct { type oauthPendingFlowUserRepo struct {
client *dbent.Client client *dbent.Client
options oauthPendingFlowUserRepoOptions
}
type oauthPendingFlowUserRepoOptions struct {
rejectDeleteWhileAuthIdentityExists bool
} }
func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error { func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
@@ -1953,6 +2180,15 @@ func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.Use
} }
func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error { func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
if r.options.rejectDeleteWhileAuthIdentityExists {
count, err := r.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(id)).Count(ctx)
if err != nil {
return err
}
if count > 0 {
return errors.New("cannot delete user while auth identities still exist")
}
}
return r.client.User.DeleteOneID(id).Exec(ctx) return r.client.User.DeleteOneID(id).Exec(ctx)
} }

View File

@@ -15,14 +15,21 @@ import (
// UserHandler handles user-related requests // UserHandler handles user-related requests
type UserHandler struct { type UserHandler struct {
userService *service.UserService userService *service.UserService
authService *service.AuthService
emailService *service.EmailService emailService *service.EmailService
emailCache service.EmailCache emailCache service.EmailCache
} }
// NewUserHandler creates a new UserHandler // NewUserHandler creates a new UserHandler
func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler { func NewUserHandler(
userService *service.UserService,
authService *service.AuthService,
emailService *service.EmailService,
emailCache service.EmailCache,
) *UserHandler {
return &UserHandler{ return &UserHandler{
userService: userService, userService: userService,
authService: authService,
emailService: emailService, emailService: emailService,
emailCache: emailCache, emailCache: emailCache,
} }
@@ -157,6 +164,16 @@ type StartIdentityBindingRequest struct {
RedirectTo string `json:"redirect_to"` RedirectTo string `json:"redirect_to"`
} }
type BindEmailIdentityRequest struct {
Email string `json:"email" binding:"required,email"`
VerifyCode string `json:"verify_code" binding:"required"`
Password string `json:"password" binding:"required,min=6"`
}
type SendEmailBindingCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
// StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow. // StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow.
// POST /api/v1/user/auth-identities/bind/start // POST /api/v1/user/auth-identities/bind/start
func (h *UserHandler) StartIdentityBinding(c *gin.Context) { func (h *UserHandler) StartIdentityBinding(c *gin.Context) {
@@ -183,6 +200,73 @@ func (h *UserHandler) StartIdentityBinding(c *gin.Context) {
response.Success(c, result) response.Success(c, result)
} }
// BindEmailIdentity verifies and binds a local email identity for the current user.
// POST /api/v1/user/account-bindings/email
func (h *UserHandler) BindEmailIdentity(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if h.authService == nil {
response.InternalError(c, "Auth service not configured")
return
}
var req BindEmailIdentityRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
updatedUser, err := h.authService.BindEmailIdentity(
c.Request.Context(),
subject.UserID,
req.Email,
req.VerifyCode,
req.Password,
)
if err != nil {
response.ErrorFrom(c, err)
return
}
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profileResp)
}
// SendEmailBindingCode sends a verification code for the current user's email binding flow.
// POST /api/v1/user/account-bindings/email/send-code
func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if h.authService == nil {
response.InternalError(c, "Auth service not configured")
return
}
var req SendEmailBindingCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Verification code sent successfully"})
}
// SendNotifyEmailCodeRequest represents the request to send notify email verification code // SendNotifyEmailCodeRequest represents the request to send notify email verification code
type SendNotifyEmailCodeRequest struct { type SendNotifyEmailCodeRequest struct {
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`

View File

@@ -11,6 +11,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -122,7 +123,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
Status: service.StatusActive, Status: service.StatusActive,
}, },
} }
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`) body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -180,7 +181,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
}, },
}, },
} }
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder) c, _ := gin.CreateTestContext(recorder)
@@ -262,7 +263,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
}, },
}, },
} }
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder) c, _ := gin.CreateTestContext(recorder)
@@ -311,6 +312,116 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
require.Equal(t, "linuxdo", usernameSource["source"]) require.Equal(t, "linuxdo", usernameSource["source"])
} }
type userHandlerEmailCacheStub struct {
data *service.VerificationCodeData
}
func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
return s.data, nil
}
func (s *userHandlerEmailCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
return nil
}
func (s *userHandlerEmailCacheStub) DeleteVerificationCode(context.Context, string) error {
return nil
}
func (s *userHandlerEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
return nil, nil
}
func (s *userHandlerEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
return nil
}
func (s *userHandlerEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
return nil
}
func (s *userHandlerEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
return nil, nil
}
func (s *userHandlerEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
return nil
}
func (s *userHandlerEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
return nil
}
func (s *userHandlerEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
return false
}
func (s *userHandlerEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
return nil
}
func (s *userHandlerEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *userHandlerEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
return 0, nil
}
func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 11,
Email: "legacy-user" + service.LinuxDoConnectSyntheticEmailDomain,
Username: "legacy-user",
Role: service.RoleUser,
Status: service.StatusActive,
},
}
emailCache := &userHandlerEmailCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
c.Params = gin.Params{{Key: "provider", Value: "email"}}
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
handler.BindEmailIdentity(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Email string `json:"email"`
EmailBound bool `json:"email_bound"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, "new@example.com", resp.Data.Email)
require.True(t, resp.Data.EmailBound)
}
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
@@ -323,7 +434,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
Status: service.StatusActive, Status: service.StatusActive,
}, },
} }
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`) body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()

View File

@@ -74,6 +74,12 @@ func RegisterAuthRoutes(
}), }),
h.Auth.ExchangePendingOAuthCompletion, h.Auth.ExchangePendingOAuthCompletion,
) )
auth.POST("/oauth/pending/send-verify-code",
rateLimiter.LimitWithOptions("oauth-pending-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.SendPendingOAuthVerifyCode,
)
auth.POST("/oauth/pending/create-account", auth.POST("/oauth/pending/create-account",
rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{ rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,

View File

@@ -52,6 +52,7 @@ func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
"/api/v1/auth/login", "/api/v1/auth/login",
"/api/v1/auth/login/2fa", "/api/v1/auth/login/2fa",
"/api/v1/auth/send-verify-code", "/api/v1/auth/send-verify-code",
"/api/v1/auth/oauth/pending/send-verify-code",
} }
for _, path := range paths { for _, path := range paths {

View File

@@ -25,6 +25,8 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile) user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword) user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile) user.PUT("", h.User.UpdateProfile)
user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding) user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
// 通知邮箱管理 // 通知邮箱管理

View File

@@ -0,0 +1,128 @@
package service
import (
"context"
"errors"
"fmt"
"net/mail"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// BindEmailIdentity verifies and binds a local email/password identity to the current user.
func (s *AuthService) BindEmailIdentity(
ctx context.Context,
userID int64,
email string,
verifyCode string,
password string,
) (*User, error) {
if s == nil {
return nil, ErrServiceUnavailable
}
normalizedEmail, err := normalizeEmailForIdentityBinding(email)
if err != nil {
return nil, err
}
if isReservedEmail(normalizedEmail) {
return nil, ErrEmailReserved
}
if strings.TrimSpace(password) == "" {
return nil, ErrPasswordRequired
}
if err := s.VerifyOAuthEmailCode(ctx, normalizedEmail, verifyCode); err != nil {
return nil, err
}
currentUser, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
switch {
case err == nil && existingUser != nil && existingUser.ID != userID:
return nil, ErrEmailExists
case err != nil && !errors.Is(err, ErrUserNotFound):
return nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(password)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
currentUser.Email = normalizedEmail
currentUser.PasswordHash = hashedPassword
if err := s.userRepo.Update(ctx, currentUser); err != nil {
if errors.Is(err, ErrEmailExists) {
return nil, ErrEmailExists
}
return nil, ErrServiceUnavailable
}
if firstRealEmailBind {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, userID, "email"); err != nil {
return nil, fmt.Errorf("apply email first bind defaults: %w", err)
}
}
return currentUser, nil
}
// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error {
if s == nil {
return ErrServiceUnavailable
}
normalizedEmail, err := normalizeEmailForIdentityBinding(email)
if err != nil {
return err
}
if isReservedEmail(normalizedEmail) {
return ErrEmailReserved
}
if s.emailService == nil {
return ErrServiceUnavailable
}
if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
if errors.Is(err, ErrUserNotFound) {
return ErrUserNotFound
}
return ErrServiceUnavailable
}
existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
switch {
case err == nil && existingUser != nil && existingUser.ID != userID:
return ErrEmailExists
case err != nil && !errors.Is(err, ErrUserNotFound):
return ErrServiceUnavailable
}
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName)
}
func normalizeEmailForIdentityBinding(email string) (string, error) {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" || len(normalized) > 255 {
return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(normalized); err != nil {
return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
return normalized, nil
}
func hasBindableEmailIdentitySubject(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return normalized != "" && !isReservedEmail(normalized)
}

View File

@@ -4,9 +4,71 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net/mail"
"strings" "strings"
"time"
) )
func normalizeOAuthSignupSource(signupSource string) string {
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" {
return "email"
}
return signupSource
}
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
// account-creation flows without relying on the public registration gate.
func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
email = strings.TrimSpace(strings.ToLower(email))
if email == "" {
return nil, ErrEmailVerifyRequired
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, ErrEmailVerifyRequired
}
if isReservedEmail(email) {
return nil, ErrEmailReserved
}
if s == nil || s.emailService == nil {
return nil, ErrServiceUnavailable
}
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil {
return nil, err
}
return &SendVerifyCodeResult{
Countdown: int(verifyCodeCooldown / time.Second),
}, nil
}
func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
return nil, nil
}
if s.redeemRepo == nil {
return nil, ErrServiceUnavailable
}
invitationCode = strings.TrimSpace(invitationCode)
if invitationCode == "" {
return nil, ErrInvitationCodeRequired
}
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
return nil, ErrInvitationCodeInvalid
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
return nil, ErrInvitationCodeInvalid
}
return redeemCode, nil
}
// VerifyOAuthEmailCode verifies the locally entered email verification code for // VerifyOAuthEmailCode verifies the locally entered email verification code for
// third-party signup and binding flows. This is intentionally independent from // third-party signup and binding flows. This is intentionally independent from
// the global registration email verification toggle. // the global registration email verification toggle.
@@ -54,19 +116,8 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, err return nil, nil, err
} }
var invitationRedeemCode *RedeemCode if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
if s.settingService.IsInvitationCodeEnabled(ctx) { return nil, nil, err
if invitationCode == "" {
return nil, nil, ErrInvitationCodeRequired
}
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
return nil, nil, ErrInvitationCodeInvalid
}
invitationRedeemCode = redeemCode
} }
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
@@ -104,22 +155,91 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, ErrServiceUnavailable return nil, nil, ErrServiceUnavailable
} }
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "") tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil { if err != nil {
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
return nil, nil, fmt.Errorf("generate token pair: %w", err) return nil, nil, fmt.Errorf("generate token pair: %w", err)
} }
return tokenPair, user, nil return tokenPair, user, nil
} }
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
// only after the pending OAuth flow has fully reached its last reversible step.
func (s *AuthService) FinalizeOAuthEmailAccount(
ctx context.Context,
user *User,
invitationCode string,
signupSource string,
) error {
if s == nil || user == nil || user.ID <= 0 {
return ErrServiceUnavailable
}
signupSource = normalizeOAuthSignupSource(signupSource)
invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
if err != nil {
return err
}
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return ErrInvitationCodeInvalid
}
}
s.postAuthUserBootstrap(ctx, user, signupSource, false)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
return nil
}
// RollbackOAuthEmailAccountCreation removes a partially-created local account
// and restores any invitation code already consumed by that account.
func (s *AuthService) RollbackOAuthEmailAccountCreation(ctx context.Context, userID int64, invitationCode string) error {
if s == nil || s.userRepo == nil || userID <= 0 {
return ErrServiceUnavailable
}
if err := s.restoreOAuthRegistrationInvitation(ctx, invitationCode, userID); err != nil {
return err
}
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete created oauth user: %w", err)
}
return nil
}
func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, invitationCode string, userID int64) error {
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
return nil
}
if s.redeemRepo == nil {
return ErrServiceUnavailable
}
invitationCode = strings.TrimSpace(invitationCode)
if invitationCode == "" || userID <= 0 {
return nil
}
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
if errors.Is(err, ErrRedeemCodeNotFound) {
return nil
}
return fmt.Errorf("load invitation code: %w", err)
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUsed || redeemCode.UsedBy == nil || *redeemCode.UsedBy != userID {
return nil
}
redeemCode.Status = StatusUnused
redeemCode.UsedBy = nil
redeemCode.UsedAt = nil
if err := s.redeemRepo.Update(ctx, redeemCode); err != nil {
return fmt.Errorf("restore invitation code: %w", err)
}
return nil
}
// ValidatePasswordCredentials checks the local password without completing the // ValidatePasswordCredentials checks the local password without completing the
// login flow. This is used by pending third-party account adoption flows before // login flow. This is used by pending third-party account adoption flows before
// the external identity has been bound. // the external identity has been bound.

View File

@@ -0,0 +1,251 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type redeemCodeRepoStub struct {
codesByCode map[string]*RedeemCode
useCalls []struct {
id int64
userID int64
}
updateCalls []*RedeemCode
}
func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error {
panic("unexpected Create call")
}
func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error {
panic("unexpected CreateBatch call")
}
func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) {
panic("unexpected GetByID call")
}
func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
if s.codesByCode == nil {
return nil, ErrRedeemCodeNotFound
}
redeemCode, ok := s.codesByCode[code]
if !ok {
return nil, ErrRedeemCodeNotFound
}
cloned := *redeemCode
return &cloned, nil
}
func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error {
if code == nil {
return nil
}
cloned := *code
s.updateCalls = append(s.updateCalls, &cloned)
if s.codesByCode == nil {
s.codesByCode = make(map[string]*RedeemCode)
}
s.codesByCode[cloned.Code] = &cloned
return nil
}
func (s *redeemCodeRepoStub) Delete(context.Context, int64) error {
panic("unexpected Delete call")
}
func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error {
for code, redeemCode := range s.codesByCode {
if redeemCode.ID != id {
continue
}
now := time.Now().UTC()
redeemCode.Status = StatusUsed
redeemCode.UsedBy = &userID
redeemCode.UsedAt = &now
s.codesByCode[code] = redeemCode
s.useCalls = append(s.useCalls, struct {
id int64
userID int64
}{id: id, userID: userID})
return nil
}
return ErrRedeemCodeNotFound
}
func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
panic("unexpected ListByUser call")
}
func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListByUserPaginated call")
}
func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
panic("unexpected SumPositiveBalanceByUser call")
}
func newOAuthEmailFlowAuthService(
userRepo UserRepository,
redeemRepo RedeemCodeRepository,
refreshTokenCache RefreshTokenCache,
settings map[string]string,
emailCache EmailCache,
) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache)
return NewAuthService(
nil,
userRepo,
redeemRepo,
refreshTokenCache,
cfg,
settingService,
emailService,
nil,
nil,
nil,
nil,
)
}
func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) {
userRepo := &userRepoStub{nextID: 42}
redeemRepo := &redeemCodeRepoStub{
codesByCode: map[string]*RedeemCode{
"INVITE123": {
ID: 7,
Code: "INVITE123",
Type: RedeemTypeInvitation,
Status: StatusUnused,
},
},
}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
redeemRepo,
nil,
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyInvitationCodeEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
"fresh@example.com",
"secret-123",
"246810",
"INVITE123",
"oidc",
)
require.Nil(t, tokenPair)
require.Nil(t, user)
require.Error(t, err)
require.Contains(t, err.Error(), "generate token pair")
require.Equal(t, []int64{42}, userRepo.deletedIDs)
require.Len(t, userRepo.created, 1)
require.Empty(t, redeemRepo.useCalls)
require.Empty(t, redeemRepo.updateCalls)
}
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
userRepo := &userRepoStub{}
redeemRepo := &redeemCodeRepoStub{
codesByCode: map[string]*RedeemCode{
"INVITE123": {
ID: 7,
Code: "INVITE123",
Type: RedeemTypeInvitation,
Status: StatusUsed,
UsedBy: func() *int64 {
v := int64(42)
return &v
}(),
UsedAt: func() *time.Time {
v := time.Now().UTC()
return &v
}(),
},
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
redeemRepo,
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyInvitationCodeEnabled: "true",
},
&emailCacheStub{},
)
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123")
require.NoError(t, err)
require.Equal(t, []int64{42}, userRepo.deletedIDs)
require.Len(t, redeemRepo.updateCalls, 1)
require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status)
require.Nil(t, redeemRepo.updateCalls[0].UsedBy)
require.Nil(t, redeemRepo.updateCalls[0].UsedAt)
}
func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
userRepo := &userRepoStub{deleteErr: errors.New("delete failed")}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
},
&emailCacheStub{},
)
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "")
require.Error(t, err)
require.Contains(t, err.Error(), "delete created oauth user")
}

View File

@@ -0,0 +1,316 @@
//go:build unit
package service_test
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
type emailBindDefaultSubAssignerStub struct {
calls []*service.AssignSubscriptionInput
}
func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
_ context.Context,
input *service.AssignSubscriptionInput,
) (*service.UserSubscription, bool, error) {
cloned := *input
s.calls = append(s.calls, &cloned)
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
func newAuthServiceForEmailBind(
t *testing.T,
settings map[string]string,
emailCache service.EmailCache,
defaultSubAssigner service.DefaultSubscriptionAssigner,
) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_service_email_bind?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS user_provider_default_grants (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
provider_type TEXT NOT NULL,
grant_reason TEXT NOT NULL DEFAULT 'first_bind',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, provider_type, grant_reason)
)`)
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
repo := repository.NewUserRepository(client, db)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-bind-email-secret",
ExpireHour: 1,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
settingRepo := &emailBindSettingRepoStub{values: settings}
settingSvc := service.NewSettingService(settingRepo, cfg)
var emailSvc *service.EmailService
if emailCache != nil {
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
return svc, repo, client
}
func TestAuthServiceBindEmailIdentity_UpdatesEmailAndAppliesFirstBindDefaults(t *testing.T) {
assigner := &emailBindDefaultSubAssignerStub{}
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, cache, assigner)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("legacy-user" + service.LinuxDoConnectSyntheticEmailDomain).
SetUsername("legacy-user").
SetPasswordHash("old-hash").
SetBalance(2.5).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, " NewEmail@Example.com ", "123456", "new-password")
require.NoError(t, err)
require.NotNil(t, updatedUser)
require.Equal(t, "newemail@example.com", updatedUser.Email)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, "newemail@example.com", storedUser.Email)
require.Equal(t, 11.0, storedUser.Balance)
require.Equal(t, 5, storedUser.Concurrency)
require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("newemail@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
require.Len(t, assigner.calls, 1)
require.Equal(t, user.ID, assigner.calls[0].UserID)
require.Equal(t, int64(11), assigner.calls[0].GroupID)
require.Equal(t, 30, assigner.calls[0].ValidityDays)
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testing.T) {
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
ctx := context.Background()
sourceUser, err := client.User.Create().
SetEmail("source-user" + service.OIDCConnectSyntheticEmailDomain).
SetUsername("source-user").
SetPasswordHash("old-hash").
SetBalance(1).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.User.Create().
SetEmail("taken@example.com").
SetUsername("taken-user").
SetPasswordHash("hash").
SetBalance(1).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, sourceUser.ID, "taken@example.com", "123456", "new-password")
require.ErrorIs(t, err, service.ErrEmailExists)
require.Nil(t, updatedUser)
storedUser, err := client.User.Get(ctx, sourceUser.ID)
require.NoError(t, err)
require.Equal(t, "source-user"+service.OIDCConnectSyntheticEmailDomain, storedUser.Email)
require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
}
func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("source-user@example.com").
SetUsername("source-user").
SetPasswordHash("old-hash").
SetBalance(1).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "reserved"+service.LinuxDoConnectSyntheticEmailDomain, "123456", "new-password")
require.ErrorIs(t, err, service.ErrEmailReserved)
require.Nil(t, updatedUser)
}
type emailBindSettingRepoStub struct {
values map[string]string
}
func (s *emailBindSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *emailBindSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", service.ErrSettingNotFound
}
func (s *emailBindSettingRepoStub) Set(context.Context, string, string) error {
panic("unexpected Set call")
}
func (s *emailBindSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if v, ok := s.values[key]; ok {
out[key] = v
}
}
return out, nil
}
func (s *emailBindSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *emailBindSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *emailBindSettingRepoStub) Delete(context.Context, string) error {
panic("unexpected Delete call")
}
type emailBindCacheStub struct {
data *service.VerificationCodeData
err error
}
func (s *emailBindCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
}
return s.data, nil
}
func (s *emailBindCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) DeleteVerificationCode(context.Context, string) error {
return nil
}
func (s *emailBindCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
return nil, nil
}
func (s *emailBindCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
return nil
}
func (s *emailBindCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
return nil, nil
}
func (s *emailBindCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) DeletePasswordResetToken(context.Context, string) error {
return nil
}
func (s *emailBindCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
return false
}
func (s *emailBindCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
return 0, nil
}

View File

@@ -449,6 +449,16 @@ export async function sendVerifyCode(
return data return data
} }
export async function sendPendingOAuthVerifyCode(
request: SendVerifyCodeRequest
): Promise<SendVerifyCodeResponse> {
const { data } = await apiClient.post<SendVerifyCodeResponse>(
'/auth/oauth/pending/send-verify-code',
request
)
return data
}
/** /**
* Validate promo code response * Validate promo code response
*/ */
@@ -638,6 +648,7 @@ export const authAPI = {
clearAuthToken, clearAuthToken,
getPublicSettings, getPublicSettings,
sendVerifyCode, sendVerifyCode,
sendPendingOAuthVerifyCode,
validatePromoCode, validatePromoCode,
validateInvitationCode, validateInvitationCode,
forgotPassword, forgotPassword,

View File

@@ -89,6 +89,19 @@ export async function toggleNotifyEmail(email: string, disabled: boolean): Promi
return data return data
} }
export async function sendEmailBindingCode(email: string): Promise<void> {
await apiClient.post('/user/account-bindings/email/send-code', { email })
}
export async function bindEmailIdentity(payload: {
email: string
verify_code: string
password: string
}): Promise<User> {
const { data } = await apiClient.post<User>('/user/account-bindings/email', payload)
return data
}
export type BindableOAuthProvider = Exclude<UserAuthProvider, 'email'> export type BindableOAuthProvider = Exclude<UserAuthProvider, 'email'>
interface BuildOAuthBindingStartURLOptions { interface BuildOAuthBindingStartURLOptions {
@@ -158,6 +171,8 @@ export const userAPI = {
verifyNotifyEmail, verifyNotifyEmail,
removeNotifyEmail, removeNotifyEmail,
toggleNotifyEmail, toggleNotifyEmail,
sendEmailBindingCode,
bindEmailIdentity,
buildOAuthBindingStartURL, buildOAuthBindingStartURL,
startOAuthBinding startOAuthBinding
} }

View File

@@ -58,11 +58,20 @@
<p v-else class="text-xs text-gray-500 dark:text-dark-400"> <p v-else class="text-xs text-gray-500 dark:text-dark-400">
{{ t('auth.verificationCodeHint') }} {{ t('auth.verificationCodeHint') }}
</p> </p>
<input
v-if="invitationCodeEnabled"
v-model="invitationCode"
:data-testid="`${testIdPrefix}-create-account-invitation-code`"
type="text"
class="input w-full"
:placeholder="t('auth.invitationCodePlaceholder')"
:disabled="isSubmitting"
/>
<button <button
:data-testid="`${testIdPrefix}-create-account-submit`" :data-testid="`${testIdPrefix}-create-account-submit`"
type="button" type="button"
class="btn btn-primary w-full" class="btn btn-primary w-full"
:disabled="isSubmitting || !email.trim() || password.length < 6" :disabled="isSubmitting || !email.trim() || password.length < 6 || (invitationCodeEnabled && !invitationCode.trim())"
@click="handleSubmit" @click="handleSubmit"
> >
{{ isSubmitting ? t('common.processing') : 'Create account' }} {{ isSubmitting ? t('common.processing') : 'Create account' }}
@@ -92,12 +101,13 @@
import { onMounted, onUnmounted, ref, watch } from 'vue' import { onMounted, onUnmounted, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import TurnstileWidget from '@/components/TurnstileWidget.vue' import TurnstileWidget from '@/components/TurnstileWidget.vue'
import { getPublicSettings, sendVerifyCode } from '@/api/auth' import { getPublicSettings, sendPendingOAuthVerifyCode } from '@/api/auth'
export type PendingOAuthCreateAccountPayload = { export type PendingOAuthCreateAccountPayload = {
email: string email: string
password: string password: string
verifyCode: string verifyCode: string
invitationCode?: string
} }
const props = defineProps<{ const props = defineProps<{
@@ -117,10 +127,12 @@ const { t } = useI18n()
const email = ref('') const email = ref('')
const password = ref('') const password = ref('')
const verifyCode = ref('') const verifyCode = ref('')
const invitationCode = ref('')
const isSendingCode = ref(false) const isSendingCode = ref(false)
const sendCodeError = ref('') const sendCodeError = ref('')
const sendCodeSuccess = ref(false) const sendCodeSuccess = ref(false)
const countdown = ref(0) const countdown = ref(0)
const invitationCodeEnabled = ref(false)
const turnstileEnabled = ref(false) const turnstileEnabled = ref(false)
const turnstileSiteKey = ref('') const turnstileSiteKey = ref('')
const turnstileToken = ref('') const turnstileToken = ref('')
@@ -203,7 +215,7 @@ async function handleSendCode() {
sendCodeSuccess.value = false sendCodeSuccess.value = false
try { try {
const response = await sendVerifyCode({ const response = await sendPendingOAuthVerifyCode({
email: trimmedEmail, email: trimmedEmail,
turnstile_token: turnstileEnabled.value ? turnstileToken.value : undefined turnstile_token: turnstileEnabled.value ? turnstileToken.value : undefined
}) })
@@ -228,7 +240,8 @@ function handleSubmit() {
emit('submit', { emit('submit', {
email: trimmedEmail, email: trimmedEmail,
password: password.value, password: password.value,
verifyCode: verifyCode.value.trim() verifyCode: verifyCode.value.trim(),
invitationCode: invitationCode.value.trim() || undefined
}) })
} }
@@ -239,9 +252,11 @@ function emitSwitchToBind() {
onMounted(async () => { onMounted(async () => {
try { try {
const settings = await getPublicSettings() const settings = await getPublicSettings()
invitationCodeEnabled.value = settings.invitation_code_enabled === true
turnstileEnabled.value = settings.turnstile_enabled === true turnstileEnabled.value = settings.turnstile_enabled === true
turnstileSiteKey.value = settings.turnstile_site_key || '' turnstileSiteKey.value = settings.turnstile_site_key || ''
} catch { } catch {
invitationCodeEnabled.value = false
turnstileEnabled.value = false turnstileEnabled.value = false
turnstileSiteKey.value = '' turnstileSiteKey.value = ''
} }

View File

@@ -4,6 +4,7 @@ import { flushPromises, mount } from '@vue/test-utils'
import PendingOAuthCreateAccountForm from '../PendingOAuthCreateAccountForm.vue' import PendingOAuthCreateAccountForm from '../PendingOAuthCreateAccountForm.vue'
const sendVerifyCode = vi.fn() const sendVerifyCode = vi.fn()
const sendPendingOAuthVerifyCode = vi.fn()
const getPublicSettings = vi.fn() const getPublicSettings = vi.fn()
vi.mock('vue-i18n', async () => { vi.mock('vue-i18n', async () => {
@@ -21,6 +22,7 @@ vi.mock('@/api/auth', async () => {
return { return {
...actual, ...actual,
sendVerifyCode: (...args: any[]) => sendVerifyCode(...args), sendVerifyCode: (...args: any[]) => sendVerifyCode(...args),
sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args),
getPublicSettings: (...args: any[]) => getPublicSettings(...args) getPublicSettings: (...args: any[]) => getPublicSettings(...args)
} }
}) })
@@ -28,6 +30,7 @@ vi.mock('@/api/auth', async () => {
describe('PendingOAuthCreateAccountForm', () => { describe('PendingOAuthCreateAccountForm', () => {
beforeEach(() => { beforeEach(() => {
sendVerifyCode.mockReset() sendVerifyCode.mockReset()
sendPendingOAuthVerifyCode.mockReset()
getPublicSettings.mockReset() getPublicSettings.mockReset()
getPublicSettings.mockResolvedValue({ getPublicSettings.mockResolvedValue({
turnstile_enabled: false, turnstile_enabled: false,
@@ -61,8 +64,42 @@ describe('PendingOAuthCreateAccountForm', () => {
]) ])
}) })
it('shows and emits invitation code when invitation-only signup is enabled', async () => {
getPublicSettings.mockResolvedValue({
invitation_code_enabled: true,
turnstile_enabled: false,
turnstile_site_key: ''
})
const wrapper = mount(PendingOAuthCreateAccountForm, {
props: {
providerName: 'LinuxDo',
testIdPrefix: 'linuxdo',
initialEmail: 'prefill@example.com',
isSubmitting: false
}
})
await flushPromises()
await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810')
await wrapper.get('[data-testid="linuxdo-create-account-invitation-code"]').setValue(' INVITE123 ')
await wrapper.get('form').trigger('submit.prevent')
expect(wrapper.emitted('submit')).toEqual([
[
{
email: 'prefill@example.com',
password: 'secret-123',
verifyCode: '246810',
invitationCode: 'INVITE123'
}
]
])
})
it('sends a verify code for the trimmed email value', async () => { it('sends a verify code for the trimmed email value', async () => {
sendVerifyCode.mockResolvedValue({ sendPendingOAuthVerifyCode.mockResolvedValue({
message: 'sent', message: 'sent',
countdown: 60 countdown: 60
}) })
@@ -80,7 +117,7 @@ describe('PendingOAuthCreateAccountForm', () => {
await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click') await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
await flushPromises() await flushPromises()
expect(sendVerifyCode).toHaveBeenCalledWith({ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
email: 'user@example.com' email: 'user@example.com'
}) })
}) })
@@ -90,7 +127,7 @@ describe('PendingOAuthCreateAccountForm', () => {
turnstile_enabled: true, turnstile_enabled: true,
turnstile_site_key: 'site-key' turnstile_site_key: 'site-key'
}) })
sendVerifyCode.mockResolvedValue({ sendPendingOAuthVerifyCode.mockResolvedValue({
message: 'sent', message: 'sent',
countdown: 60 countdown: 60
}) })
@@ -120,7 +157,7 @@ describe('PendingOAuthCreateAccountForm', () => {
await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click') await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
await flushPromises() await flushPromises()
expect(sendVerifyCode).toHaveBeenCalledWith({ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
email: 'user@example.com', email: 'user@example.com',
turnstile_token: 'turnstile-token' turnstile_token: 'turnstile-token'
}) })

View File

@@ -13,15 +13,14 @@
<div <div
v-for="item in providerItems" v-for="item in providerItems"
:key="item.provider" :key="item.provider"
class="flex items-center justify-between gap-3 rounded-xl bg-white/80 px-3 py-2.5 dark:bg-dark-800/70" class="rounded-xl bg-white/80 px-3 py-3 dark:bg-dark-800/70"
> >
<div class="min-w-0"> <div class="flex flex-col gap-3 sm:flex-row sm:items-start sm:justify-between sm:gap-4">
<div class="min-w-0 flex-1">
<div class="flex items-center gap-2">
<div class="text-sm font-medium text-gray-900 dark:text-white"> <div class="text-sm font-medium text-gray-900 dark:text-white">
{{ item.label }} {{ item.label }}
</div> </div>
</div>
<div class="flex shrink-0 items-center gap-2">
<span <span
:data-testid="`profile-binding-${item.provider}-status`" :data-testid="`profile-binding-${item.provider}-status`"
:class="['badge', item.bound ? 'badge-success' : 'badge-gray']" :class="['badge', item.bound ? 'badge-success' : 'badge-gray']"
@@ -32,7 +31,68 @@
: t('profile.authBindings.status.notBound') : t('profile.authBindings.status.notBound')
}} }}
</span> </span>
</div>
<div
v-if="item.provider === 'email' && !item.bound"
class="mt-3 grid gap-2 sm:grid-cols-[minmax(0,1.4fr)_auto]"
>
<input
v-model.trim="emailBindingForm.email"
data-testid="profile-binding-email-input"
type="email"
class="input"
:placeholder="t('profile.authBindings.emailPlaceholder')"
:disabled="isSendingEmailCode || isBindingEmail"
/>
<button
data-testid="profile-binding-email-send-code"
type="button"
class="btn btn-secondary btn-sm"
:disabled="isSendingEmailCode || isBindingEmail"
@click="sendEmailCode"
>
{{
isSendingEmailCode
? t('common.loading')
: t('profile.authBindings.sendCodeAction')
}}
</button>
<input
v-model.trim="emailBindingForm.verifyCode"
data-testid="profile-binding-email-code-input"
type="text"
inputmode="numeric"
maxlength="6"
class="input"
:placeholder="t('profile.authBindings.codePlaceholder')"
:disabled="isBindingEmail"
/>
<input
v-model="emailBindingForm.password"
data-testid="profile-binding-email-password-input"
type="password"
class="input"
:placeholder="t('profile.authBindings.passwordPlaceholder')"
:disabled="isBindingEmail"
/>
<button
data-testid="profile-binding-email-submit"
type="button"
class="btn btn-primary btn-sm sm:col-span-2"
:disabled="isBindingEmail"
@click="bindEmail"
>
{{
isBindingEmail
? t('common.loading')
: t('profile.authBindings.confirmEmailBindAction')
}}
</button>
</div>
</div>
<div class="flex shrink-0 items-center gap-2">
<button <button
v-if="item.canBind" v-if="item.canBind"
:data-testid="`profile-binding-${item.provider}-action`" :data-testid="`profile-binding-${item.provider}-action`"
@@ -46,10 +106,11 @@
</div> </div>
</div> </div>
</div> </div>
</div>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { computed } from 'vue' import { computed, reactive, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useRoute } from 'vue-router' import { useRoute } from 'vue-router'
import { import {
@@ -57,8 +118,8 @@ import {
resolveWeChatOAuthStartStrict, resolveWeChatOAuthStartStrict,
type WeChatOAuthPublicSettings, type WeChatOAuthPublicSettings,
} from '@/api/auth' } from '@/api/auth'
import { startOAuthBinding } from '@/api/user' import { bindEmailIdentity, sendEmailBindingCode, startOAuthBinding } from '@/api/user'
import { useAppStore } from '@/stores' import { useAppStore, useAuthStore } from '@/stores'
import type { User, UserAuthBindingStatus, UserAuthProvider } from '@/types' import type { User, UserAuthBindingStatus, UserAuthProvider } from '@/types'
const props = withDefaults( const props = withDefaults(
@@ -84,6 +145,32 @@ const props = withDefaults(
const { t } = useI18n() const { t } = useI18n()
const route = useRoute() const route = useRoute()
const appStore = useAppStore() const appStore = useAppStore()
const authStore = useAuthStore()
const localUser = ref<User | null>(null)
const isSendingEmailCode = ref(false)
const isBindingEmail = ref(false)
const emailBindingForm = reactive({
email: '',
verifyCode: '',
password: '',
})
watch(
() => props.user,
(user) => {
localUser.value = null
if (!user || getBindingStatusForUser(user, 'email')) {
return
}
if (typeof user.email === 'string' && !user.email.endsWith('.invalid')) {
emailBindingForm.email = user.email
}
},
{ immediate: true }
)
const currentUser = computed(() => localUser.value ?? props.user)
const wechatOAuthSettings = computed<WeChatOAuthPublicSettings | null>(() => { const wechatOAuthSettings = computed<WeChatOAuthPublicSettings | null>(() => {
if (hasExplicitWeChatOAuthCapabilities(appStore.cachedPublicSettings)) { if (hasExplicitWeChatOAuthCapabilities(appStore.cachedPublicSettings)) {
@@ -117,20 +204,20 @@ function normalizeBindingStatus(binding: boolean | UserAuthBindingStatus | undef
} }
function getBindingStatus(provider: UserAuthProvider): boolean { function getBindingStatus(provider: UserAuthProvider): boolean {
const currentUser = props.user return getBindingStatusForUser(currentUser.value, provider)
}
function getBindingStatusForUser(user: User | null | undefined, provider: UserAuthProvider): boolean {
if (provider === 'email') { if (provider === 'email') {
return typeof currentUser?.email_bound === 'boolean' return typeof user?.email_bound === 'boolean' ? user.email_bound : Boolean(user?.email)
? currentUser.email_bound
: Boolean(currentUser?.email)
} }
const directFlag = currentUser?.[`${provider}_bound` as keyof User] const directFlag = user?.[`${provider}_bound` as keyof User]
if (typeof directFlag === 'boolean') { if (typeof directFlag === 'boolean') {
return directFlag return directFlag
} }
const nested = currentUser?.auth_bindings?.[provider] ?? currentUser?.identity_bindings?.[provider] const nested = user?.auth_bindings?.[provider] ?? user?.identity_bindings?.[provider]
const normalized = normalizeBindingStatus(nested) const normalized = normalizeBindingStatus(nested)
return normalized ?? false return normalized ?? false
} }
@@ -171,4 +258,72 @@ function startBinding(provider: UserAuthProvider): void {
wechatOAuthSettings: provider === 'wechat' ? wechatOAuthSettings.value : null, wechatOAuthSettings: provider === 'wechat' ? wechatOAuthSettings.value : null,
}) })
} }
function applyUpdatedUser(user: User): void {
localUser.value = user
authStore.user = user
}
function validateEmailBindingForm(requireCode: boolean): boolean {
if (!emailBindingForm.email) {
appStore.showError(t('auth.emailRequired'))
return false
}
if (!/^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(emailBindingForm.email)) {
appStore.showError(t('auth.invalidEmail'))
return false
}
if (requireCode && !emailBindingForm.verifyCode) {
appStore.showError(t('auth.codeRequired'))
return false
}
if (requireCode && !emailBindingForm.password) {
appStore.showError(t('auth.passwordRequired'))
return false
}
if (requireCode && emailBindingForm.password.length < 6) {
appStore.showError(t('auth.passwordMinLength'))
return false
}
return true
}
async function sendEmailCode(): Promise<void> {
if (!validateEmailBindingForm(false)) {
return
}
isSendingEmailCode.value = true
try {
await sendEmailBindingCode(emailBindingForm.email)
appStore.showSuccess(t('profile.authBindings.codeSentTo', { email: emailBindingForm.email }))
} catch (error) {
appStore.showError((error as { message?: string }).message || t('auth.sendCodeFailed'))
} finally {
isSendingEmailCode.value = false
}
}
async function bindEmail(): Promise<void> {
if (!validateEmailBindingForm(true)) {
return
}
isBindingEmail.value = true
try {
const user = await bindEmailIdentity({
email: emailBindingForm.email,
verify_code: emailBindingForm.verifyCode,
password: emailBindingForm.password,
})
applyUpdatedUser(user)
emailBindingForm.verifyCode = ''
emailBindingForm.password = ''
appStore.showSuccess(t('profile.authBindings.bindSuccess'))
} catch (error) {
appStore.showError((error as { message?: string }).message || t('common.tryAgain'))
} finally {
isBindingEmail.value = false
}
}
</script> </script>

View File

@@ -2,7 +2,7 @@ import { mount } from '@vue/test-utils'
import { createPinia, setActivePinia } from 'pinia' import { createPinia, setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue' import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue'
import { useAppStore } from '@/stores' import { useAppStore, useAuthStore } from '@/stores'
import type { User } from '@/types' import type { User } from '@/types'
const routeState = vi.hoisted(() => ({ const routeState = vi.hoisted(() => ({
@@ -15,10 +15,24 @@ const locationState = vi.hoisted(() => ({
let pinia: ReturnType<typeof createPinia> let pinia: ReturnType<typeof createPinia>
const userApiMocks = vi.hoisted(() => ({
sendEmailBindingCode: vi.fn(),
bindEmailIdentity: vi.fn(),
}))
vi.mock('vue-router', () => ({ vi.mock('vue-router', () => ({
useRoute: () => routeState, useRoute: () => routeState,
})) }))
vi.mock('@/api/user', async (importOriginal) => {
const actual = await importOriginal<typeof import('@/api/user')>()
return {
...actual,
sendEmailBindingCode: (...args: any[]) => userApiMocks.sendEmailBindingCode(...args),
bindEmailIdentity: (...args: any[]) => userApiMocks.bindEmailIdentity(...args),
}
})
vi.mock('vue-i18n', async (importOriginal) => { vi.mock('vue-i18n', async (importOriginal) => {
const actual = await importOriginal<typeof import('vue-i18n')>() const actual = await importOriginal<typeof import('vue-i18n')>()
return { return {
@@ -34,6 +48,13 @@ vi.mock('vue-i18n', async (importOriginal) => {
if (key === 'profile.authBindings.providers.wechat') return 'WeChat' if (key === 'profile.authBindings.providers.wechat') return 'WeChat'
if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC' if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC'
if (key === 'profile.authBindings.bindAction') return `Bind ${params?.providerName || ''}`.trim() if (key === 'profile.authBindings.bindAction') return `Bind ${params?.providerName || ''}`.trim()
if (key === 'profile.authBindings.emailPlaceholder') return 'Email address'
if (key === 'profile.authBindings.codePlaceholder') return 'Verification code'
if (key === 'profile.authBindings.passwordPlaceholder') return 'Set password'
if (key === 'profile.authBindings.sendCodeAction') return 'Send code'
if (key === 'profile.authBindings.confirmEmailBindAction') return 'Bind email'
if (key === 'profile.authBindings.codeSentTo') return `Code sent to ${params?.email || ''}`.trim()
if (key === 'profile.authBindings.bindSuccess') return 'Bind success'
return key return key
}, },
}), }),
@@ -76,6 +97,8 @@ describe('ProfileIdentityBindingsSection', () => {
const appStore = useAppStore() const appStore = useAppStore()
appStore.cachedPublicSettings = null appStore.cachedPublicSettings = null
appStore.publicSettingsLoaded = false appStore.publicSettingsLoaded = false
userApiMocks.sendEmailBindingCode.mockReset()
userApiMocks.bindEmailIdentity.mockReset()
}) })
afterEach(() => { afterEach(() => {
@@ -224,4 +247,58 @@ describe('ProfileIdentityBindingsSection', () => {
expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(true) expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(true)
}) })
it('sends email verification code and binds email from the profile card', async () => {
userApiMocks.sendEmailBindingCode.mockResolvedValue(undefined)
userApiMocks.bindEmailIdentity.mockResolvedValue(
createUser({
email: 'bound@example.com',
email_bound: true,
auth_bindings: {
email: { bound: true },
},
})
)
const appStore = useAppStore()
const authStore = useAuthStore()
authStore.user = createUser({
email: 'legacy-user@linuxdo-connect.invalid',
email_bound: false,
auth_bindings: {
email: { bound: false },
},
})
const showSuccessSpy = vi.spyOn(appStore, 'showSuccess')
const wrapper = mount(ProfileIdentityBindingsSection, {
global: {
plugins: [pinia],
},
props: {
user: authStore.user,
linuxdoEnabled: false,
oidcEnabled: false,
wechatEnabled: false,
},
})
await wrapper.get('[data-testid="profile-binding-email-input"]').setValue('bound@example.com')
await wrapper.get('[data-testid="profile-binding-email-send-code"]').trigger('click')
expect(userApiMocks.sendEmailBindingCode).toHaveBeenCalledWith('bound@example.com')
expect(showSuccessSpy).toHaveBeenCalledWith('Code sent to bound@example.com')
await wrapper.get('[data-testid="profile-binding-email-code-input"]').setValue('123456')
await wrapper.get('[data-testid="profile-binding-email-password-input"]').setValue('new-password')
await wrapper.get('[data-testid="profile-binding-email-submit"]').trigger('click')
expect(userApiMocks.bindEmailIdentity).toHaveBeenCalledWith({
email: 'bound@example.com',
verify_code: '123456',
password: 'new-password',
})
expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound')
expect(authStore.user?.email).toBe('bound@example.com')
})
}) })

View File

@@ -964,6 +964,12 @@ export default {
description: 'View current bindings and connect another provider to this account.', description: 'View current bindings and connect another provider to this account.',
bindAction: 'Bind {providerName}', bindAction: 'Bind {providerName}',
bindSuccess: 'Account linked successfully', bindSuccess: 'Account linked successfully',
emailPlaceholder: 'Enter email address',
codePlaceholder: 'Enter verification code',
passwordPlaceholder: 'Set a login password',
sendCodeAction: 'Send code',
confirmEmailBindAction: 'Bind email',
codeSentTo: 'Code sent to {email}',
status: { status: {
bound: 'Bound', bound: 'Bound',
notBound: 'Not bound', notBound: 'Not bound',

View File

@@ -968,6 +968,12 @@ export default {
description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。', description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。',
bindAction: '绑定 {providerName}', bindAction: '绑定 {providerName}',
bindSuccess: '账号绑定成功', bindSuccess: '账号绑定成功',
emailPlaceholder: '输入邮箱地址',
codePlaceholder: '输入验证码',
passwordPlaceholder: '设置登录密码',
sendCodeAction: '发送验证码',
confirmEmailBindAction: '绑定邮箱',
codeSentTo: '验证码已发送到 {email}',
status: { status: {
bound: '已绑定', bound: '已绑定',
notBound: '未绑定', notBound: '未绑定',

View File

@@ -118,6 +118,8 @@ export interface RegisterRequest {
export interface SendVerifyCodeRequest { export interface SendVerifyCodeRequest {
email: string email: string
turnstile_token?: string turnstile_token?: string
pending_auth_token?: string
pending_oauth_token?: string
} }
export interface SendVerifyCodeResponse { export interface SendVerifyCodeResponse {

View File

@@ -176,7 +176,12 @@ import { AuthLayout } from '@/components/layout'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import TurnstileWidget from '@/components/TurnstileWidget.vue' import TurnstileWidget from '@/components/TurnstileWidget.vue'
import { useAuthStore, useAppStore } from '@/stores' import { useAuthStore, useAppStore } from '@/stores'
import { persistOAuthTokenContext, getPublicSettings, sendVerifyCode } from '@/api/auth' import {
persistOAuthTokenContext,
getPublicSettings,
sendPendingOAuthVerifyCode,
sendVerifyCode,
} from '@/api/auth'
import { apiClient } from '@/api/client' import { apiClient } from '@/api/client'
import { buildAuthErrorMessage } from '@/utils/authError' import { buildAuthErrorMessage } from '@/utils/authError'
import { import {
@@ -355,18 +360,21 @@ async function sendCode(): Promise<void> {
errorMessage.value = '' errorMessage.value = ''
try { try {
if (!isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) { if (!pendingAuthToken.value && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
errorMessage.value = buildEmailSuffixNotAllowedMessage() errorMessage.value = buildEmailSuffixNotAllowedMessage()
appStore.showError(errorMessage.value) appStore.showError(errorMessage.value)
return return
} }
const response = await sendVerifyCode({ const requestPayload = {
email: email.value, email: email.value,
[pendingAuthTokenField.value]: pendingAuthToken.value || undefined, [pendingAuthTokenField.value]: pendingAuthToken.value || undefined,
// 优先使用重发时新获取的 token因为初始 token 可能已被使用) // 优先使用重发时新获取的 token因为初始 token 可能已被使用)
turnstile_token: resendTurnstileToken.value || initialTurnstileToken.value || undefined turnstile_token: resendTurnstileToken.value || initialTurnstileToken.value || undefined
} as Parameters<typeof sendVerifyCode>[0]) } as Parameters<typeof sendVerifyCode>[0]
const response = pendingAuthToken.value
? await sendPendingOAuthVerifyCode(requestPayload)
: await sendVerifyCode(requestPayload)
codeSent.value = true codeSent.value = true
startCountdown(response.countdown) startCountdown(response.countdown)

View File

@@ -444,6 +444,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string {
return err.response?.data?.detail || err.response?.data?.message || err.message || fallback return err.response?.data?.detail || err.response?.data?.message || err.message || fallback
} }
function isCreateAccountRecoveryError(error: unknown): boolean {
const data = (error as {
response?: {
data?: {
reason?: string
error?: string
code?: string
step?: string
intent?: string
}
}
}).response?.data
const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent]
.map(value => value?.trim().toLowerCase())
.filter((value): value is string => Boolean(value))
return states.includes('email_exists') ||
states.includes('bind_login_required') ||
states.includes('bind_login') ||
states.includes('adopt_existing_user_by_email')
}
async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
if (getOAuthCompletionKind(completion) === 'bind') { if (getOAuthCompletionKind(completion) === 'bind') {
const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
@@ -540,10 +562,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
email: payload.email, email: payload.email,
password: payload.password, password: payload.password,
verify_code: payload.verifyCode || undefined, verify_code: payload.verifyCode || undefined,
invitation_code: payload.invitationCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision()) ...serializeAdoptionDecision(currentAdoptionDecision())
}) })
await finalizePendingAccountResponse(data) await finalizePendingAccountResponse(data)
} catch (e: unknown) { } catch (e: unknown) {
if (isCreateAccountRecoveryError(e)) {
switchToBindLoginMode(payload.email)
return
}
accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed')) accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed'))
} finally { } finally {
isSubmitting.value = false isSubmitting.value = false

View File

@@ -488,6 +488,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string {
return err.response?.data?.detail || err.response?.data?.message || err.message || fallback return err.response?.data?.detail || err.response?.data?.message || err.message || fallback
} }
function isCreateAccountRecoveryError(error: unknown): boolean {
const data = (error as {
response?: {
data?: {
reason?: string
error?: string
code?: string
step?: string
intent?: string
}
}
}).response?.data
const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent]
.map(value => value?.trim().toLowerCase())
.filter((value): value is string => Boolean(value))
return states.includes('email_exists') ||
states.includes('bind_login_required') ||
states.includes('bind_login') ||
states.includes('adopt_existing_user_by_email')
}
async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
if (getOAuthCompletionKind(completion) === 'bind') { if (getOAuthCompletionKind(completion) === 'bind') {
const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
@@ -584,10 +606,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
email: payload.email, email: payload.email,
password: payload.password, password: payload.password,
verify_code: payload.verifyCode || undefined, verify_code: payload.verifyCode || undefined,
invitation_code: payload.invitationCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision()) ...serializeAdoptionDecision(currentAdoptionDecision())
}) })
await finalizePendingAccountResponse(data) await finalizePendingAccountResponse(data)
} catch (e: unknown) { } catch (e: unknown) {
if (isCreateAccountRecoveryError(e)) {
switchToBindLoginMode(payload.email)
return
}
accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed')) accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed'))
} finally { } finally {
isSubmitting.value = false isSubmitting.value = false

View File

@@ -647,6 +647,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string {
return err.response?.data?.detail || err.response?.data?.message || err.message || fallback return err.response?.data?.detail || err.response?.data?.message || err.message || fallback
} }
function isCreateAccountRecoveryError(error: unknown): boolean {
const data = (error as {
response?: {
data?: {
reason?: string
error?: string
code?: string
step?: string
intent?: string
}
}
}).response?.data
const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent]
.map(value => value?.trim().toLowerCase())
.filter((value): value is string => Boolean(value))
return states.includes('email_exists') ||
states.includes('bind_login_required') ||
states.includes('bind_login') ||
states.includes('adopt_existing_user_by_email')
}
async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
if (getOAuthCompletionKind(completion) === 'bind') { if (getOAuthCompletionKind(completion) === 'bind') {
const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
@@ -739,10 +761,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
email: payload.email, email: payload.email,
password: payload.password, password: payload.password,
verify_code: payload.verifyCode || undefined, verify_code: payload.verifyCode || undefined,
invitation_code: payload.invitationCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision()) ...serializeAdoptionDecision(currentAdoptionDecision())
}) })
await finalizePendingAccountResponse(data) await finalizePendingAccountResponse(data)
} catch (e: unknown) { } catch (e: unknown) {
if (isCreateAccountRecoveryError(e)) {
switchToBindLoginMode(payload.email)
return
}
accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed')) accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed'))
} finally { } finally {
isSubmitting.value = false isSubmitting.value = false

View File

@@ -11,6 +11,7 @@ const {
clearPendingAuthSessionMock, clearPendingAuthSessionMock,
getPublicSettingsMock, getPublicSettingsMock,
sendVerifyCodeMock, sendVerifyCodeMock,
sendPendingOAuthVerifyCodeMock,
persistOAuthTokenContextMock, persistOAuthTokenContextMock,
apiClientPostMock, apiClientPostMock,
authStoreState, authStoreState,
@@ -23,6 +24,7 @@ const {
clearPendingAuthSessionMock: vi.fn(), clearPendingAuthSessionMock: vi.fn(),
getPublicSettingsMock: vi.fn(), getPublicSettingsMock: vi.fn(),
sendVerifyCodeMock: vi.fn(), sendVerifyCodeMock: vi.fn(),
sendPendingOAuthVerifyCodeMock: vi.fn(),
persistOAuthTokenContextMock: vi.fn(), persistOAuthTokenContextMock: vi.fn(),
apiClientPostMock: vi.fn(), apiClientPostMock: vi.fn(),
authStoreState: { authStoreState: {
@@ -80,6 +82,7 @@ vi.mock('@/api/auth', async () => {
...actual, ...actual,
getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args), getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args),
sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args), sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args),
sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args),
persistOAuthTokenContext: (...args: any[]) => persistOAuthTokenContextMock(...args), persistOAuthTokenContext: (...args: any[]) => persistOAuthTokenContextMock(...args),
} }
}) })
@@ -100,6 +103,7 @@ describe('EmailVerifyView', () => {
clearPendingAuthSessionMock.mockReset() clearPendingAuthSessionMock.mockReset()
getPublicSettingsMock.mockReset() getPublicSettingsMock.mockReset()
sendVerifyCodeMock.mockReset() sendVerifyCodeMock.mockReset()
sendPendingOAuthVerifyCodeMock.mockReset()
persistOAuthTokenContextMock.mockReset() persistOAuthTokenContextMock.mockReset()
apiClientPostMock.mockReset() apiClientPostMock.mockReset()
authStoreState.pendingAuthSession = null authStoreState.pendingAuthSession = null
@@ -112,9 +116,86 @@ describe('EmailVerifyView', () => {
registration_email_suffix_whitelist: [], registration_email_suffix_whitelist: [],
}) })
sendVerifyCodeMock.mockResolvedValue({ countdown: 60 }) sendVerifyCodeMock.mockResolvedValue({ countdown: 60 })
sendPendingOAuthVerifyCodeMock.mockResolvedValue({ countdown: 60 })
setTokenMock.mockResolvedValue({}) setTokenMock.mockResolvedValue({})
}) })
it('uses the pending oauth verify-code endpoint when register data carries a pending auth session', async () => {
authStoreState.pendingAuthSession = {
token: 'pending-token-1',
token_field: 'pending_auth_token',
provider: 'wechat',
redirect: '/profile',
}
sessionStorage.setItem(
'register_data',
JSON.stringify({
email: 'fresh@example.com',
password: 'secret-123',
})
)
mount(EmailVerifyView, {
global: {
stubs: {
AuthLayout: { template: '<div><slot /><slot name="footer" /></div>' },
Icon: true,
TurnstileWidget: true,
transition: false,
},
},
})
await flushPromises()
expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
email: 'fresh@example.com',
pending_auth_token: 'pending-token-1',
})
expect(sendVerifyCodeMock).not.toHaveBeenCalled()
})
it('skips the registration email suffix whitelist for pending oauth verification', async () => {
authStoreState.pendingAuthSession = {
token: 'pending-token-2',
token_field: 'pending_auth_token',
provider: 'oidc',
redirect: '/profile',
}
getPublicSettingsMock.mockResolvedValue({
turnstile_enabled: false,
turnstile_site_key: '',
site_name: 'Sub2API',
registration_email_suffix_whitelist: ['allowed.com'],
})
sessionStorage.setItem(
'register_data',
JSON.stringify({
email: 'fresh@example.com',
password: 'secret-123',
})
)
mount(EmailVerifyView, {
global: {
stubs: {
AuthLayout: { template: '<div><slot /><slot name="footer" /></div>' },
Icon: true,
TurnstileWidget: true,
transition: false,
},
},
})
await flushPromises()
expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
email: 'fresh@example.com',
pending_auth_token: 'pending-token-2',
})
expect(showErrorMock).not.toHaveBeenCalled()
})
it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => { it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => {
authStoreState.pendingAuthSession = { authStoreState.pendingAuthSession = {
token: 'pending-token-1', token: 'pending-token-1',

View File

@@ -15,6 +15,7 @@ const getPublicSettings = vi.fn()
const login2FA = vi.fn() const login2FA = vi.fn()
const apiClientPost = vi.fn() const apiClientPost = vi.fn()
const sendVerifyCode = vi.fn() const sendVerifyCode = vi.fn()
const sendPendingOAuthVerifyCode = vi.fn()
vi.mock('vue-router', () => ({ vi.mock('vue-router', () => ({
useRoute: () => ({ useRoute: () => ({
@@ -61,7 +62,8 @@ vi.mock('@/api/auth', async () => {
completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args), completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args),
getPublicSettings: (...args: any[]) => getPublicSettings(...args), getPublicSettings: (...args: any[]) => getPublicSettings(...args),
login2FA: (...args: any[]) => login2FA(...args), login2FA: (...args: any[]) => login2FA(...args),
sendVerifyCode: (...args: any[]) => sendVerifyCode(...args) sendVerifyCode: (...args: any[]) => sendVerifyCode(...args),
sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args)
} }
}) })
@@ -79,6 +81,7 @@ describe('LinuxDoCallbackView', () => {
login2FA.mockReset() login2FA.mockReset()
apiClientPost.mockReset() apiClientPost.mockReset()
sendVerifyCode.mockReset() sendVerifyCode.mockReset()
sendPendingOAuthVerifyCode.mockReset()
getPublicSettings.mockResolvedValue({ getPublicSettings.mockResolvedValue({
turnstile_enabled: false, turnstile_enabled: false,
turnstile_site_key: '' turnstile_site_key: ''
@@ -334,6 +337,11 @@ describe('LinuxDoCallbackView', () => {
}) })
it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
getPublicSettings.mockResolvedValue({
invitation_code_enabled: true,
turnstile_enabled: false,
turnstile_site_key: ''
})
exchangePendingOAuthCompletion.mockResolvedValue({ exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required', error: 'email_required',
redirect: '/welcome', redirect: '/welcome',
@@ -370,6 +378,7 @@ describe('LinuxDoCallbackView', () => {
await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ') await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ')
await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123') await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810') await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810')
await wrapper.get('[data-testid="linuxdo-create-account-invitation-code"]').setValue(' INVITE123 ')
await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click') await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click')
await flushPromises() await flushPromises()
@@ -377,6 +386,7 @@ describe('LinuxDoCallbackView', () => {
email: 'new@example.com', email: 'new@example.com',
password: 'secret-123', password: 'secret-123',
verify_code: '246810', verify_code: '246810',
invitation_code: 'INVITE123',
adopt_display_name: true, adopt_display_name: true,
adopt_avatar: false adopt_avatar: false
}) })
@@ -384,12 +394,48 @@ describe('LinuxDoCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/welcome') expect(replace).toHaveBeenCalledWith('/welcome')
}) })
it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome'
})
apiClientPost.mockRejectedValue({
response: {
data: {
reason: 'EMAIL_EXISTS',
message: 'email already exists'
}
}
})
const wrapper = mount(LinuxDoCallbackView, {
global: {
stubs: {
AuthLayout: { template: '<div><slot /></div>' },
Icon: true,
RouterLink: { template: '<a><slot /></a>' },
transition: false
}
}
})
await flushPromises()
await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue('existing@example.com')
await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click')
await flushPromises()
expect((wrapper.get('[data-testid="linuxdo-bind-login-email"]').element as HTMLInputElement).value).toBe(
'existing@example.com'
)
})
it('sends a verify code for pending oauth account creation', async () => { it('sends a verify code for pending oauth account creation', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({ exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required', error: 'email_required',
redirect: '/welcome' redirect: '/welcome'
}) })
sendVerifyCode.mockResolvedValue({ sendPendingOAuthVerifyCode.mockResolvedValue({
message: 'sent', message: 'sent',
countdown: 60 countdown: 60
}) })
@@ -411,7 +457,7 @@ describe('LinuxDoCallbackView', () => {
await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click') await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
await flushPromises() await flushPromises()
expect(sendVerifyCode).toHaveBeenCalledWith({ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
email: 'new@example.com' email: 'new@example.com'
}) })
}) })

View File

@@ -15,6 +15,7 @@ const getPublicSettings = vi.fn()
const login2FA = vi.fn() const login2FA = vi.fn()
const apiClientPost = vi.fn() const apiClientPost = vi.fn()
const sendVerifyCode = vi.fn() const sendVerifyCode = vi.fn()
const sendPendingOAuthVerifyCode = vi.fn()
vi.mock('vue-router', () => ({ vi.mock('vue-router', () => ({
useRoute: () => ({ useRoute: () => ({
@@ -66,7 +67,8 @@ vi.mock('@/api/auth', async () => {
completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args),
getPublicSettings: (...args: any[]) => getPublicSettings(...args), getPublicSettings: (...args: any[]) => getPublicSettings(...args),
login2FA: (...args: any[]) => login2FA(...args), login2FA: (...args: any[]) => login2FA(...args),
sendVerifyCode: (...args: any[]) => sendVerifyCode(...args) sendVerifyCode: (...args: any[]) => sendVerifyCode(...args),
sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args)
} }
}) })
@@ -84,6 +86,7 @@ describe('OidcCallbackView', () => {
login2FA.mockReset() login2FA.mockReset()
apiClientPost.mockReset() apiClientPost.mockReset()
sendVerifyCode.mockReset() sendVerifyCode.mockReset()
sendPendingOAuthVerifyCode.mockReset()
getPublicSettings.mockResolvedValue({ getPublicSettings.mockResolvedValue({
oidc_oauth_provider_name: 'ExampleID', oidc_oauth_provider_name: 'ExampleID',
turnstile_enabled: false, turnstile_enabled: false,
@@ -312,6 +315,12 @@ describe('OidcCallbackView', () => {
}) })
it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
getPublicSettings.mockResolvedValue({
oidc_oauth_provider_name: 'ExampleID',
invitation_code_enabled: true,
turnstile_enabled: false,
turnstile_site_key: ''
})
exchangePendingOAuthCompletion.mockResolvedValue({ exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required', error: 'email_required',
redirect: '/welcome', redirect: '/welcome',
@@ -348,6 +357,7 @@ describe('OidcCallbackView', () => {
await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ') await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ')
await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123') await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="oidc-create-account-verify-code"]').setValue('246810') await wrapper.get('[data-testid="oidc-create-account-verify-code"]').setValue('246810')
await wrapper.get('[data-testid="oidc-create-account-invitation-code"]').setValue(' INVITE123 ')
await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click') await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click')
await flushPromises() await flushPromises()
@@ -355,6 +365,7 @@ describe('OidcCallbackView', () => {
email: 'new@example.com', email: 'new@example.com',
password: 'secret-123', password: 'secret-123',
verify_code: '246810', verify_code: '246810',
invitation_code: 'INVITE123',
adopt_display_name: true, adopt_display_name: true,
adopt_avatar: false adopt_avatar: false
}) })
@@ -362,12 +373,48 @@ describe('OidcCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/welcome') expect(replace).toHaveBeenCalledWith('/welcome')
}) })
it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome'
})
apiClientPost.mockRejectedValue({
response: {
data: {
reason: 'EMAIL_EXISTS',
message: 'email already exists'
}
}
})
const wrapper = mount(OidcCallbackView, {
global: {
stubs: {
AuthLayout: { template: '<div><slot /></div>' },
Icon: true,
RouterLink: { template: '<a><slot /></a>' },
transition: false
}
}
})
await flushPromises()
await wrapper.get('[data-testid="oidc-create-account-email"]').setValue('existing@example.com')
await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click')
await flushPromises()
expect((wrapper.get('[data-testid="oidc-bind-login-email"]').element as HTMLInputElement).value).toBe(
'existing@example.com'
)
})
it('sends a verify code for pending oauth account creation', async () => { it('sends a verify code for pending oauth account creation', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({ exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required', error: 'email_required',
redirect: '/welcome' redirect: '/welcome'
}) })
sendVerifyCode.mockResolvedValue({ sendPendingOAuthVerifyCode.mockResolvedValue({
message: 'sent', message: 'sent',
countdown: 60 countdown: 60
}) })
@@ -389,7 +436,7 @@ describe('OidcCallbackView', () => {
await wrapper.get('[data-testid="oidc-create-account-send-code"]').trigger('click') await wrapper.get('[data-testid="oidc-create-account-send-code"]').trigger('click')
await flushPromises() await flushPromises()
expect(sendVerifyCode).toHaveBeenCalledWith({ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
email: 'new@example.com' email: 'new@example.com'
}) })
}) })

View File

@@ -8,6 +8,8 @@ const {
login2FAMock, login2FAMock,
apiClientPostMock, apiClientPostMock,
sendVerifyCodeMock, sendVerifyCodeMock,
sendPendingOAuthVerifyCodeMock,
getPublicSettingsMock,
prepareOAuthBindAccessTokenCookieMock, prepareOAuthBindAccessTokenCookieMock,
getAuthTokenMock, getAuthTokenMock,
replaceMock, replaceMock,
@@ -24,6 +26,8 @@ const {
login2FAMock: vi.fn(), login2FAMock: vi.fn(),
apiClientPostMock: vi.fn(), apiClientPostMock: vi.fn(),
sendVerifyCodeMock: vi.fn(), sendVerifyCodeMock: vi.fn(),
sendPendingOAuthVerifyCodeMock: vi.fn(),
getPublicSettingsMock: vi.fn(),
prepareOAuthBindAccessTokenCookieMock: vi.fn(), prepareOAuthBindAccessTokenCookieMock: vi.fn(),
getAuthTokenMock: vi.fn(), getAuthTokenMock: vi.fn(),
replaceMock: vi.fn(), replaceMock: vi.fn(),
@@ -130,6 +134,8 @@ vi.mock('@/api/auth', async () => {
completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args), completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args),
login2FA: (...args: any[]) => login2FAMock(...args), login2FA: (...args: any[]) => login2FAMock(...args),
sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args), sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args),
sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args),
getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args),
prepareOAuthBindAccessTokenCookie: (...args: any[]) => prepareOAuthBindAccessTokenCookieMock(...args), prepareOAuthBindAccessTokenCookie: (...args: any[]) => prepareOAuthBindAccessTokenCookieMock(...args),
getAuthToken: (...args: any[]) => getAuthTokenMock(...args), getAuthToken: (...args: any[]) => getAuthTokenMock(...args),
} }
@@ -142,6 +148,8 @@ describe('WechatCallbackView', () => {
login2FAMock.mockReset() login2FAMock.mockReset()
apiClientPostMock.mockReset() apiClientPostMock.mockReset()
sendVerifyCodeMock.mockReset() sendVerifyCodeMock.mockReset()
sendPendingOAuthVerifyCodeMock.mockReset()
getPublicSettingsMock.mockReset()
replaceMock.mockReset() replaceMock.mockReset()
setTokenMock.mockReset() setTokenMock.mockReset()
showSuccessMock.mockReset() showSuccessMock.mockReset()
@@ -167,6 +175,11 @@ describe('WechatCallbackView', () => {
configurable: true, configurable: true,
value: 'Mozilla/5.0', value: 'Mozilla/5.0',
}) })
getPublicSettingsMock.mockResolvedValue({
invitation_code_enabled: false,
turnstile_enabled: false,
turnstile_site_key: '',
})
}) })
it('overrides an incompatible query mode with the configured open capability during bind recovery', async () => { it('overrides an incompatible query mode with the configured open capability during bind recovery', async () => {
@@ -478,6 +491,11 @@ describe('WechatCallbackView', () => {
}) })
it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
getPublicSettingsMock.mockResolvedValue({
invitation_code_enabled: true,
turnstile_enabled: false,
turnstile_site_key: '',
})
exchangePendingOAuthCompletionMock.mockResolvedValue({ exchangePendingOAuthCompletionMock.mockResolvedValue({
error: 'email_required', error: 'email_required',
redirect: '/welcome', redirect: '/welcome',
@@ -514,6 +532,7 @@ describe('WechatCallbackView', () => {
await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ') await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ')
await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123') await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="wechat-create-account-verify-code"]').setValue('246810') await wrapper.get('[data-testid="wechat-create-account-verify-code"]').setValue('246810')
await wrapper.get('[data-testid="wechat-create-account-invitation-code"]').setValue(' INVITE123 ')
await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click') await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click')
await flushPromises() await flushPromises()
@@ -521,6 +540,7 @@ describe('WechatCallbackView', () => {
email: 'new@example.com', email: 'new@example.com',
password: 'secret-123', password: 'secret-123',
verify_code: '246810', verify_code: '246810',
invitation_code: 'INVITE123',
adopt_display_name: true, adopt_display_name: true,
adopt_avatar: false, adopt_avatar: false,
}) })
@@ -528,12 +548,48 @@ describe('WechatCallbackView', () => {
expect(replaceMock).toHaveBeenCalledWith('/welcome') expect(replaceMock).toHaveBeenCalledWith('/welcome')
}) })
it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
})
apiClientPostMock.mockRejectedValue({
response: {
data: {
reason: 'EMAIL_EXISTS',
message: 'email already exists',
},
},
})
const wrapper = mount(WechatCallbackView, {
global: {
stubs: {
AuthLayout: { template: '<div><slot /></div>' },
Icon: true,
RouterLink: { template: '<a><slot /></a>' },
transition: false,
},
},
})
await flushPromises()
await wrapper.get('[data-testid="wechat-create-account-email"]').setValue('existing@example.com')
await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click')
await flushPromises()
expect((wrapper.get('[data-testid="wechat-bind-login-email"]').element as HTMLInputElement).value).toBe(
'existing@example.com'
)
})
it('sends a verify code for pending oauth account creation', async () => { it('sends a verify code for pending oauth account creation', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({ exchangePendingOAuthCompletionMock.mockResolvedValue({
error: 'email_required', error: 'email_required',
redirect: '/welcome', redirect: '/welcome',
}) })
sendVerifyCodeMock.mockResolvedValue({ sendPendingOAuthVerifyCodeMock.mockResolvedValue({
message: 'sent', message: 'sent',
countdown: 60, countdown: 60,
}) })
@@ -555,7 +611,7 @@ describe('WechatCallbackView', () => {
await wrapper.get('[data-testid="wechat-create-account-send-code"]').trigger('click') await wrapper.get('[data-testid="wechat-create-account-send-code"]').trigger('click')
await flushPromises() await flushPromises()
expect(sendVerifyCodeMock).toHaveBeenCalledWith({ expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
email: 'new@example.com', email: 'new@example.com',
}) })
}) })