From c229f33e9e827ab278be70f23a40bb01957569b5 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 10:26:22 +0800 Subject: [PATCH 01/31] fix(review): harden payment, oauth, and migration paths --- backend/ent/migrate/schema.go | 11 +- backend/ent/schema/auth_identity.go | 3 +- backend/ent/schema/payment_order.go | 4 +- backend/ent/schema/pending_auth_session.go | 1 + backend/ent/schema/user.go | 3 +- backend/internal/config/config.go | 4 +- .../internal/handler/auth_linuxdo_oauth.go | 34 ++- .../handler/auth_linuxdo_oauth_test.go | 22 ++ backend/internal/payment/wire.go | 16 +- backend/internal/payment/wire_test.go | 62 ++++++ backend/internal/server/routes/auth.go | 1 + .../internal/service/payment_fulfillment.go | 24 +- .../service/payment_fulfillment_test.go | 11 + backend/internal/service/payment_order.go | 29 ++- .../service/payment_order_lifecycle.go | 10 + .../service/payment_order_lifecycle_test.go | 91 ++++++++ .../service/payment_order_result_test.go | 42 +++- ...dd_payment_order_provider_key_snapshot.sql | 2 +- ...hat_dual_mode_and_auth_source_defaults.sql | 9 - ...rce_payment_orders_out_trade_no_unique.sql | 7 + ...tity_payment_migrations_regression_test.go | 37 ++++ .../api/__tests__/auth-oauth-adoption.spec.ts | 14 +- frontend/src/api/auth.ts | 28 +-- frontend/src/api/user.ts | 6 +- frontend/src/router/__tests__/guards.spec.ts | 18 +- .../src/router/__tests__/wechat-route.spec.ts | 9 + frontend/src/router/index.ts | 3 +- .../src/views/auth/WechatCallbackView.vue | 2 +- frontend/src/views/user/PaymentResultView.vue | 23 +- frontend/src/views/user/PaymentView.vue | 20 ++ .../user/__tests__/PaymentResultView.spec.ts | 22 ++ .../views/user/__tests__/PaymentView.spec.ts | 205 ++++++++++++++++++ .../views/user/__tests__/paymentUx.spec.ts | 10 + 33 files changed, 704 insertions(+), 79 deletions(-) create mode 100644 backend/internal/payment/wire_test.go create mode 100644 backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql create mode 100644 backend/migrations/auth_identity_payment_migrations_regression_test.go create mode 100644 frontend/src/views/user/__tests__/PaymentView.spec.ts diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 81f6a664..40b326a9 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -361,7 +361,7 @@ var ( Symbol: "auth_identities_users_auth_identities", Columns: []*schema.Column{AuthIdentitiesColumns[9]}, RefColumns: []*schema.Column{UsersColumns[0]}, - OnDelete: schema.NoAction, + OnDelete: schema.Cascade, }, }, Indexes: []*schema.Index{ @@ -405,7 +405,7 @@ var ( Symbol: "auth_identity_channels_auth_identities_channels", Columns: []*schema.Column{AuthIdentityChannelsColumns[9]}, RefColumns: []*schema.Column{AuthIdentitiesColumns[0]}, - OnDelete: schema.NoAction, + OnDelete: schema.Cascade, }, }, Indexes: []*schema.Index{ @@ -595,7 +595,7 @@ var ( Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision", Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]}, RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]}, - OnDelete: schema.NoAction, + OnDelete: schema.Cascade, }, }, Indexes: []*schema.Index{ @@ -692,8 +692,11 @@ var ( Indexes: []*schema.Index{ { Name: "paymentorder_out_trade_no", - Unique: false, + Unique: true, Columns: []*schema.Column{PaymentOrdersColumns[8]}, + Annotation: &entsql.IndexAnnotation{ + Where: "out_trade_no <> ''", + }, }, { Name: "paymentorder_user_id", diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go index e4b9ac90..0b1b56ab 100644 --- a/backend/ent/schema/auth_identity.go +++ b/backend/ent/schema/auth_identity.go @@ -79,7 +79,8 @@ func (AuthIdentity) Edges() []ent.Edge { Field("user_id"). Required(). Unique(), - edge.To("channels", AuthIdentityChannel.Type), + edge.To("channels", AuthIdentityChannel.Type). + Annotations(entsql.OnDelete(entsql.Cascade)), edge.To("adoption_decisions", IdentityAdoptionDecision.Type), } } diff --git a/backend/ent/schema/payment_order.go b/backend/ent/schema/payment_order.go index 5815d032..d25d1e5e 100644 --- a/backend/ent/schema/payment_order.go +++ b/backend/ent/schema/payment_order.go @@ -185,7 +185,9 @@ func (PaymentOrder) Edges() []ent.Edge { func (PaymentOrder) Indexes() []ent.Index { return []ent.Index{ - index.Fields("out_trade_no"), + index.Fields("out_trade_no"). + Unique(). + Annotations(entsql.IndexWhere("out_trade_no <> ''")), index.Fields("user_id"), index.Fields("status"), index.Fields("expires_at"), diff --git a/backend/ent/schema/pending_auth_session.go b/backend/ent/schema/pending_auth_session.go index 91341d49..7e95f085 100644 --- a/backend/ent/schema/pending_auth_session.go +++ b/backend/ent/schema/pending_auth_session.go @@ -119,6 +119,7 @@ func (PendingAuthSession) Edges() []ent.Edge { Field("target_user_id"). Unique(), edge.To("adoption_decision", IdentityAdoptionDecision.Type). + Annotations(entsql.OnDelete(entsql.Cascade)). Unique(), } } diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index bb58d9e3..f307bda8 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -115,7 +115,8 @@ func (User) Edges() []ent.Edge { edge.To("attribute_values", UserAttributeValue.Type), edge.To("promo_code_usages", PromoCodeUsage.Type), edge.To("payment_orders", PaymentOrder.Type), - edge.To("auth_identities", AuthIdentity.Type), + edge.To("auth_identities", AuthIdentity.Type). + Annotations(entsql.OnDelete(entsql.Cascade)), edge.To("pending_auth_sessions", PendingAuthSession.Type), } } diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 44bc5c9f..f355a15d 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1202,7 +1202,7 @@ func setDefaults() { viper.SetDefault("linuxdo_connect.redirect_url", "") viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback") viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post") - viper.SetDefault("linuxdo_connect.use_pkce", false) + viper.SetDefault("linuxdo_connect.use_pkce", true) viper.SetDefault("linuxdo_connect.userinfo_email_path", "") viper.SetDefault("linuxdo_connect.userinfo_id_path", "") viper.SetDefault("linuxdo_connect.userinfo_username_path", "") @@ -1222,7 +1222,7 @@ func setDefaults() { viper.SetDefault("oidc_connect.redirect_url", "") viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback") viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post") - viper.SetDefault("oidc_connect.use_pkce", false) + viper.SetDefault("oidc_connect.use_pkce", true) viper.SetDefault("oidc_connect.validate_id_token", true) viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256") viper.SetDefault("oidc_connect.clock_skew_seconds", 120) diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index e0bee2f5..2bd44e78 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -937,7 +937,19 @@ func clearOAuthBindAccessTokenCookie(c *gin.Context, secure bool) { Value: "", Path: oauthBindAccessTokenCookiePath, MaxAge: -1, - HttpOnly: false, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func setOAuthBindAccessTokenCookie(c *gin.Context, token string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthBindAccessTokenCookieName, + Value: url.QueryEscape(strings.TrimSpace(token)), + Path: oauthBindAccessTokenCookiePath, + MaxAge: linuxDoOAuthCookieMaxAgeSec, + HttpOnly: true, Secure: secure, SameSite: http.SameSiteLaxMode, }) @@ -1021,6 +1033,26 @@ func (h *AuthHandler) buildOAuthBindUserCookieFromContext(c *gin.Context) (strin return buildOAuthBindUserCookieValue(*userID, h.oauthBindCookieSecret()) } +func (h *AuthHandler) PrepareOAuthBindAccessTokenCookie(c *gin.Context) { + const bearerPrefix = "Bearer " + + authHeader := strings.TrimSpace(c.GetHeader("Authorization")) + if !strings.HasPrefix(strings.ToLower(authHeader), strings.ToLower(bearerPrefix)) { + response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required")) + return + } + + token := strings.TrimSpace(authHeader[len(bearerPrefix):]) + if token == "" { + response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required")) + return + } + + setOAuthBindAccessTokenCookie(c, token, isRequestHTTPS(c)) + c.Status(http.StatusNoContent) + c.Writer.WriteHeaderNow() +} + func (h *AuthHandler) resolveOAuthBindTargetUserID(c *gin.Context) (*int64, error) { if subject, ok := servermiddleware.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { return &subject.UserID, nil diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index 0c760ee9..a3d87dfb 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -5,6 +5,7 @@ import ( "context" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" @@ -226,6 +227,27 @@ func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) { require.Equal(t, -1, accessTokenCookie.MaxAge) } +func TestPrepareOAuthBindAccessTokenCookieSetsHttpOnlyCookie(t *testing.T) { + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{}) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/bind-token", nil) + req.Header.Set("Authorization", "Bearer access-token-value") + c.Request = req + + handler.PrepareOAuthBindAccessTokenCookie(c) + + require.Equal(t, http.StatusNoContent, recorder.Code) + accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName) + require.NotNil(t, accessTokenCookie) + require.Equal(t, oauthBindAccessTokenCookiePath, accessTokenCookie.Path) + require.Equal(t, linuxDoOAuthCookieMaxAgeSec, accessTokenCookie.MaxAge) + require.True(t, accessTokenCookie.HttpOnly) + require.Equal(t, url.QueryEscape("access-token-value"), accessTokenCookie.Value) +} + func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { diff --git a/backend/internal/payment/wire.go b/backend/internal/payment/wire.go index 9717465d..4b7f422d 100644 --- a/backend/internal/payment/wire.go +++ b/backend/internal/payment/wire.go @@ -4,6 +4,7 @@ import ( "encoding/hex" "fmt" "log/slog" + "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" @@ -19,11 +20,22 @@ type EncryptionKey []byte // When the key is non-empty but invalid (bad hex or wrong length), an error is returned // to prevent startup with a misconfigured encryption key. func ProvideEncryptionKey(cfg *config.Config) (EncryptionKey, error) { - if cfg.Totp.EncryptionKey == "" { + if cfg == nil { + slog.Warn("payment encryption key not configured — encrypted payment config and resume signing will be unavailable") + return nil, nil + } + keyHex := strings.TrimSpace(cfg.Totp.EncryptionKey) + if keyHex == "" { slog.Warn("payment encryption key not configured — encrypted payment config will be unavailable") return nil, nil } - key, err := hex.DecodeString(cfg.Totp.EncryptionKey) + // Reject auto-generated TOTP keys for payment signing. + // They change across restarts/instances and can silently break resume-token flows. + if !cfg.Totp.EncryptionKeyConfigured { + slog.Warn("payment encryption/signing key is not explicitly configured; set TOTP_ENCRYPTION_KEY to enable payment resume tokens") + return nil, nil + } + key, err := hex.DecodeString(keyHex) if err != nil { return nil, fmt.Errorf("invalid payment encryption key (hex decode): %w", err) } diff --git a/backend/internal/payment/wire_test.go b/backend/internal/payment/wire_test.go new file mode 100644 index 00000000..1b360f89 --- /dev/null +++ b/backend/internal/payment/wire_test.go @@ -0,0 +1,62 @@ +package payment + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +func TestProvideEncryptionKeySkipsAutoGeneratedTotpKey(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Totp: config.TotpConfig{ + EncryptionKey: strings.Repeat("a", 64), + EncryptionKeyConfigured: false, + }, + } + + key, err := ProvideEncryptionKey(cfg) + if err != nil { + t.Fatalf("ProvideEncryptionKey returned error: %v", err) + } + if len(key) != 0 { + t.Fatalf("encryption key len = %d, want 0", len(key)) + } +} + +func TestProvideEncryptionKeyUsesConfiguredTotpKey(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Totp: config.TotpConfig{ + EncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + EncryptionKeyConfigured: true, + }, + } + + key, err := ProvideEncryptionKey(cfg) + if err != nil { + t.Fatalf("ProvideEncryptionKey returned error: %v", err) + } + if len(key) != 32 { + t.Fatalf("encryption key len = %d, want 32", len(key)) + } +} + +func TestProvideEncryptionKeyRejectsConfiguredInvalidLength(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Totp: config.TotpConfig{ + EncryptionKey: "abcd", + EncryptionKeyConfigured: true, + }, + } + + _, err := ProvideEncryptionKey(cfg) + if err == nil { + t.Fatal("expected error for invalid key length") + } +} diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index f1032eb5..b4b75795 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -164,6 +164,7 @@ func RegisterAuthRoutes( authenticated.GET("/auth/me", h.Auth.GetCurrentUser) // 撤销所有会话(需要认证) authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions) + authenticated.POST("/auth/oauth/bind-token", h.Auth.PrepareOAuthBindAccessTokenCookie) authenticated.GET("/auth/oauth/linuxdo/bind/start", func(c *gin.Context) { query := c.Request.URL.Query() query.Set("intent", "bind_current_user") diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 904960ee..71f1eb2f 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -80,21 +80,25 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo }) return err } - // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount). - // Also skip if paid is NaN/Inf (malformed provider data). - if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) { - if math.Abs(paid-o.PayAmount) > amountToleranceCNY { - s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo}) - return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid) - } + if !isValidProviderAmount(paid) { + s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", pk, map[string]any{ + "expected": o.PayAmount, + "paid": paid, + "tradeNo": tradeNo, + }) + return fmt.Errorf("invalid paid amount from provider: %v", paid) } - // Use order's expected amount when provider didn't report one - if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) { - paid = o.PayAmount + if math.Abs(paid-o.PayAmount) > amountToleranceCNY { + s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo}) + return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid) } return s.toPaid(ctx, o, tradeNo, paid, pk) } +func isValidProviderAmount(amount float64) bool { + return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0) +} + func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error { return validateProviderSnapshotMetadata(order, providerKey, metadata) } diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go index 6aed19f8..abdb59de 100644 --- a/backend/internal/service/payment_fulfillment_test.go +++ b/backend/internal/service/payment_fulfillment_test.go @@ -5,6 +5,7 @@ package service import ( "context" "errors" + "math" "testing" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -322,6 +323,16 @@ func TestParseLegacyPaymentOrderID(t *testing.T) { assert.False(t, ok) } +func TestIsValidProviderAmount(t *testing.T) { + t.Parallel() + + assert.True(t, isValidProviderAmount(0.01)) + assert.False(t, isValidProviderAmount(0)) + assert.False(t, isValidProviderAmount(-1)) + assert.False(t, isValidProviderAmount(math.NaN())) + assert.False(t, isValidProviderAmount(math.Inf(1))) +} + func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) { t.Parallel() diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index 6554526e..3fdcecb5 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -139,6 +139,10 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq tm = defaultOrderTimeoutMin } exp := time.Now().Add(time.Duration(tm) * time.Minute) + outTradeNo, err := s.allocateOutTradeNo(ctx, tx) + if err != nil { + return nil, err + } providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req) selectedInstanceID := "" selectedProviderKey := "" @@ -155,7 +159,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq SetPayAmount(payAmount). SetFeeRate(feeRate). SetRechargeCode(""). - SetOutTradeNo(generateOutTradeNo()). + SetOutTradeNo(outTradeNo). SetPaymentType(req.PaymentType). SetPaymentTradeNo(""). SetOrderType(req.OrderType). @@ -193,6 +197,21 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq return order, nil } +func (s *PaymentService) allocateOutTradeNo(ctx context.Context, tx *dbent.Tx) (string, error) { + const maxAttempts = 5 + for attempt := 0; attempt < maxAttempts; attempt++ { + candidate := generateOutTradeNo() + exists, err := tx.PaymentOrder.Query().Where(paymentorder.OutTradeNo(candidate)).Exist(ctx) + if err != nil { + return "", fmt.Errorf("check out_trade_no uniqueness: %w", err) + } + if !exists { + return candidate, nil + } + } + return "", fmt.Errorf("generate unique out_trade_no: exhausted %d attempts", maxAttempts) +} + func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, userID int64, max int) error { if max <= 0 { max = defaultMaxPendingOrders @@ -366,7 +385,10 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen } resumeToken := "" if resume := s.paymentResume(); resume != nil { - if resume.isSigningConfigured() { + if canonicalReturnURL != "" { + if err := resume.ensureSigningKey(); err != nil { + return nil, err + } resumeToken, err = resume.CreateToken(ResumeTokenClaims{ OrderID: order.ID, UserID: order.UserID, @@ -482,6 +504,9 @@ func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, r if err != nil { return nil, err } + if err := s.paymentResume().ensureSigningKey(); err != nil { + return nil, err + } authorizeURL, err := buildWeChatPaymentOAuthStartURL(req, "snsapi_base") if err != nil { diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index ccab7c11..ffb63066 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -150,6 +150,16 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s return "" } if resp.Status == payment.ProviderStatusPaid { + if !isValidProviderAmount(resp.Amount) { + s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", prov.ProviderKey(), map[string]any{ + "expected": o.PayAmount, + "paid": resp.Amount, + "tradeNo": resp.TradeNo, + "queryRef": queryRef, + }) + slog.Warn("query upstream returned invalid paid amount", "orderID", o.ID, "queryRef", queryRef, "paid", resp.Amount) + return "" + } notificationTradeNo := o.PaymentTradeNo if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) { if _, updateErr := s.entClient.PaymentOrder.Update(). diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go index 39993a2f..cabdb445 100644 --- a/backend/internal/service/payment_order_lifecycle_test.go +++ b/backend/internal/service/payment_order_lifecycle_test.go @@ -234,6 +234,97 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) { require.Equal(t, user.ID, redeemRepo.useCalls[0].userID) } +func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) { + ctx := context.Background() + client := newPaymentOrderLifecycleTestClient(t) + + user, err := client.User.Create(). + SetEmail("checkpaid-zero-amount@example.com"). + SetPasswordHash("hash"). + SetUsername("checkpaid-zero-amount-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("CHECKPAID-ZERO-AMOUNT"). + SetOutTradeNo("sub2_checkpaid_zero_amount"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + userRepo := &mockUserRepo{ + getByIDUser: &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + Balance: 0, + }, + } + redeemRepo := &paymentOrderLifecycleRedeemRepo{ + codesByCode: map[string]*RedeemCode{ + order.RechargeCode: { + ID: 1, + Code: order.RechargeCode, + Type: RedeemTypeBalance, + Value: order.Amount, + Status: StatusUnused, + }, + }, + } + redeemService := NewRedeemService( + redeemRepo, + userRepo, + nil, + nil, + nil, + client, + nil, + ) + registry := payment.NewRegistry() + provider := &paymentOrderLifecycleQueryProvider{ + resp: &payment.QueryOrderResponse{ + TradeNo: "upstream-trade-zero", + Status: payment.ProviderStatusPaid, + Amount: 0, + }, + } + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + redeemService: redeemService, + userRepo: userRepo, + providersLoaded: true, + } + + got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) + require.NoError(t, err) + require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo) + require.Equal(t, OrderStatusPending, got.Status) + require.Empty(t, got.PaymentTradeNo) + + reloaded, err := client.PaymentOrder.Get(ctx, order.ID) + require.NoError(t, err) + require.Equal(t, OrderStatusPending, reloaded.Status) + require.Empty(t, reloaded.PaymentTradeNo) + + require.Equal(t, 0.0, userRepo.getByIDUser.Balance) + require.Empty(t, redeemRepo.useCalls) +} + func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) { ctx := context.Background() client := newPaymentOrderLifecycleTestClient(t) diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go index 16757323..23371cfd 100644 --- a/backend/internal/service/payment_order_result_test.go +++ b/backend/internal/service/payment_order_result_test.go @@ -159,6 +159,45 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testin } } +func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testing.T) { + t.Parallel() + + svc := &PaymentService{ + configService: &PaymentConfigService{ + settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }}, + // Intentionally missing payment resume signing key. + encryptionKey: nil, + }, + } + + resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ + Amount: 12.5, + PaymentType: payment.TypeWxpay, + IsWeChatBrowser: true, + SrcURL: "https://merchant.example/payment?from=wechat", + OrderType: payment.OrderTypeBalance, + }, 12.5, 12.88, 0.03) + if resp != nil { + t.Fatalf("expected nil response, got %+v", resp) + } + if err == nil { + t.Fatal("expected error, got nil") + } + + appErr := infraerrors.FromError(err) + if appErr.Reason != "PAYMENT_RESUME_NOT_CONFIGURED" { + t.Fatalf("reason = %q, want %q", appErr.Reason, "PAYMENT_RESUME_NOT_CONFIGURED") + } +} + func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) { svc := newWeChatPaymentOAuthTestService(map[string]string{ SettingKeyWeChatConnectEnabled: "true", @@ -189,7 +228,8 @@ func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService { return &PaymentService{ configService: &PaymentConfigService{ - settingRepo: &paymentConfigSettingRepoStub{values: values}, + settingRepo: &paymentConfigSettingRepoStub{values: values}, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), }, } } diff --git a/backend/migrations/112_add_payment_order_provider_key_snapshot.sql b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql index 7ec19ae3..d331b824 100644 --- a/backend/migrations/112_add_payment_order_provider_key_snapshot.sql +++ b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql @@ -1,4 +1,4 @@ -ALTER TABLE payment_orders ADD COLUMN provider_key VARCHAR(30); +ALTER TABLE payment_orders ADD COLUMN IF NOT EXISTS provider_key VARCHAR(30); UPDATE payment_orders SET provider_key = ( diff --git a/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql index 6eef59e2..9b037984 100644 --- a/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql +++ b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql @@ -21,12 +21,3 @@ VALUES ('auth_source_default_oidc_grant_on_signup', 'false'), ('auth_source_default_wechat_grant_on_signup', 'false') ON CONFLICT (key) DO NOTHING; - -UPDATE settings -SET value = 'false' -WHERE key IN ( - 'auth_source_default_email_grant_on_signup', - 'auth_source_default_linuxdo_grant_on_signup', - 'auth_source_default_oidc_grant_on_signup', - 'auth_source_default_wechat_grant_on_signup' -); diff --git a/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql new file mode 100644 index 00000000..4e256562 --- /dev/null +++ b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql @@ -0,0 +1,7 @@ +-- Replace the legacy non-unique index with a partial unique index. +-- Keep empty-string legacy rows compatible while enforcing uniqueness for real order IDs. +DROP INDEX IF EXISTS paymentorder_out_trade_no; + +CREATE UNIQUE INDEX IF NOT EXISTS paymentorder_out_trade_no + ON payment_orders (out_trade_no) + WHERE out_trade_no <> ''; diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go new file mode 100644 index 00000000..1c4a51fa --- /dev/null +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -0,0 +1,37 @@ +package migrations + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMigration112UsesIdempotentAddColumn(t *testing.T) { + content, err := FS.ReadFile("112_add_payment_order_provider_key_snapshot.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS provider_key VARCHAR(30)") + require.NotContains(t, sql, "ADD COLUMN provider_key VARCHAR(30);") +} + +func TestMigration118DoesNotForceOverwriteAuthSourceGrantDefaults(t *testing.T) { + content, err := FS.ReadFile("118_wechat_dual_mode_and_auth_source_defaults.sql") + require.NoError(t, err) + + sql := string(content) + require.NotContains(t, sql, "UPDATE settings") + require.NotContains(t, sql, "SET value = 'false'") + require.True(t, strings.Contains(sql, "ON CONFLICT (key) DO NOTHING")) +} + +func TestMigration119EnforcesOutTradeNoPartialUniqueIndex(t *testing.T) { + content, err := FS.ReadFile("119_enforce_payment_orders_out_trade_no_unique.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "DROP INDEX IF EXISTS paymentorder_out_trade_no") + require.Contains(t, sql, "CREATE UNIQUE INDEX IF NOT EXISTS paymentorder_out_trade_no") + require.Contains(t, sql, "WHERE out_trade_no <> ''") +} diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts index f95332fb..a484d7ed 100644 --- a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts +++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts @@ -173,20 +173,12 @@ describe('oauth adoption auth api', () => { expect(hasPendingOAuthSuggestedProfile({})).toBe(false) }) - it('prepares an oauth bind access token cookie before redirect binding', async () => { + it('requests an HttpOnly oauth bind cookie before redirect binding', async () => { localStorage.setItem('auth_token', 'access-token-value') - const setCookie = vi.fn() - Object.defineProperty(document, 'cookie', { - configurable: true, - get: () => '', - set: setCookie - }) - const { prepareOAuthBindAccessTokenCookie } = await import('@/api/auth') - prepareOAuthBindAccessTokenCookie() + await prepareOAuthBindAccessTokenCookie() - expect(setCookie).toHaveBeenCalledTimes(1) - expect(setCookie.mock.calls[0]?.[0]).toContain('oauth_bind_access_token=access-token-value') + expect(post).toHaveBeenCalledWith('/auth/oauth/bind-token') }) }) diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 9244489c..9621c26e 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -278,33 +278,11 @@ export function persistOAuthTokenContext(tokens: Partial): v } } -export function prepareOAuthBindAccessTokenCookie(): void { - if (typeof document === 'undefined' || typeof window === 'undefined') { +export async function prepareOAuthBindAccessTokenCookie(): Promise { + if (!getAuthToken()) { return } - - const token = getAuthToken() - if (!token) { - return - } - - const secure = window.location.protocol === 'https:' ? '; Secure' : '' - const path = resolveOAuthBindCookiePath() - document.cookie = - `oauth_bind_access_token=${encodeURIComponent(token)}; Path=${path}/auth/oauth; Max-Age=600; SameSite=Lax${secure}` -} - -function resolveOAuthBindCookiePath(): string { - const apiBase = ((import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1').replace(/\/$/, '') - - try { - return new URL(apiBase, window.location.origin).pathname.replace(/\/$/, '') || '/api/v1' - } catch { - if (apiBase.startsWith('/')) { - return apiBase - } - return '/api/v1' - } + await apiClient.post('/auth/oauth/bind-token') } /** diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts index 32ef07e0..f6baf49d 100644 --- a/frontend/src/api/user.ts +++ b/frontend/src/api/user.ts @@ -153,10 +153,10 @@ export function buildOAuthBindingStartURL( return `${normalized}/auth/oauth/${provider}/start?${params.toString()}` } -export function startOAuthBinding( +export async function startOAuthBinding( provider: BindableOAuthProvider, options: BuildOAuthBindingStartURLOptions = {} -): void { +): Promise { if (typeof window === 'undefined') { return } @@ -164,7 +164,7 @@ export function startOAuthBinding( if (!startURL) { return } - prepareOAuthBindAccessTokenCookie() + await prepareOAuthBindAccessTokenCookie() window.location.href = startURL } diff --git a/frontend/src/router/__tests__/guards.spec.ts b/frontend/src/router/__tests__/guards.spec.ts index 11636139..bdf07b18 100644 --- a/frontend/src/router/__tests__/guards.spec.ts +++ b/frontend/src/router/__tests__/guards.spec.ts @@ -83,7 +83,8 @@ function simulateGuard( '/auth/callback', '/auth/linuxdo/callback', '/auth/oidc/callback', - '/auth/wechat/callback' + '/auth/wechat/callback', + '/auth/wechat/payment/callback', ] const pendingAuthPaths = ['/register', '/email-verify'] const isAllowed = @@ -131,7 +132,8 @@ function simulateGuard( '/auth/callback', '/auth/linuxdo/callback', '/auth/oidc/callback', - '/auth/wechat/callback' + '/auth/wechat/callback', + '/auth/wechat/payment/callback', ] const pendingAuthPaths = ['/register', '/email-verify'] const isAllowed = @@ -448,6 +450,18 @@ describe('路由守卫逻辑', () => { expect(redirect).toBeNull() }) + it('unauthenticated: WeChat payment callback route is allowed', () => { + const authState: MockAuthState = { + isAuthenticated: false, + isAdmin: false, + isSimpleMode: false, + backendModeEnabled: true, + hasPendingAuthSession: false, + } + const redirect = simulateGuard('/auth/wechat/payment/callback', { requiresAuth: false }, authState) + expect(redirect).toBeNull() + }) + it('unauthenticated: /register is allowed when a pending auth session exists', () => { const authState: MockAuthState = { isAuthenticated: false, diff --git a/frontend/src/router/__tests__/wechat-route.spec.ts b/frontend/src/router/__tests__/wechat-route.spec.ts index 84b20452..f85a732d 100644 --- a/frontend/src/router/__tests__/wechat-route.spec.ts +++ b/frontend/src/router/__tests__/wechat-route.spec.ts @@ -52,4 +52,13 @@ describe('router WeChat OAuth route', () => { expect(route?.meta.requiresAuth).toBe(false) expect(route?.meta.title).toBe('WeChat OAuth Callback') }) + + it('registers the WeChat payment callback route as a public route', async () => { + const { default: router } = await import('@/router') + const route = router.getRoutes().find((record) => record.name === 'WeChatPaymentOAuthCallback') + + expect(route?.path).toBe('/auth/wechat/payment/callback') + expect(route?.meta.requiresAuth).toBe(false) + expect(route?.meta.title).toBe('WeChat Payment Callback') + }) }) diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 1a73e8aa..b7fcf475 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -547,7 +547,8 @@ const BACKEND_MODE_CALLBACK_PATHS = [ '/auth/callback', '/auth/linuxdo/callback', '/auth/oidc/callback', - '/auth/wechat/callback' + '/auth/wechat/callback', + '/auth/wechat/payment/callback', ] const BACKEND_MODE_PENDING_AUTH_PATHS = ['/register', '/email-verify'] diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue index 2bcc1c3d..9a71f62b 100644 --- a/frontend/src/views/auth/WechatCallbackView.vue +++ b/frontend/src/views/auth/WechatCallbackView.vue @@ -613,7 +613,7 @@ async function handleBindCurrentAccount() { return } - prepareOAuthBindAccessTokenCookie() + await prepareOAuthBindAccessTokenCookie() window.location.href = startURL } diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index 57e81f40..1af34540 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -101,7 +101,11 @@ import { ref, computed, onBeforeUnmount, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { useRoute, useRouter } from 'vue-router' import OrderStatusBadge from '@/components/payment/OrderStatusBadge.vue' -import { PAYMENT_RECOVERY_STORAGE_KEY, readPaymentRecoverySnapshot } from '@/components/payment/paymentFlow' +import { + PAYMENT_RECOVERY_STORAGE_KEY, + clearPaymentRecoverySnapshot, + readPaymentRecoverySnapshot, +} from '@/components/payment/paymentFlow' import { usePaymentStore } from '@/stores/payment' import { paymentAPI } from '@/api/payment' import type { PaymentOrder } from '@/types/payment' @@ -193,6 +197,18 @@ function clearStatusRefreshTimer(): void { } } +function clearRecoverySnapshot(): void { + if (typeof window === 'undefined') return + clearPaymentRecoverySnapshot(window.localStorage, PAYMENT_RECOVERY_STORAGE_KEY) +} + +function clearRecoverySnapshotForTerminalStatus(status: string | null | undefined): void { + if (!status) return + if (!isPendingStatus(status)) { + clearRecoverySnapshot() + } +} + function scheduleStatusRefresh(refreshOrder: (() => Promise) | null): void { clearStatusRefreshTimer() if (!refreshOrder || !isPending.value || refreshAttempts.value >= STATUS_REFRESH_MAX_ATTEMPTS) { @@ -204,6 +220,7 @@ function scheduleStatusRefresh(refreshOrder: (() => Promise const refreshedOrder = await refreshOrder() if (refreshedOrder) { order.value = refreshedOrder + clearRecoverySnapshotForTerminalStatus(refreshedOrder.status) } if (isPendingStatus(order.value?.status)) { @@ -285,6 +302,10 @@ onMounted(async () => { if (isPendingStatus(order.value?.status)) { scheduleStatusRefresh(refreshOrder) + } else if (order.value) { + clearRecoverySnapshotForTerminalStatus(order.value.status) + } else if (returnInfo.value) { + clearRecoverySnapshot() } loading.value = false }) diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index 7d037917..10aa7019 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -391,6 +391,20 @@ function resetPayment() { removeRecoverySnapshot() } +async function redirectToPaymentResult(state: PaymentRecoverySnapshot): Promise { + const query: Record = {} + if (state.orderId > 0) { + query.order_id = String(state.orderId) + } + if (state.resumeToken) { + query.resume_token = state.resumeToken + } + await router.push({ + path: '/payment/result', + query, + }) +} + function onPaymentDone() { const wasSubscription = paymentState.value.orderType === 'subscription' resetPayment() @@ -684,8 +698,14 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n const errMsg = String(jsapiResult.err_msg || '').toLowerCase() if (errMsg.includes('cancel')) { appStore.showInfo(t('payment.qr.cancelled')) + resetPayment() } else if (errMsg && !errMsg.includes('ok')) { applyScenarioError({ reason: 'WECHAT_JSAPI_FAILED', message: errMsg }, visibleMethod) + resetPayment() + } else { + const resultState = { ...decision.paymentState } + resetPayment() + await redirectToPaymentResult(resultState) } return } diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts index 34ced07a..94ae6ef8 100644 --- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts @@ -60,6 +60,21 @@ const orderFactory = (status: string) => ({ refund_amount: 0, }) +const recoverySnapshotFactory = (resumeToken: string) => ({ + orderId: 42, + amount: 88, + qrCode: '', + expiresAt: '2099-01-01T00:10:00.000Z', + paymentType: 'alipay', + payUrl: 'https://pay.example.com/session/42', + clientSecret: '', + payAmount: 88, + orderType: 'balance', + paymentMode: 'popup', + resumeToken, + createdAt: Date.UTC(2099, 0, 1, 0, 0, 0), +}) + describe('PaymentResultView', () => { beforeEach(() => { routeState.query = {} @@ -162,6 +177,7 @@ describe('PaymentResultView', () => { expect(wrapper.text()).toContain('payment.result.success') expect(wrapper.text()).toContain('103.00') expect(wrapper.text()).toContain('100.00') + expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() }) it('refreshes a pending resume-token result until the order becomes paid', async () => { @@ -169,6 +185,10 @@ describe('PaymentResultView', () => { routeState.query = { resume_token: 'resume-77', } + window.localStorage.setItem( + PAYMENT_RECOVERY_STORAGE_KEY, + JSON.stringify(recoverySnapshotFactory('resume-77')), + ) resolveOrderPublicByResumeToken .mockResolvedValueOnce({ data: orderFactory('PENDING'), @@ -189,6 +209,7 @@ describe('PaymentResultView', () => { expect(resolveOrderPublicByResumeToken).toHaveBeenCalledTimes(1) expect(wrapper.text()).toContain('payment.result.processing') + expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).not.toBeNull() await vi.advanceTimersByTimeAsync(2000) await flushPromises() @@ -196,6 +217,7 @@ describe('PaymentResultView', () => { expect(resolveOrderPublicByResumeToken).toHaveBeenCalledTimes(2) expect(wrapper.text()).toContain('payment.result.success') expect(wrapper.text()).not.toContain('payment.result.failed') + expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() }) it('does not fall back to public out_trade_no verification when resume_token recovery fails', async () => { diff --git a/frontend/src/views/user/__tests__/PaymentView.spec.ts b/frontend/src/views/user/__tests__/PaymentView.spec.ts new file mode 100644 index 00000000..f60ea962 --- /dev/null +++ b/frontend/src/views/user/__tests__/PaymentView.spec.ts @@ -0,0 +1,205 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, shallowMount } from '@vue/test-utils' +import PaymentView from '../PaymentView.vue' +import { PAYMENT_RECOVERY_STORAGE_KEY } from '@/components/payment/paymentFlow' + +const routeState = vi.hoisted(() => ({ + path: '/purchase', + query: {} as Record, +})) + +const routerReplace = vi.hoisted(() => vi.fn()) +const routerPush = vi.hoisted(() => vi.fn()) +const routerResolve = vi.hoisted(() => vi.fn(() => ({ href: '/payment/stripe?mock=1' }))) +const createOrder = vi.hoisted(() => vi.fn()) +const refreshUser = vi.hoisted(() => vi.fn()) +const fetchActiveSubscriptions = vi.hoisted(() => vi.fn().mockResolvedValue(undefined)) +const showError = vi.hoisted(() => vi.fn()) +const showInfo = vi.hoisted(() => vi.fn()) +const getCheckoutInfo = vi.hoisted(() => vi.fn()) +const bridgeInvoke = vi.hoisted(() => vi.fn()) + +vi.mock('vue-router', async () => { + const actual = await vi.importActual('vue-router') + return { + ...actual, + useRoute: () => routeState, + useRouter: () => ({ + replace: routerReplace, + push: routerPush, + resolve: routerResolve, + }), + } +}) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key, + }), + } +}) + +vi.mock('@/stores/auth', () => ({ + useAuthStore: () => ({ + user: { + username: 'demo-user', + balance: 0, + }, + refreshUser, + }), +})) + +vi.mock('@/stores/payment', () => ({ + usePaymentStore: () => ({ + createOrder, + }), +})) + +vi.mock('@/stores/subscriptions', () => ({ + useSubscriptionStore: () => ({ + activeSubscriptions: [], + fetchActiveSubscriptions, + }), +})) + +vi.mock('@/stores', () => ({ + useAppStore: () => ({ + showError, + showInfo, + }), +})) + +vi.mock('@/api/payment', () => ({ + paymentAPI: { + getCheckoutInfo, + }, +})) + +vi.mock('@/utils/device', () => ({ + isMobileDevice: () => true, +})) + +function checkoutInfoFixture() { + return { + data: { + methods: { + wxpay: { + daily_limit: 0, + daily_used: 0, + daily_remaining: 0, + single_min: 0, + single_max: 0, + fee_rate: 0, + available: true, + }, + }, + global_min: 0, + global_max: 0, + plans: [], + balance_disabled: false, + balance_recharge_multiplier: 1, + recharge_fee_rate: 0, + help_text: '', + help_image_url: '', + stripe_publishable_key: '', + }, + } +} + +function jsapiOrderFixture(resumeToken: string) { + return { + order_id: 123, + amount: 88, + pay_amount: 88, + fee_rate: 0, + expires_at: '2099-01-01T00:10:00.000Z', + payment_type: 'wxpay', + result_type: 'jsapi_ready' as const, + resume_token: resumeToken, + jsapi: { + appId: 'wx123', + timeStamp: '1712345678', + nonceStr: 'nonce', + package: 'prepay_id=wx123', + signType: 'RSA', + paySign: 'signed', + }, + } +} + +describe('PaymentView WeChat JSAPI flow', () => { + beforeEach(() => { + routeState.path = '/purchase' + routeState.query = { + wechat_resume: '1', + wechat_resume_token: 'resume-token-123', + } + routerReplace.mockReset().mockResolvedValue(undefined) + routerPush.mockReset().mockResolvedValue(undefined) + routerResolve.mockClear() + createOrder.mockReset() + refreshUser.mockReset() + fetchActiveSubscriptions.mockReset().mockResolvedValue(undefined) + showError.mockReset() + showInfo.mockReset() + getCheckoutInfo.mockReset().mockResolvedValue(checkoutInfoFixture()) + bridgeInvoke.mockReset() + window.localStorage.clear() + ;(window as Window & { WeixinJSBridge?: { invoke: typeof bridgeInvoke } }).WeixinJSBridge = { + invoke: bridgeInvoke, + } + }) + + it('resets payment state and redirects to /payment/result after JSAPI reports success', async () => { + createOrder.mockResolvedValue(jsapiOrderFixture('resume-token-123')) + bridgeInvoke.mockImplementation((_action, _payload, callback) => { + callback({ err_msg: 'get_brand_wcpay_request:ok' }) + }) + + shallowMount(PaymentView, { + global: { + stubs: { + Teleport: true, + Transition: false, + }, + }, + }) + await flushPromises() + await flushPromises() + + expect(routerReplace).toHaveBeenCalledWith({ path: '/purchase', query: {} }) + expect(routerPush).toHaveBeenCalledWith({ + path: '/payment/result', + query: { + order_id: '123', + resume_token: 'resume-token-123', + }, + }) + expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() + }) + + it('resets payment state when JSAPI reports cancellation', async () => { + createOrder.mockResolvedValue(jsapiOrderFixture('resume-token-cancel')) + bridgeInvoke.mockImplementation((_action, _payload, callback) => { + callback({ err_msg: 'get_brand_wcpay_request:cancel' }) + }) + + shallowMount(PaymentView, { + global: { + stubs: { + Teleport: true, + Transition: false, + }, + }, + }) + await flushPromises() + await flushPromises() + + expect(showInfo).toHaveBeenCalledWith('payment.qr.cancelled') + expect(routerPush).not.toHaveBeenCalled() + expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() + }) +}) diff --git a/frontend/src/views/user/__tests__/paymentUx.spec.ts b/frontend/src/views/user/__tests__/paymentUx.spec.ts index c2a4ac59..8d73d1fa 100644 --- a/frontend/src/views/user/__tests__/paymentUx.spec.ts +++ b/frontend/src/views/user/__tests__/paymentUx.spec.ts @@ -28,6 +28,16 @@ describe('describePaymentScenarioError', () => { }) }) + it('maps WeChat H5 authorization errors when provider aliases use wxpay_direct', () => { + expect(describePaymentScenarioError( + { reason: 'WECHAT_H5_NOT_AUTHORIZED' }, + { paymentMethod: 'wxpay_direct', isMobile: true, isWechatBrowser: false }, + )).toEqual({ + messageKey: 'payment.errors.wechatH5NotAuthorized', + hintKey: 'payment.errors.wechatOpenInWeChatHint', + }) + }) + it('maps missing WeixinJSBridge to a JSAPI-specific prompt', () => { expect(describePaymentScenarioError( new Error('WeixinJSBridge is unavailable'), From dd314c41e3e23dbe598dd1e0f83811c5d95586e4 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 11:17:23 +0800 Subject: [PATCH 02/31] fix(payment): restore public resume and result flows --- backend/internal/handler/payment_handler.go | 84 ++++++---- .../handler/payment_handler_resume_test.go | 152 +++++++++++++++++- backend/internal/server/routes/payment.go | 6 +- backend/internal/service/payment_order.go | 9 +- .../service/payment_resume_service.go | 27 +++- .../service/payment_resume_service_test.go | 55 ++++++- frontend/src/api/__tests__/payment.spec.ts | 8 +- frontend/src/api/payment.ts | 5 + .../payment/__tests__/paymentFlow.spec.ts | 29 ++++ .../src/components/payment/paymentFlow.ts | 6 +- frontend/src/views/user/PaymentResultView.vue | 44 +++-- frontend/src/views/user/PaymentView.vue | 13 +- .../user/__tests__/PaymentResultView.spec.ts | 40 ++--- .../views/user/__tests__/PaymentView.spec.ts | 37 +++++ .../src/views/user/paymentWechatResume.ts | 10 +- 15 files changed, 435 insertions(+), 90 deletions(-) diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 16b25355..09580442 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -2,9 +2,9 @@ package handler import ( "fmt" - "net/http" "strconv" "strings" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/payment" @@ -454,29 +454,65 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) { // PublicOrderResult is the limited order info returned by the public verify endpoint. // No user details are exposed — only payment status information. type PublicOrderResult struct { - ID int64 `json:"id"` - OutTradeNo string `json:"out_trade_no"` - Amount float64 `json:"amount"` - PayAmount float64 `json:"pay_amount"` - PaymentType string `json:"payment_type"` - OrderType string `json:"order_type"` - Status string `json:"status"` + ID int64 `json:"id"` + OutTradeNo string `json:"out_trade_no"` + Amount float64 `json:"amount"` + PayAmount float64 `json:"pay_amount"` + FeeRate float64 `json:"fee_rate"` + PaymentType string `json:"payment_type"` + OrderType string `json:"order_type"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + PaidAt *time.Time `json:"paid_at,omitempty"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + RefundAmount float64 `json:"refund_amount"` + RefundReason *string `json:"refund_reason,omitempty"` + RefundRequestedAt *time.Time `json:"refund_requested_at,omitempty"` + RefundRequestedBy *string `json:"refund_requested_by,omitempty"` + RefundRequestReason *string `json:"refund_request_reason,omitempty"` + PlanID *int64 `json:"plan_id,omitempty"` } -var errPaymentPublicOrderVerifyRemoved = infraerrors.New( - http.StatusGone, - "PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED", - "public payment order verification by out_trade_no has been removed; use resume_token recovery instead", -).WithMetadata(map[string]string{ - "replacement_endpoint": "/api/v1/payment/public/orders/resolve", - "replacement_field": "resume_token", -}) +func buildPublicOrderResult(order *dbent.PaymentOrder) PublicOrderResult { + return PublicOrderResult{ + ID: order.ID, + OutTradeNo: order.OutTradeNo, + Amount: order.Amount, + PayAmount: order.PayAmount, + FeeRate: order.FeeRate, + PaymentType: order.PaymentType, + OrderType: order.OrderType, + Status: order.Status, + CreatedAt: order.CreatedAt, + ExpiresAt: order.ExpiresAt, + PaidAt: order.PaidAt, + CompletedAt: order.CompletedAt, + RefundAmount: order.RefundAmount, + RefundReason: order.RefundReason, + RefundRequestedAt: order.RefundRequestedAt, + RefundRequestedBy: order.RefundRequestedBy, + RefundRequestReason: order.RefundRequestReason, + PlanID: order.PlanID, + } +} -// VerifyOrderPublic is kept as a compatibility shim for the removed anonymous -// out_trade_no lookup endpoint and always returns HTTP 410 Gone. +// VerifyOrderPublic keeps the legacy anonymous out_trade_no lookup available as +// a compatibility path for older result pages and staggered deploys. // POST /api/v1/payment/public/orders/verify func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) { - response.ErrorFrom(c, errPaymentPublicOrderVerifyRemoved) + var req VerifyOrderRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, buildPublicOrderResult(order)) } // ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token. @@ -493,15 +529,7 @@ func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Success(c, PublicOrderResult{ - ID: order.ID, - OutTradeNo: order.OutTradeNo, - Amount: order.Amount, - PayAmount: order.PayAmount, - PaymentType: order.PaymentType, - OrderType: order.OrderType, - Status: order.Status, - }) + response.Success(c, buildPublicOrderResult(order)) } // requireAuth extracts the authenticated subject from the context. diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go index 28da15d9..5a2ecb46 100644 --- a/backend/internal/handler/payment_handler_resume_test.go +++ b/backend/internal/handler/payment_handler_resume_test.go @@ -4,16 +4,17 @@ package handler import ( "bytes" + "context" "database/sql" "encoding/json" "net/http" "net/http/httptest" "testing" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/internal/payment" - "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -74,7 +75,7 @@ func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T) } } -func TestVerifyOrderPublicReturnsGone(t *testing.T) { +func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) @@ -90,6 +91,32 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) { client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) t.Cleanup(func() { _ = client.Close() }) + user, err := client.User.Create(). + SetEmail("public-verify@example.com"). + SetPasswordHash("hash"). + SetUsername("public-verify-user"). + Save(context.Background()) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(90.64). + SetFeeRate(0.03). + SetRechargeCode("PUBLIC-VERIFY"). + SetOutTradeNo("legacy-order-no"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-public-verify"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(service.OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(context.Background()) + require.NoError(t, err) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil) h := NewPaymentHandler(paymentSvc, nil, nil) @@ -104,11 +131,122 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) { h.VerifyOrderPublic(ctx) - require.Equal(t, http.StatusGone, recorder.Code) + require.Equal(t, http.StatusOK, recorder.Code) - var resp response.Response + var resp struct { + Code int `json:"code"` + Data struct { + ID int64 `json:"id"` + OutTradeNo string `json:"out_trade_no"` + Amount float64 `json:"amount"` + PayAmount float64 `json:"pay_amount"` + FeeRate float64 `json:"fee_rate"` + PaymentType string `json:"payment_type"` + OrderType string `json:"order_type"` + Status string `json:"status"` + RefundAmount float64 `json:"refund_amount"` + CreatedAt string `json:"created_at"` + ExpiresAt string `json:"expires_at"` + } `json:"data"` + } require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) - require.Equal(t, http.StatusGone, resp.Code) - require.Equal(t, "PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED", resp.Reason) - require.Contains(t, resp.Message, "removed") + require.Equal(t, 0, resp.Code) + require.Equal(t, order.ID, resp.Data.ID) + require.Equal(t, "legacy-order-no", resp.Data.OutTradeNo) + require.Equal(t, 90.64, resp.Data.PayAmount) + require.Equal(t, 0.03, resp.Data.FeeRate) + require.Equal(t, payment.TypeAlipay, resp.Data.PaymentType) + require.Equal(t, payment.OrderTypeBalance, resp.Data.OrderType) + require.Equal(t, service.OrderStatusPending, resp.Data.Status) + require.Equal(t, 0.0, resp.Data.RefundAmount) + require.NotEmpty(t, resp.Data.CreatedAt) + require.NotEmpty(t, resp.Data.ExpiresAt) +} + +func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + + db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + user, err := client.User.Create(). + SetEmail("public-resolve@example.com"). + SetPasswordHash("hash"). + SetUsername("public-resolve-user"). + Save(context.Background()) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(100). + SetPayAmount(103). + SetFeeRate(0.03). + SetRechargeCode("PUBLIC-RESOLVE"). + SetOutTradeNo("resolve-order-no"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-public-resolve"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(service.OrderStatusPaid). + SetExpiresAt(time.Now().Add(time.Hour)). + SetPaidAt(time.Now()). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(context.Background()) + require.NoError(t, err) + + resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef")) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil) + h := NewPaymentHandler(paymentSvc, nil, nil) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest( + http.MethodPost, + "/api/v1/payment/public/orders/resolve", + bytes.NewBufferString(`{"resume_token":"`+token+`"}`), + ) + ctx.Request.Header.Set("Content-Type", "application/json") + + h.ResolveOrderPublicByResumeToken(ctx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, float64(order.ID), resp.Data["id"]) + require.Equal(t, "resolve-order-no", resp.Data["out_trade_no"]) + require.Equal(t, 100.0, resp.Data["amount"]) + require.Equal(t, 103.0, resp.Data["pay_amount"]) + require.Equal(t, 0.03, resp.Data["fee_rate"]) + require.Equal(t, payment.TypeAlipay, resp.Data["payment_type"]) + require.Equal(t, payment.OrderTypeBalance, resp.Data["order_type"]) + require.Equal(t, service.OrderStatusPaid, resp.Data["status"]) + require.Contains(t, resp.Data, "created_at") + require.Contains(t, resp.Data, "expires_at") + require.Contains(t, resp.Data, "refund_amount") } diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index ec340d94..e4828ead 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -44,9 +44,9 @@ func RegisterPaymentRoutes( } // --- Public payment endpoints (no auth) --- - // Signed resume-token recovery is the supported public lookup path. - // The legacy anonymous out_trade_no verify endpoint is kept only as a - // compatibility shim that returns HTTP 410 Gone. + // Signed resume-token recovery is the preferred public lookup path. + // The legacy anonymous out_trade_no verify endpoint remains available as a + // persisted-state compatibility path for staggered upgrades. public := v1.Group("/payment/public") { public.POST("/orders/verify", paymentHandler.VerifyOrderPublic) diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index 3fdcecb5..15d4509d 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -379,16 +379,13 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen } subject := s.buildPaymentSubject(plan, limitAmount, cfg) outTradeNo := order.OutTradeNo - canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost) + canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost, req.SrcURL) if err != nil { return nil, err } resumeToken := "" if resume := s.paymentResume(); resume != nil { - if canonicalReturnURL != "" { - if err := resume.ensureSigningKey(); err != nil { - return nil, err - } + if canonicalReturnURL != "" && resume.isSigningConfigured() { resumeToken, err = resume.CreateToken(ResumeTokenClaims{ OrderID: order.ID, UserID: order.UserID, @@ -402,7 +399,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen } } } - providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, resumeToken) + providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, outTradeNo, resumeToken) if err != nil { return nil, err } diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go index 6e8acccb..438aa59f 100644 --- a/backend/internal/service/payment_resume_service.go +++ b/backend/internal/service/payment_resume_service.go @@ -209,7 +209,7 @@ func visibleMethodSourceSettingKey(method string) string { } } -func CanonicalizeReturnURL(raw string, srcHost string) (string, error) { +func CanonicalizeReturnURL(raw string, srcHost string, srcURL string) (string, error) { raw = strings.TrimSpace(raw) if raw == "" { return "", nil @@ -228,13 +228,29 @@ func CanonicalizeReturnURL(raw string, srcHost string) (string, error) { if parsed.Path != paymentResultReturnPath { return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page") } - if !sameOriginHost(parsed.Host, srcHost) { - return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site") + if !allowedReturnURLHost(parsed.Host, srcHost, srcURL) { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site or browser origin") } return parsed.String(), nil } -func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (string, error) { +func allowedReturnURLHost(returnURLHost string, requestHost string, refererURL string) bool { + if sameOriginHost(returnURLHost, requestHost) { + return true + } + + refererURL = strings.TrimSpace(refererURL) + if refererURL == "" { + return false + } + parsedReferer, err := url.Parse(refererURL) + if err != nil || parsedReferer.Host == "" { + return false + } + return sameOriginHost(returnURLHost, parsedReferer.Host) +} + +func buildPaymentReturnURL(base string, orderID int64, outTradeNo string, resumeToken string) (string, error) { canonical := strings.TrimSpace(base) if canonical == "" { return "", nil @@ -253,6 +269,9 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri if orderID > 0 { query.Set("order_id", strconv.FormatInt(orderID, 10)) } + if strings.TrimSpace(outTradeNo) != "" { + query.Set("out_trade_no", strings.TrimSpace(outTradeNo)) + } if strings.TrimSpace(resumeToken) != "" { query.Set("resume_token", strings.TrimSpace(resumeToken)) } diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go index 78b6bba3..ffa55e69 100644 --- a/backend/internal/service/payment_resume_service_test.go +++ b/backend/internal/service/payment_resume_service_test.go @@ -64,7 +64,7 @@ func TestNormalizePaymentSource(t *testing.T) { func TestCanonicalizeReturnURL(t *testing.T) { t.Parallel() - got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com") + got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com", "") if err != nil { t.Fatalf("CanonicalizeReturnURL returned error: %v", err) } @@ -76,7 +76,7 @@ func TestCanonicalizeReturnURL(t *testing.T) { func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) { t.Parallel() - if _, err := CanonicalizeReturnURL("/payment/result", "example.com"); err == nil { + if _, err := CanonicalizeReturnURL("/payment/result", "example.com", ""); err == nil { t.Fatal("CanonicalizeReturnURL should reject relative URLs") } } @@ -84,15 +84,31 @@ func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) { func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) { t.Parallel() - if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com"); err == nil { + if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com", ""); err == nil { t.Fatal("CanonicalizeReturnURL should reject external hosts") } } +func TestCanonicalizeReturnURLAllowsConfiguredFrontendHost(t *testing.T) { + t.Parallel() + + got, err := CanonicalizeReturnURL( + "https://app.example.com/payment/result?from=checkout", + "api.example.com", + "https://app.example.com/purchase", + ) + if err != nil { + t.Fatalf("CanonicalizeReturnURL returned error: %v", err) + } + if got != "https://app.example.com/payment/result?from=checkout" { + t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://app.example.com/payment/result?from=checkout") + } +} + func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) { t.Parallel() - if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com"); err == nil { + if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com", ""); err == nil { t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths") } } @@ -100,7 +116,7 @@ func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) { func TestBuildPaymentReturnURL(t *testing.T) { t.Parallel() - got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "resume-token") + got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "sub2_42", "resume-token") if err != nil { t.Fatalf("buildPaymentReturnURL returned error: %v", err) } @@ -119,6 +135,9 @@ func TestBuildPaymentReturnURL(t *testing.T) { if query.Get("order_id") != strconv.FormatInt(42, 10) { t.Fatalf("order_id = %q", query.Get("order_id")) } + if query.Get("out_trade_no") != "sub2_42" { + t.Fatalf("out_trade_no = %q", query.Get("out_trade_no")) + } if query.Get("resume_token") != "resume-token" { t.Fatalf("resume_token = %q", query.Get("resume_token")) } @@ -127,10 +146,34 @@ func TestBuildPaymentReturnURL(t *testing.T) { } } +func TestBuildPaymentReturnURLWithoutResumeTokenStillIncludesOutTradeNo(t *testing.T) { + t.Parallel() + + got, err := buildPaymentReturnURL("https://example.com/payment/result", 42, "sub2_42", "") + if err != nil { + t.Fatalf("buildPaymentReturnURL returned error: %v", err) + } + + parsed, err := url.Parse(got) + if err != nil { + t.Fatalf("url.Parse returned error: %v", err) + } + query := parsed.Query() + if query.Get("order_id") != "42" { + t.Fatalf("order_id = %q", query.Get("order_id")) + } + if query.Get("out_trade_no") != "sub2_42" { + t.Fatalf("out_trade_no = %q", query.Get("out_trade_no")) + } + if query.Get("resume_token") != "" { + t.Fatalf("resume_token = %q, want empty", query.Get("resume_token")) + } +} + func TestBuildPaymentReturnURLEmptyBase(t *testing.T) { t.Parallel() - got, err := buildPaymentReturnURL("", 42, "resume-token") + got, err := buildPaymentReturnURL("", 42, "sub2_42", "resume-token") if err != nil { t.Fatalf("buildPaymentReturnURL returned error: %v", err) } diff --git a/frontend/src/api/__tests__/payment.spec.ts b/frontend/src/api/__tests__/payment.spec.ts index 3006484e..e38fba57 100644 --- a/frontend/src/api/__tests__/payment.spec.ts +++ b/frontend/src/api/__tests__/payment.spec.ts @@ -22,8 +22,12 @@ describe('payment api', () => { post.mockResolvedValue({ data: {} }) }) - it('does not expose anonymous public out_trade_no verification', () => { - expect(Object.prototype.hasOwnProperty.call(paymentAPI, 'verifyOrderPublic')).toBe(false) + it('keeps legacy public out_trade_no verification for upgrade compatibility', async () => { + await paymentAPI.verifyOrderPublic('legacy-order-no') + + expect(post).toHaveBeenCalledWith('/payment/public/orders/verify', { + out_trade_no: 'legacy-order-no', + }) }) it('keeps signed public resume-token resolve endpoint', async () => { diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts index e866e184..92b0ec90 100644 --- a/frontend/src/api/payment.ts +++ b/frontend/src/api/payment.ts @@ -67,6 +67,11 @@ export const paymentAPI = { return apiClient.post('/payment/orders/verify', { out_trade_no: outTradeNo }) }, + /** Legacy-compatible public order lookup by out_trade_no */ + verifyOrderPublic(outTradeNo: string) { + return apiClient.post('/payment/public/orders/verify', { out_trade_no: outTradeNo }) + }, + /** Resolve an order from a signed resume token without auth */ resolveOrderPublicByResumeToken(resumeToken: string) { return apiClient.post('/payment/public/orders/resolve', { resume_token: resumeToken }) diff --git a/frontend/src/components/payment/__tests__/paymentFlow.spec.ts b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts index 7f4d6186..48c77dfb 100644 --- a/frontend/src/components/payment/__tests__/paymentFlow.spec.ts +++ b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts @@ -73,6 +73,7 @@ describe('decidePaymentLaunch', () => { expect(decision.paymentState.paymentType).toBe('alipay') expect(decision.stripeMethod).toBe('alipay') expect(decision.recovery.resumeToken).toBe('resume-1') + expect(decision.recovery.outTradeNo).toBe('') }) it('uses Stripe route flow for mobile WeChat client secret', () => { @@ -94,6 +95,7 @@ describe('decidePaymentLaunch', () => { pay_url: 'https://pay.example.com/session/abc', payment_mode: 'popup', resume_token: 'resume-2', + out_trade_no: 'sub2_abc', }), { visibleMethod: 'wxpay', orderType: 'balance', @@ -103,6 +105,7 @@ describe('decidePaymentLaunch', () => { expect(decision.kind).toBe('redirect_waiting') expect(decision.paymentState.payUrl).toBe('https://pay.example.com/session/abc') expect(decision.recovery.paymentMode).toBe('popup') + expect(decision.recovery.outTradeNo).toBe('sub2_abc') expect(decision.recovery.resumeToken).toBe('resume-2') }) @@ -225,6 +228,7 @@ describe('readPaymentRecoverySnapshot', () => { expiresAt: '2099-01-01T00:10:00.000Z', paymentType: 'alipay', payUrl: 'https://pay.example.com/session/33', + outTradeNo: 'sub2_33', clientSecret: '', payAmount: 18, orderType: 'balance', @@ -249,6 +253,7 @@ describe('readPaymentRecoverySnapshot', () => { expiresAt: '2024-01-01T00:10:00.000Z', paymentType: 'wxpay', payUrl: 'https://pay.example.com/session/55', + outTradeNo: 'sub2_55', clientSecret: '', payAmount: 18, orderType: 'balance', @@ -264,10 +269,34 @@ describe('readPaymentRecoverySnapshot', () => { expect(readPaymentRecoverySnapshot(JSON.stringify({ ...expiredSnapshot, + outTradeNo: 'sub2_55', expiresAt: '2099-01-01T00:10:00.000Z', }), { now: Date.UTC(2099, 0, 1, 0, 1, 0), resumeToken: 'other-token', })).toBeNull() }) + + it('keeps backward compatibility with snapshots written before outTradeNo existed', () => { + const restored = readPaymentRecoverySnapshot(JSON.stringify({ + orderId: 44, + amount: 18, + qrCode: '', + expiresAt: '2099-01-01T00:10:00.000Z', + paymentType: 'alipay', + payUrl: 'https://pay.example.com/session/44', + clientSecret: '', + payAmount: 18, + orderType: 'balance', + paymentMode: 'popup', + resumeToken: 'resume-44', + createdAt: Date.UTC(2099, 0, 1, 0, 0, 0), + }), { + now: Date.UTC(2099, 0, 1, 0, 1, 0), + resumeToken: 'resume-44', + }) + + expect(restored?.orderId).toBe(44) + expect(restored?.outTradeNo).toBe('') + }) }) diff --git a/frontend/src/components/payment/paymentFlow.ts b/frontend/src/components/payment/paymentFlow.ts index 7fbc1435..05f36ed0 100644 --- a/frontend/src/components/payment/paymentFlow.ts +++ b/frontend/src/components/payment/paymentFlow.ts @@ -34,6 +34,7 @@ export interface PaymentRecoverySnapshot { expiresAt: string paymentType: string payUrl: string + outTradeNo: string clientSecret: string payAmount: number orderType: OrderType | '' @@ -132,6 +133,7 @@ export function decidePaymentLaunch( expiresAt: result.expires_at || '', paymentType: visibleMethod, payUrl: result.pay_url || '', + outTradeNo: result.out_trade_no || '', clientSecret: result.client_secret || '', payAmount: result.pay_amount, orderType: context.orderType, @@ -227,6 +229,7 @@ export function readPaymentRecoverySnapshot( || typeof parsed.expiresAt !== 'string' || typeof parsed.paymentType !== 'string' || typeof parsed.payUrl !== 'string' + || (parsed.outTradeNo != null && typeof parsed.outTradeNo !== 'string') || typeof parsed.clientSecret !== 'string' || typeof parsed.payAmount !== 'number' || typeof parsed.paymentMode !== 'string' @@ -241,7 +244,7 @@ export function readPaymentRecoverySnapshot( if (Number.isFinite(expiresAt) && expiresAt <= now) { return null } - if (options.resumeToken && parsed.resumeToken && parsed.resumeToken !== options.resumeToken) { + if (options.resumeToken && parsed.resumeToken !== options.resumeToken) { return null } @@ -252,6 +255,7 @@ export function readPaymentRecoverySnapshot( expiresAt: parsed.expiresAt, paymentType: parsed.paymentType, payUrl: parsed.payUrl, + outTradeNo: parsed.outTradeNo || '', clientSecret: parsed.clientSecret, payAmount: parsed.payAmount, orderType: parsed.orderType === 'subscription' ? 'subscription' : 'balance', diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index 1af34540..cbebaa83 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -190,6 +190,15 @@ async function resolveOrderFromResumeToken(resumeToken: string): Promise { + try { + const result = await paymentAPI.verifyOrderPublic(outTradeNo) + return result.data + } catch (_err: unknown) { + return null + } +} + function clearStatusRefreshTimer(): void { if (statusRefreshTimer !== null) { clearTimeout(statusRefreshTimer) @@ -234,24 +243,19 @@ onMounted(async () => { ? route.query.resume_token : '' const routeOrderId = Number(route.query.order_id) || 0 - const outTradeNo = String(route.query.out_trade_no || '') + let outTradeNo = String(route.query.out_trade_no || '') let orderId = 0 - if (resumeToken && typeof window !== 'undefined') { + if (typeof window !== 'undefined') { const restored = readPaymentRecoverySnapshot( window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY), - { resumeToken }, + resumeToken ? { resumeToken } : {}, ) if (restored?.orderId) { orderId = restored.orderId } - } - - if (!order.value && resumeToken && orderId) { - try { - order.value = await paymentStore.pollOrderStatus(orderId) - } catch (_err: unknown) { - // Fall through to signed resume-token recovery below. + if (!outTradeNo && restored?.outTradeNo) { + outTradeNo = restored.outTradeNo } } @@ -269,6 +273,20 @@ onMounted(async () => { orderId = routeOrderId } + const hasLegacyFallbackContext = typeof route.query.trade_status === 'string' + && route.query.trade_status.trim() !== '' + const shouldUsePublicOutTradeNo = !resumeToken && outTradeNo !== '' && (hasLegacyFallbackContext || routeOrderId > 0 || orderId > 0) + + if (!order.value && shouldUsePublicOutTradeNo) { + const legacyOrder = await resolveOrderFromOutTradeNo(outTradeNo) + if (legacyOrder) { + order.value = legacyOrder + if (!orderId) { + orderId = legacyOrder.id + } + } + } + if (!order.value && !resumeToken && orderId) { try { order.value = await paymentStore.pollOrderStatus(orderId) @@ -277,8 +295,6 @@ onMounted(async () => { } } - const hasLegacyFallbackContext = typeof route.query.trade_status === 'string' - && route.query.trade_status.trim() !== '' if (!order.value && !resumeToken && !orderId && outTradeNo && hasLegacyFallbackContext) { returnInfo.value = { outTradeNo, @@ -293,6 +309,10 @@ onMounted(async () => { return await resolveOrderFromResumeToken(resumeToken) } + if (shouldUsePublicOutTradeNo) { + return await resolveOrderFromOutTradeNo(outTradeNo) + } + if (orderId) { return await paymentStore.pollOrderStatus(orderId) } diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index 10aa7019..1577039e 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -276,7 +276,7 @@ import PaymentStatusPanel from '@/components/payment/PaymentStatusPanel.vue' import Icon from '@/components/icons/Icon.vue' import type { PaymentMethodOption } from '@/components/payment/PaymentMethodSelector.vue' import { buildPaymentErrorToastMessage, describePaymentScenarioError } from './paymentUx' -import { parseWechatResumeRoute, stripWechatResumeQuery } from './paymentWechatResume' +import { hasWechatResumeQuery, parseWechatResumeRoute, stripWechatResumeQuery } from './paymentWechatResume' const { t } = useI18n() const route = useRoute() @@ -329,6 +329,7 @@ function emptyPaymentState(): PaymentRecoverySnapshot { expiresAt: '', paymentType: '', payUrl: '', + outTradeNo: '', clientSecret: '', payAmount: 0, orderType: '', @@ -396,6 +397,9 @@ async function redirectToPaymentResult(state: PaymentRecoverySnapshot): Promise< if (state.orderId > 0) { query.order_id = String(state.orderId) } + if (state.outTradeNo) { + query.out_trade_no = state.outTradeNo + } if (state.resumeToken) { query.resume_token = state.resumeToken } @@ -809,9 +813,14 @@ onMounted(async () => { selectedMethod.value = sorted[0] } if (typeof window !== 'undefined') { + if (hasWechatResumeQuery(route.query)) { + removeRecoverySnapshot() + } const routeResumeToken = typeof route.query.resume_token === 'string' ? route.query.resume_token - : undefined + : typeof route.query.wechat_resume_token === 'string' + ? route.query.wechat_resume_token + : undefined const restored = readPaymentRecoverySnapshot( window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY), { resumeToken: routeResumeToken }, diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts index 94ae6ef8..91741963 100644 --- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts @@ -7,7 +7,7 @@ const routeState = vi.hoisted(() => ({ const routerPush = vi.hoisted(() => vi.fn()) const pollOrderStatus = vi.hoisted(() => vi.fn()) -const verifyOrder = vi.hoisted(() => vi.fn()) +const verifyOrderPublic = vi.hoisted(() => vi.fn()) const resolveOrderPublicByResumeToken = vi.hoisted(() => vi.fn()) vi.mock('vue-router', async () => { @@ -37,7 +37,7 @@ vi.mock('@/stores/payment', () => ({ vi.mock('@/api/payment', () => ({ paymentAPI: { - verifyOrder, + verifyOrderPublic, resolveOrderPublicByResumeToken, }, })) @@ -67,6 +67,7 @@ const recoverySnapshotFactory = (resumeToken: string) => ({ expiresAt: '2099-01-01T00:10:00.000Z', paymentType: 'alipay', payUrl: 'https://pay.example.com/session/42', + outTradeNo: 'sub2_20260420abcd1234', clientSecret: '', payAmount: 88, orderType: 'balance', @@ -80,7 +81,7 @@ describe('PaymentResultView', () => { routeState.query = {} routerPush.mockReset() pollOrderStatus.mockReset() - verifyOrder.mockReset() + verifyOrderPublic.mockReset() resolveOrderPublicByResumeToken.mockReset() window.localStorage.clear() }) @@ -102,6 +103,7 @@ describe('PaymentResultView', () => { expiresAt: '2099-01-01T00:10:00.000Z', paymentType: 'alipay', payUrl: 'https://pay.example.com/session/42', + outTradeNo: 'sub2_20260420abcd1234', clientSecret: '', payAmount: 88, orderType: 'balance', @@ -109,7 +111,9 @@ describe('PaymentResultView', () => { resumeToken: 'resume-42', createdAt: Date.UTC(2099, 0, 1, 0, 0, 0), })) - pollOrderStatus.mockResolvedValue(orderFactory('PENDING')) + resolveOrderPublicByResumeToken.mockResolvedValue({ + data: orderFactory('PENDING'), + }) const wrapper = mount(PaymentResultView, { global: { @@ -121,7 +125,8 @@ describe('PaymentResultView', () => { await flushPromises() - expect(pollOrderStatus).toHaveBeenCalledWith(42) + expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-42') + expect(pollOrderStatus).not.toHaveBeenCalled() expect(wrapper.text()).toContain('payment.result.processing') expect(wrapper.text()).not.toContain('payment.result.success') expect(wrapper.text()).not.toContain('payment.result.failed') @@ -140,6 +145,7 @@ describe('PaymentResultView', () => { expiresAt: '2099-01-01T00:10:00.000Z', paymentType: 'alipay', payUrl: 'https://pay.example.com/session/42', + outTradeNo: 'sub2_20260420abcd1234', clientSecret: '', payAmount: 88, orderType: 'balance', @@ -147,12 +153,6 @@ describe('PaymentResultView', () => { resumeToken: 'resume-authoritative', createdAt: Date.UTC(2099, 0, 1, 0, 0, 0), })) - pollOrderStatus.mockResolvedValue({ - ...orderFactory('PENDING'), - amount: 88, - pay_amount: 88, - fee_rate: 0, - }) resolveOrderPublicByResumeToken.mockResolvedValue({ data: { ...orderFactory('PAID'), @@ -172,7 +172,7 @@ describe('PaymentResultView', () => { await flushPromises() - expect(pollOrderStatus).toHaveBeenCalledWith(42) + expect(pollOrderStatus).not.toHaveBeenCalled() expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-authoritative') expect(wrapper.text()).toContain('payment.result.success') expect(wrapper.text()).toContain('103.00') @@ -227,7 +227,6 @@ describe('PaymentResultView', () => { trade_status: 'TRADE_SUCCESS', } resolveOrderPublicByResumeToken.mockRejectedValueOnce(new Error('resume failed')) - mount(PaymentResultView, { global: { stubs: { @@ -239,16 +238,19 @@ describe('PaymentResultView', () => { await flushPromises() expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-fail') - expect(verifyOrder).not.toHaveBeenCalled() + expect(verifyOrderPublic).not.toHaveBeenCalled() }) - it('does not use anonymous out_trade_no verification when no signed resume context is available', async () => { + it('uses public out_trade_no verification when no signed resume context is available', async () => { routeState.query = { out_trade_no: 'legacy-123', trade_status: 'TRADE_SUCCESS', } + verifyOrderPublic.mockResolvedValue({ + data: orderFactory('PAID'), + }) - mount(PaymentResultView, { + const wrapper = mount(PaymentResultView, { global: { stubs: { OrderStatusBadge: true, @@ -258,7 +260,9 @@ describe('PaymentResultView', () => { await flushPromises() - expect(verifyOrder).not.toHaveBeenCalled() + expect(verifyOrderPublic).toHaveBeenCalledWith('legacy-123') + expect(pollOrderStatus).not.toHaveBeenCalled() + expect(wrapper.text()).toContain('payment.result.success') }) it('does not use public out_trade_no verification for bare order numbers without legacy return markers', async () => { @@ -276,7 +280,7 @@ describe('PaymentResultView', () => { await flushPromises() - expect(verifyOrder).not.toHaveBeenCalled() + expect(verifyOrderPublic).not.toHaveBeenCalled() }) it('resolves order by resume token when local recovery snapshot is missing', async () => { diff --git a/frontend/src/views/user/__tests__/PaymentView.spec.ts b/frontend/src/views/user/__tests__/PaymentView.spec.ts index f60ea962..66648da4 100644 --- a/frontend/src/views/user/__tests__/PaymentView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentView.spec.ts @@ -117,6 +117,7 @@ function jsapiOrderFixture(resumeToken: string) { fee_rate: 0, expires_at: '2099-01-01T00:10:00.000Z', payment_type: 'wxpay', + out_trade_no: 'sub2_jsapi_123', result_type: 'jsapi_ready' as const, resume_token: resumeToken, jsapi: { @@ -175,6 +176,7 @@ describe('PaymentView WeChat JSAPI flow', () => { path: '/payment/result', query: { order_id: '123', + out_trade_no: 'sub2_jsapi_123', resume_token: 'resume-token-123', }, }) @@ -202,4 +204,39 @@ describe('PaymentView WeChat JSAPI flow', () => { expect(routerPush).not.toHaveBeenCalled() expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() }) + + it('clears a stale recovery snapshot before handling wechat resume callback params', async () => { + createOrder.mockRejectedValueOnce(new Error('resume failed')) + window.localStorage.setItem(PAYMENT_RECOVERY_STORAGE_KEY, JSON.stringify({ + orderId: 999, + amount: 66, + qrCode: 'stale-qr', + expiresAt: '2099-01-01T00:10:00.000Z', + paymentType: 'alipay', + payUrl: 'https://pay.example.com/stale', + outTradeNo: 'stale-out-trade-no', + clientSecret: '', + payAmount: 66, + orderType: 'balance', + paymentMode: 'popup', + resumeToken: '', + createdAt: Date.UTC(2099, 0, 1, 0, 0, 0), + })) + + shallowMount(PaymentView, { + global: { + stubs: { + Teleport: true, + Transition: false, + }, + }, + }) + await flushPromises() + await flushPromises() + + expect(createOrder).toHaveBeenCalledWith(expect.objectContaining({ + wechat_resume_token: 'resume-token-123', + })) + expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() + }) }) diff --git a/frontend/src/views/user/paymentWechatResume.ts b/frontend/src/views/user/paymentWechatResume.ts index f53c8457..64f254da 100644 --- a/frontend/src/views/user/paymentWechatResume.ts +++ b/frontend/src/views/user/paymentWechatResume.ts @@ -19,12 +19,20 @@ function readQueryString(query: LocationQuery, key: string): string { return typeof value === 'string' ? value : '' } +export function hasWechatResumeQuery(query: LocationQuery): boolean { + if (readQueryString(query, 'wechat_resume') === '1') { + return true + } + return readQueryString(query, 'wechat_resume_token') !== '' + || readQueryString(query, 'openid') !== '' +} + export function parseWechatResumeRoute( query: LocationQuery, plans: SubscriptionPlan[], fallbackBalanceAmount: number, ): ParsedWechatResumeRoute | null { - if (readQueryString(query, 'wechat_resume') !== '1') { + if (!hasWechatResumeQuery(query)) { return null } From 84628108fc8dd5437ebfff4ad8f6ceaf8b24091f Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 11:17:32 +0800 Subject: [PATCH 03/31] fix(auth): preserve backward-compatible oauth defaults --- backend/internal/config/config.go | 15 +- backend/internal/config/config_test.go | 38 ++++- .../internal/handler/admin/setting_handler.go | 40 +++--- .../internal/handler/auth_linuxdo_oauth.go | 49 +++---- .../handler/auth_linuxdo_oauth_test.go | 79 ++++++++++- backend/internal/handler/auth_oidc_oauth.go | 134 +++++++++++------- .../internal/handler/auth_oidc_oauth_test.go | 88 +++++++++++- backend/internal/handler/auth_wechat_oauth.go | 7 +- .../handler/auth_wechat_oauth_test.go | 80 +++++++++++ backend/internal/service/setting_service.go | 8 +- .../setting_service_oidc_config_test.go | 44 ++++++ .../service/setting_service_public_test.go | 20 +++ backend/internal/service/user_service.go | 51 ++++++- backend/internal/service/user_service_test.go | 67 +++++++++ .../ProfileIdentityBindingsSection.vue | 25 +++- .../ProfileIdentityBindingsSection.spec.ts | 22 +++ frontend/src/views/admin/SettingsView.vue | 12 +- .../admin/__tests__/SettingsView.spec.ts | 24 ++++ 18 files changed, 661 insertions(+), 142 deletions(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index f355a15d..32ad91b7 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1202,7 +1202,7 @@ func setDefaults() { viper.SetDefault("linuxdo_connect.redirect_url", "") viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback") viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post") - viper.SetDefault("linuxdo_connect.use_pkce", true) + viper.SetDefault("linuxdo_connect.use_pkce", false) viper.SetDefault("linuxdo_connect.userinfo_email_path", "") viper.SetDefault("linuxdo_connect.userinfo_id_path", "") viper.SetDefault("linuxdo_connect.userinfo_username_path", "") @@ -1222,8 +1222,8 @@ func setDefaults() { viper.SetDefault("oidc_connect.redirect_url", "") viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback") viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post") - viper.SetDefault("oidc_connect.use_pkce", true) - viper.SetDefault("oidc_connect.validate_id_token", true) + viper.SetDefault("oidc_connect.use_pkce", false) + viper.SetDefault("oidc_connect.validate_id_token", false) viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256") viper.SetDefault("oidc_connect.clock_skew_seconds", 120) viper.SetDefault("oidc_connect.require_email_verified", false) @@ -1613,9 +1613,6 @@ func (c *Config) Validate() error { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } if c.LinuxDo.Enabled { - if !c.LinuxDo.UsePKCE { - return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.enabled=true") - } if strings.TrimSpace(c.LinuxDo.ClientID) == "" { return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") } @@ -1668,12 +1665,6 @@ func (c *Config) Validate() error { warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) } if c.OIDC.Enabled { - if !c.OIDC.UsePKCE { - return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.enabled=true") - } - if !c.OIDC.ValidateIDToken { - return fmt.Errorf("oidc_connect.validate_id_token must be true when oidc_connect.enabled=true") - } if strings.TrimSpace(c.OIDC.ClientID) == "" { return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index fe48541b..f40a5f57 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -346,7 +346,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { } } -func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { +func TestValidateLinuxDoAllowsDisablingPKCEForCompatibility(t *testing.T) { resetViperWithJWTSecret(t) cfg, err := Load() @@ -363,11 +363,8 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { cfg.LinuxDo.UsePKCE = false err = cfg.Validate() - if err == nil { - t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil") - } - if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") { - t.Fatalf("Validate() expected use_pkce error, got: %v", err) + if err != nil { + t.Fatalf("Validate() expected LinuxDo config without PKCE to pass for compatibility, got: %v", err) } } @@ -427,6 +424,35 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T } } +func TestValidateOIDCAllowsDisablingPKCEAndIDTokenValidation(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.OIDC.Enabled = true + cfg.OIDC.ClientID = "oidc-client" + cfg.OIDC.ClientSecret = "oidc-secret" + cfg.OIDC.IssuerURL = "https://issuer.example.com" + cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth" + cfg.OIDC.TokenURL = "https://issuer.example.com/token" + cfg.OIDC.UserInfoURL = "https://issuer.example.com/userinfo" + cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback" + cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback" + cfg.OIDC.Scopes = "openid email profile" + cfg.OIDC.UsePKCE = false + cfg.OIDC.ValidateIDToken = false + cfg.OIDC.JWKSURL = "" + cfg.OIDC.AllowedSigningAlgs = "" + + err = cfg.Validate() + if err != nil { + t.Fatalf("Validate() expected OIDC config without PKCE/id_token validation to pass for compatibility, got: %v", err) + } +} + func TestLoadDefaultDashboardCacheConfig(t *testing.T) { resetViperWithJWTSecret(t) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index e6609c97..f85f199b 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -653,20 +653,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode) } } - if req.WeChatConnectRedirectURL == "" { - response.BadRequest(c, "WeChat Redirect URL is required when enabled") - return - } - if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil { - response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL") - return - } - if req.WeChatConnectFrontendRedirectURL == "" { - req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback" - } - if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil { - response.BadRequest(c, "WeChat Frontend Redirect URL is invalid") - return + if req.WeChatConnectOpenEnabled || req.WeChatConnectMPEnabled { + if req.WeChatConnectRedirectURL == "" { + response.BadRequest(c, "WeChat Redirect URL is required when web oauth is enabled") + return + } + if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil { + response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL") + return + } + if req.WeChatConnectFrontendRedirectURL == "" { + req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback" + } + if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil { + response.BadRequest(c, "WeChat Frontend Redirect URL is invalid") + return + } } } @@ -749,14 +751,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.BadRequest(c, "OIDC scopes must contain openid") return } - if !req.OIDCConnectUsePKCE { - response.BadRequest(c, "OIDC PKCE must be enabled") - return - } - if !req.OIDCConnectValidateIDToken { - response.BadRequest(c, "OIDC ID Token validation must be enabled") - return - } switch req.OIDCConnectTokenAuthMethod { case "", "client_secret_post", "client_secret_basic", "none": default: @@ -767,7 +761,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600") return } - if req.OIDCConnectAllowedSigningAlgs == "" { + if req.OIDCConnectValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" { response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") return } diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 2bd44e78..ef9a5bca 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -123,13 +123,16 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie) } - verifier, err := oauth.GenerateCodeVerifier() - if err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) - return + codeChallenge := "" + if cfg.UsePKCE { + verifier, err := oauth.GenerateCodeVerifier() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) + return + } + codeChallenge = oauth.GenerateCodeChallenge(verifier) + setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) } - codeChallenge := oauth.GenerateCodeChallenge(verifier) - setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) redirectURI := strings.TrimSpace(cfg.RedirectURL) if redirectURI == "" { @@ -200,10 +203,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName) intent = normalizeOAuthIntent(intent) - codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie) - if codeVerifier == "" { - redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") - return + codeVerifier := "" + if cfg.UsePKCE { + codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return + } } redirectURI := strings.TrimSpace(cfg.RedirectURL) @@ -292,25 +298,16 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } if existingIdentityUser != nil { - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") - if err != nil { - redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) - return - } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ Intent: oauthIntentLogin, Identity: identityKey, - TargetUserID: &user.ID, + TargetUserID: &existingIdentityUser.ID, ResolvedEmail: existingIdentityUser.Email, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ - "access_token": tokenPair.AccessToken, - "refresh_token": tokenPair.RefreshToken, - "expires_in": tokenPair.ExpiresIn, - "token_type": "Bearer", - "redirect": redirectTo, + "redirect": redirectTo, }, }); err != nil { redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") @@ -546,7 +543,9 @@ func linuxDoExchangeCode( form.Set("client_id", cfg.ClientID) form.Set("code", code) form.Set("redirect_uri", redirectURI) - form.Set("code_verifier", codeVerifier) + if strings.TrimSpace(codeVerifier) != "" { + form.Set("code_verifier", codeVerifier) + } r := client.R(). SetContext(ctx). @@ -699,8 +698,10 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod q.Set("scope", cfg.Scopes) } q.Set("state", state) - q.Set("code_challenge", codeChallenge) - q.Set("code_challenge_method", "S256") + if strings.TrimSpace(codeChallenge) != "" { + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + } u.RawQuery = q.Encode() return u.String(), nil diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index a3d87dfb..841dc442 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -171,6 +171,80 @@ func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) { require.Equal(t, int64(42), userID) } +func TestLinuxDoOAuthStartOmitsPKCEWhenDisabled(t *testing.T) { + handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: "https://connect.linux.do/oauth/authorize", + TokenURL: "https://connect.linux.do/oauth/token", + UserInfoURL: "https://connect.linux.do/api/user", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: false, + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?redirect=/dashboard", nil) + + handler.LinuxDoOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.NotContains(t, recorder.Header().Get("Location"), "code_challenge=") + require.Nil(t, findCookie(recorder.Result().Cookies(), linuxDoOAuthVerifierCookie)) +} + +func TestLinuxDoOAuthCallbackAllowsMissingVerifierWhenPKCEDisabled(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + require.NoError(t, r.ParseForm()) + require.Empty(t, r.PostForm.Get("code_verifier")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"compat-subject","username":"linuxdo_user","name":"LinuxDo Display"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: false, + }) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=linuxdo-code&state=state-123", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) +} + func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) { handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ Enabled: true, @@ -327,7 +401,10 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) require.True(t, ok) require.Equal(t, "/dashboard", completion["redirect"]) - require.NotEmpty(t, completion["access_token"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) + _, hasRefreshToken := completion["refresh_token"] + require.False(t, hasRefreshToken) require.Nil(t, completion["error"]) } diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index d2042a87..7fe4b8d9 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -157,21 +157,25 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) { } codeChallenge := "" - verifier, genErr := oauth.GenerateCodeVerifier() - if genErr != nil { - response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr)) - return + if cfg.UsePKCE { + verifier, genErr := oauth.GenerateCodeVerifier() + if genErr != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr)) + return + } + codeChallenge = oauth.GenerateCodeChallenge(verifier) + oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie) } - codeChallenge = oauth.GenerateCodeChallenge(verifier) - oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie) nonce := "" - nonce, err = oauth.GenerateState() - if err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err)) - return + if cfg.ValidateIDToken { + nonce, err = oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err)) + return + } + oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie) } - oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie) redirectURI := strings.TrimSpace(cfg.RedirectURL) if redirectURI == "" { @@ -244,17 +248,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { intent = normalizeOAuthIntent(intent) codeVerifier := "" - codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie) - if codeVerifier == "" { - redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") - return + if cfg.UsePKCE { + codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return + } } expectedNonce := "" - expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie) - if expectedNonce == "" { - redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "") - return + if cfg.ValidateIDToken { + expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie) + if expectedNonce == "" { + redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "") + return + } } redirectURI := strings.TrimSpace(cfg.RedirectURL) @@ -284,16 +292,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } - if strings.TrimSpace(tokenResp.IDToken) == "" { - redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "") - return - } + var idClaims *oidcIDTokenClaims + if cfg.ValidateIDToken { + if strings.TrimSpace(tokenResp.IDToken) == "" { + redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "") + return + } - idClaims, err := oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce) - if err != nil { - log.Printf("[OIDC OAuth] id_token validation failed: %v", err) - redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "") - return + idClaims, err = oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce) + if err != nil { + log.Printf("[OIDC OAuth] id_token validation failed: %v", err) + redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "") + return + } } userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp) @@ -303,7 +314,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } - subject := strings.TrimSpace(idClaims.Subject) + subject := "" + if idClaims != nil { + subject = strings.TrimSpace(idClaims.Subject) + } if subject == "" { subject = strings.TrimSpace(userInfoClaims.Subject) } @@ -311,7 +325,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "") return } - issuer := strings.TrimSpace(idClaims.Issuer) + issuer := "" + if idClaims != nil { + issuer = strings.TrimSpace(idClaims.Issuer) + } if issuer == "" { issuer = strings.TrimSpace(cfg.IssuerURL) } @@ -321,21 +338,34 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { } emailVerified := userInfoClaims.EmailVerified - if emailVerified == nil { + if emailVerified == nil && idClaims != nil { emailVerified = idClaims.EmailVerified } - if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) { + if idClaims != nil && userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) { redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "") return } identityKey := oidcIdentityKey(issuer, subject) - compatEmail := strings.TrimSpace(firstNonEmpty(userInfoClaims.Email, idClaims.Email)) + compatEmail := strings.TrimSpace(userInfoClaims.Email) + if compatEmail == "" && idClaims != nil { + compatEmail = strings.TrimSpace(idClaims.Email) + } email := oidcSyntheticEmailFromIdentityKey(identityKey) username := firstNonEmpty( userInfoClaims.Username, - idClaims.PreferredUsername, - idClaims.Name, + func() string { + if idClaims != nil { + return idClaims.PreferredUsername + } + return "" + }(), + func() string { + if idClaims != nil { + return idClaims.Name + } + return "" + }(), oidcFallbackUsername(subject), ) identityRef := service.PendingAuthIdentityKey{ @@ -350,7 +380,12 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { "issuer": issuer, "email_verified": emailVerified != nil && *emailVerified, "provider_fallback": strings.TrimSpace(cfg.ProviderName), - "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string { + if idClaims != nil { + return idClaims.Name + } + return "" + }(), username), "suggested_avatar_url": userInfoClaims.AvatarURL, } if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) { @@ -387,25 +422,16 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } if existingIdentityUser != nil { - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") - if err != nil { - redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) - return - } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ Intent: oauthIntentLogin, Identity: identityRef, - TargetUserID: &user.ID, + TargetUserID: &existingIdentityUser.ID, ResolvedEmail: existingIdentityUser.Email, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ - "access_token": tokenPair.AccessToken, - "refresh_token": tokenPair.RefreshToken, - "expires_in": tokenPair.ExpiresIn, - "token_type": "Bearer", - "redirect": redirectTo, + "redirect": redirectTo, }, }); err != nil { redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") @@ -670,7 +696,9 @@ func oidcExchangeCode( form.Set("client_id", cfg.ClientID) form.Set("code", code) form.Set("redirect_uri", redirectURI) - form.Set("code_verifier", codeVerifier) + if strings.TrimSpace(codeVerifier) != "" { + form.Set("code_verifier", codeVerifier) + } r := client.R(). SetContext(ctx). @@ -872,9 +900,13 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall q.Set("scope", cfg.Scopes) } q.Set("state", state) - q.Set("nonce", nonce) - q.Set("code_challenge", codeChallenge) - q.Set("code_challenge_method", "S256") + if strings.TrimSpace(nonce) != "" { + q.Set("nonce", nonce) + } + if strings.TrimSpace(codeChallenge) != "" { + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + } u.RawQuery = q.Encode() return u.String(), nil diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index 2acca18a..a600fd56 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -186,6 +186,89 @@ func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) { require.Equal(t, int64(84), userID) } +func TestOIDCOAuthStartOmitsPKCEAndNonceWhenDisabled(t *testing.T) { + handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{ + Enabled: true, + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/oauth/authorize", + TokenURL: "https://issuer.example.com/oauth/token", + UserInfoURL: "https://issuer.example.com/oauth/userinfo", + Scopes: "openid profile email", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: false, + ValidateIDToken: false, + RequireEmailVerified: false, + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/start?redirect=/dashboard", nil) + + handler.OIDCOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.NotContains(t, location, "code_challenge=") + require.NotContains(t, location, "nonce=") + require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthVerifierCookie)) + require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthNonceCookie)) +} + +func TestOIDCOAuthCallbackAllowsOptionalPKCEAndIDTokenValidation(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + require.NoError(t, r.ParseForm()) + require.Empty(t, r.PostForm.Get("code_verifier")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"oidc-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"sub":"oidc-subject-compat","preferred_username":"oidc_user","name":"OIDC Display","email":"oidc@example.com"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, config.OIDCConnectConfig{ + Enabled: true, + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "openid profile email", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: false, + ValidateIDToken: false, + RequireEmailVerified: false, + }) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) +} + func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) { cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ Subject: "oidc-subject-login", @@ -250,7 +333,10 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *t completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) require.True(t, ok) require.Equal(t, "/dashboard", completion["redirect"]) - require.NotEmpty(t, completion["access_token"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) + _, hasRefreshToken := completion["refresh_token"] + require.False(t, hasRefreshToken) require.Nil(t, completion["error"]) } diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 78f5d7c2..39703ce7 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -279,12 +279,7 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") - if err != nil { - redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) - return - } - if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil { + if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, nil, nil, &existingIdentityUser.ID); err != nil { redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") return } diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 937daa6d..99006701 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -213,6 +213,86 @@ func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMo require.Equal(t, "third_party_signup", completion["choice_reason"]) } +func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback")) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(wechatSyntheticEmail("union-456")). + SetUsername("wechat-existing-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"username": "wechat-existing-user"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, existingUser.Email, session.ResolvedEmail) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/dashboard", completion["redirect"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) + _, hasRefreshToken := completion["refresh_token"] + require.False(t, hasRefreshToken) +} + func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL t.Cleanup(func() { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index fe566fec..059bbcd3 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -631,7 +631,7 @@ func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string mpReady := mpEnabled && webRedirectReady && mpAppID != "" && mpAppSecret != "" mobileReady := mobileEnabled && mobileAppID != "" && mobileAppSecret != "" - return openReady || mpReady || mobileReady, openReady, mpReady, mobileReady + return openReady || mpReady, openReady, mpReady, mobileReady } // filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON @@ -1693,8 +1693,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken } - result.OIDCConnectUsePKCE = true - result.OIDCConnectValidateIDToken = true if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) } else { @@ -2196,8 +2194,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { effective.RedirectURL = strings.TrimSpace(v) } - effective.UsePKCE = true - if !effective.Enabled { return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") } @@ -2421,8 +2417,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { effective.ValidateIDToken = raw == "true" } - effective.UsePKCE = true - effective.ValidateIDToken = true if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { effective.AllowedSigningAlgs = strings.TrimSpace(v) } diff --git a/backend/internal/service/setting_service_oidc_config_test.go b/backend/internal/service/setting_service_oidc_config_test.go index 3809b332..a5a3959a 100644 --- a/backend/internal/service/setting_service_oidc_config_test.go +++ b/backend/internal/service/setting_service_oidc_config_test.go @@ -101,3 +101,47 @@ func TestGetOIDCConnectOAuthConfig_ResolvesEndpointsFromIssuerDiscovery(t *testi require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/userinfo", got.UserInfoURL) require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/certs", got.JWKSURL) } + +func TestSettingService_ParseSettings_PreservesOptionalOIDCCompatibilityFlags(t *testing.T) { + svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{}) + + got := svc.parseSettings(map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + SettingKeyOIDCConnectUsePKCE: "false", + SettingKeyOIDCConnectValidateIDToken: "false", + }) + + require.False(t, got.OIDCConnectUsePKCE) + require.False(t, got.OIDCConnectValidateIDToken) +} + +func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) { + cfg := &config.Config{ + OIDC: config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/auth", + TokenURL: "https://issuer.example.com/token", + UserInfoURL: "https://issuer.example.com/userinfo", + RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + Scopes: "openid email profile", + TokenAuthMethod: "client_secret_post", + }, + } + + repo := &settingOIDCRepoStub{values: map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + SettingKeyOIDCConnectUsePKCE: "false", + SettingKeyOIDCConnectValidateIDToken: "false", + }} + svc := NewSettingService(repo, cfg) + + got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.False(t, got.UsePKCE) + require.False(t, got.ValidateIDToken) +} diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 497d1e36..4c7ca14b 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -112,3 +112,23 @@ func TestSettingService_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t * require.True(t, settings.WeChatOAuthOpenEnabled) require.True(t, settings.WeChatOAuthMPEnabled) } + +func TestSettingService_GetPublicSettings_DoesNotExposeMobileOnlyWeChatAsWebOAuthAvailable(t *testing.T) { + svc := NewSettingService(&settingPublicRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectMobileEnabled: "true", + SettingKeyWeChatConnectMode: "mobile", + SettingKeyWeChatConnectMobileAppID: "wx-mobile-app", + SettingKeyWeChatConnectMobileAppSecret: "wx-mobile-secret", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + }, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.False(t, settings.WeChatOAuthEnabled) + require.False(t, settings.WeChatOAuthOpenEnabled) + require.False(t, settings.WeChatOAuthMPEnabled) + require.True(t, settings.WeChatOAuthMobileEnabled) +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index bc444af5..c16d810b 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -248,12 +248,59 @@ func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID in return UserIdentitySummarySet{}, err } - return UserIdentitySummarySet{ + summaries := UserIdentitySummarySet{ Email: s.buildEmailIdentitySummary(user, records), LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records), OIDC: s.buildProviderIdentitySummary("oidc", user, records), WeChat: s.buildProviderIdentitySummary("wechat", user, records), - }, nil + } + + s.applyExplicitProviderAvailability(ctx, &summaries) + return summaries, nil +} + +func (s *UserService) applyExplicitProviderAvailability(ctx context.Context, summaries *UserIdentitySummarySet) { + if s == nil || summaries == nil || s.settingRepo == nil { + return + } + + settings, err := s.settingRepo.GetMultiple(ctx, []string{ + SettingKeyLinuxDoConnectEnabled, + SettingKeyOIDCConnectEnabled, + SettingKeyWeChatConnectEnabled, + SettingKeyWeChatConnectOpenEnabled, + SettingKeyWeChatConnectMPEnabled, + SettingKeyWeChatConnectMobileEnabled, + SettingKeyWeChatConnectMode, + }) + if err != nil { + return + } + + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" { + disableIdentityBindAction(&summaries.LinuxDo) + } + if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" { + disableIdentityBindAction(&summaries.OIDC) + } + if raw, ok := settings[SettingKeyWeChatConnectEnabled]; ok && strings.TrimSpace(raw) != "" { + if raw != "true" { + disableIdentityBindAction(&summaries.WeChat) + return + } + openEnabled, mpEnabled, _ := parseWeChatConnectCapabilitySettings(settings, true, settings[SettingKeyWeChatConnectMode]) + if !openEnabled && !mpEnabled { + disableIdentityBindAction(&summaries.WeChat) + } + } +} + +func disableIdentityBindAction(summary *UserIdentitySummary) { + if summary == nil || summary.Bound { + return + } + summary.CanBind = false + summary.BindStartPath = "" } func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUserIdentityBindingRequest) (*StartUserIdentityBindingResult, error) { diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 88bb1637..109d459d 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -51,6 +51,44 @@ type mockUserRepoTxState struct { deleteAvatarIDs []int64 } +type mockUserSettingRepo struct { + values map[string]string +} + +func (m *mockUserSettingRepo) Get(context.Context, string) (*Setting, error) { + panic("unexpected Get call") +} + +func (m *mockUserSettingRepo) GetValue(context.Context, string) (string, error) { + panic("unexpected GetValue call") +} + +func (m *mockUserSettingRepo) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (m *mockUserSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := m.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (m *mockUserSettingRepo) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (m *mockUserSettingRepo) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (m *mockUserSettingRepo) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + func (m *mockUserRepo) Create(context.Context, *User) error { return nil } func (m *mockUserRepo) GetByID(ctx context.Context, _ int64) (*User, error) { if m.getByIDErr != nil { @@ -382,6 +420,35 @@ func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testin require.True(t, summaries.LinuxDo.CanBind) } +func TestGetProfileIdentitySummaries_HidesBindActionWhenProviderExplicitlyDisabled(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 15, + Email: "alice@example.com", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "alice@example.com", + }, + }, + } + settingRepo := &mockUserSettingRepo{ + values: map[string]string{ + SettingKeyLinuxDoConnectEnabled: "false", + }, + } + svc := NewUserService(repo, settingRepo, nil, nil) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 15, repo.getByIDUser) + + require.NoError(t, err) + require.False(t, summaries.LinuxDo.Bound) + require.False(t, summaries.LinuxDo.CanBind) + require.Empty(t, summaries.LinuxDo.BindStartPath) +} + func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) { repo := &mockUserRepo{} svc := NewUserService(repo, nil, nil, nil) // billingCache = nil diff --git a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue index 848789d9..48b1b879 100644 --- a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue +++ b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue @@ -362,6 +362,16 @@ function getBindingDetails(provider: UserAuthProvider): UserAuthBindingStatus | return binding } +function isProviderEnabledForBinding(provider: BindableProvider): boolean { + if (provider === 'linuxdo') { + return props.linuxdoEnabled + } + if (provider === 'oidc') { + return props.oidcEnabled + } + return resolvedWeChatBinding.value.mode !== null +} + const providerItems = computed(() => [ { provider: 'email' as const, @@ -375,7 +385,10 @@ const providerItems = computed(() => [ provider: 'linuxdo' as const, label: t('profile.authBindings.providers.linuxdo'), bound: getBindingStatus('linuxdo'), - canBind: getBindingDetails('linuxdo')?.can_bind ?? (props.linuxdoEnabled && !getBindingStatus('linuxdo')), + canBind: + !getBindingStatus('linuxdo') && + isProviderEnabledForBinding('linuxdo') && + (getBindingDetails('linuxdo')?.can_bind ?? true), canUnbind: Boolean(getBindingStatus('linuxdo') && getBindingDetails('linuxdo')?.can_unbind), details: getBindingDetails('linuxdo'), }, @@ -383,7 +396,10 @@ const providerItems = computed(() => [ provider: 'oidc' as const, label: t('profile.authBindings.providers.oidc', { providerName: props.oidcProviderName }), bound: getBindingStatus('oidc'), - canBind: getBindingDetails('oidc')?.can_bind ?? (props.oidcEnabled && !getBindingStatus('oidc')), + canBind: + !getBindingStatus('oidc') && + isProviderEnabledForBinding('oidc') && + (getBindingDetails('oidc')?.can_bind ?? true), canUnbind: Boolean(getBindingStatus('oidc') && getBindingDetails('oidc')?.can_unbind), details: getBindingDetails('oidc'), }, @@ -391,7 +407,10 @@ const providerItems = computed(() => [ provider: 'wechat' as const, label: t('profile.authBindings.providers.wechat'), bound: getBindingStatus('wechat'), - canBind: getBindingDetails('wechat')?.can_bind ?? (resolvedWeChatBinding.value.mode !== null && !getBindingStatus('wechat')), + canBind: + !getBindingStatus('wechat') && + isProviderEnabledForBinding('wechat') && + (getBindingDetails('wechat')?.can_bind ?? true), canUnbind: Boolean(getBindingStatus('wechat') && getBindingDetails('wechat')?.can_unbind), details: getBindingDetails('wechat'), }, diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts index 345e0209..9d8c88d4 100644 --- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts +++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts @@ -474,4 +474,26 @@ describe('ProfileIdentityBindingsSection', () => { expect(userApiMocks.unbindAuthIdentity).toHaveBeenCalledWith('linuxdo') expect(wrapper.get('[data-testid="profile-binding-linuxdo-status"]').text()).toBe('Not bound') }) + + it('hides bind actions when provider details say bindable but the provider is disabled', () => { + const wrapper = mount(ProfileIdentityBindingsSection, { + global: { + plugins: [pinia], + }, + props: { + user: createUser({ + auth_bindings: { + linuxdo: { bound: false, can_bind: true }, + oidc: { bound: false, can_bind: true }, + }, + }), + linuxdoEnabled: false, + oidcEnabled: false, + wechatEnabled: false, + }, + }) + + expect(wrapper.find('[data-testid="profile-binding-linuxdo-action"]').exists()).toBe(false) + expect(wrapper.find('[data-testid="profile-binding-oidc-action"]').exists()).toBe(false) + }) }) diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index a13f1981..5772c501 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -2032,7 +2032,7 @@ @@ -2046,7 +2046,7 @@ @@ -4961,8 +4961,8 @@ const form = reactive({ oidc_connect_redirect_url: "", oidc_connect_frontend_redirect_url: "/auth/oidc/callback", oidc_connect_token_auth_method: "client_secret_post", - oidc_connect_use_pkce: true, - oidc_connect_validate_id_token: true, + oidc_connect_use_pkce: false, + oidc_connect_validate_id_token: false, oidc_connect_allowed_signing_algs: "RS256,ES256,PS256", oidc_connect_clock_skew_seconds: 120, oidc_connect_require_email_verified: false, @@ -5846,8 +5846,8 @@ async function saveSettings() { oidc_connect_frontend_redirect_url: form.oidc_connect_frontend_redirect_url, oidc_connect_token_auth_method: form.oidc_connect_token_auth_method, - oidc_connect_use_pkce: true, - oidc_connect_validate_id_token: true, + oidc_connect_use_pkce: form.oidc_connect_use_pkce, + oidc_connect_validate_id_token: form.oidc_connect_validate_id_token, oidc_connect_allowed_signing_algs: form.oidc_connect_allowed_signing_algs, oidc_connect_clock_skew_seconds: form.oidc_connect_clock_skew_seconds, oidc_connect_require_email_verified: diff --git a/frontend/src/views/admin/__tests__/SettingsView.spec.ts b/frontend/src/views/admin/__tests__/SettingsView.spec.ts index 27a43c9f..10c51b2a 100644 --- a/frontend/src/views/admin/__tests__/SettingsView.spec.ts +++ b/frontend/src/views/admin/__tests__/SettingsView.spec.ts @@ -776,4 +776,28 @@ describe("admin SettingsView wechat connect controls", () => { ).toBe(true); expect(wrapper.text()).toContain("首次绑定时授权"); }); + + it("preserves optional OIDC compatibility flags instead of forcing them on save", async () => { + getSettings.mockResolvedValueOnce({ + ...baseSettingsResponse, + oidc_connect_enabled: true, + oidc_connect_use_pkce: false, + oidc_connect_validate_id_token: false, + }); + + const wrapper = mountView(); + + await flushPromises(); + await openSecurityTab(wrapper); + await wrapper.find("form").trigger("submit.prevent"); + await flushPromises(); + + expect(updateSettings).toHaveBeenCalledTimes(1); + expect(updateSettings).toHaveBeenCalledWith( + expect.objectContaining({ + oidc_connect_use_pkce: false, + oidc_connect_validate_id_token: false, + }), + ); + }); }); From ca1f30a9113f363ce864bfb102ddd7286210fa2d Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 11:17:38 +0800 Subject: [PATCH 04/31] fix(auth): harden pending oauth session consumption --- .../handler/auth_oauth_pending_flow.go | 73 ++++++++++++------- .../handler/auth_oauth_pending_flow_test.go | 10 +-- .../service/auth_pending_identity_service.go | 35 +++++++-- .../auth_pending_identity_service_test.go | 66 +++++++++++++++++ 4 files changed, 146 insertions(+), 38 deletions(-) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 7d7b50f4..c7cd6103 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -277,6 +277,22 @@ func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool { return false } +func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool { + if session == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) { + return false + } + if session.TargetUserID == nil || *session.TargetUserID <= 0 { + return false + } + if pendingSessionWantsInvitation(payload) { + return false + } + return strings.TrimSpace(pendingSessionStringValue(payload, "step")) == "" +} + func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error { if session == nil { return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") @@ -1212,13 +1228,7 @@ func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt( if session == nil || len(payload) == 0 { return false, nil } - if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) { - return false, nil - } - if !pendingOAuthCompletionIncludesTokenPayload(payload) { - return false, nil - } - if session.TargetUserID == nil || *session.TargetUserID <= 0 { + if !pendingOAuthCompletionCanIssueTokenPair(session, payload) { return false, nil } if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" && @@ -1649,6 +1659,22 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { } } applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) + + canIssueTokenPair := pendingOAuthCompletionCanIssueTokenPair(session, payload) + var loginUser *service.User + if canIssueTokenPair { + loginUser, err = h.userService.GetByID(c.Request.Context(), *session.TargetUserID) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + } skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload) if err != nil { clearCookies() @@ -1658,25 +1684,6 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { if skipAdoptionPrompt { delete(payload, "adoption_required") } - if pendingOAuthCompletionIncludesTokenPayload(payload) { - if session.TargetUserID == nil || *session.TargetUserID <= 0 { - clearCookies() - response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid")) - return - } - user, err := h.userService.GetByID(c.Request.Context(), *session.TargetUserID) - if err != nil { - clearCookies() - response.ErrorFrom(c, err) - return - } - if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { - clearCookies() - response.ErrorFrom(c, err) - return - } - h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) - } if pendingSessionWantsInvitation(payload) { if adoptionDecision.hasDecision() { @@ -1724,6 +1731,20 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { return } + if canIssueTokenPair { + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), loginUser, "") + if err != nil { + clearCookies() + response.InternalError(c, "Failed to generate token pair") + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), loginUser.ID) + payload["access_token"] = tokenPair.AccessToken + payload["refresh_token"] = tokenPair.RefreshToken + payload["expires_in"] = tokenPair.ExpiresIn + payload["token_type"] = "Bearer" + } + clearCookies() response.Success(c, payload) } diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 8940e37d..6f457206 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -746,11 +746,7 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo }). SetLocalFlowState(map[string]any{ oauthCompletionResponseKey: map[string]any{ - "access_token": "access-token", - "refresh_token": "refresh-token", - "expires_in": float64(3600), - "token_type": "Bearer", - "redirect": "/dashboard", + "redirect": "/dashboard", }, }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). @@ -769,8 +765,8 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo require.Equal(t, http.StatusOK, recorder.Code) payload := decodeJSONResponseData(t, recorder) - require.Equal(t, "access-token", payload["access_token"]) - require.Equal(t, "refresh-token", payload["refresh_token"]) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) require.Equal(t, "/dashboard", payload["redirect"]) require.Equal(t, "Existing Login Example", payload["suggested_display_name"]) require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"]) diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go index 7001ee18..26732c77 100644 --- a/backend/internal/service/auth_pending_identity_service.go +++ b/backend/internal/service/auth_pending_identity_service.go @@ -237,15 +237,40 @@ func (s *AuthPendingIdentityService) consumeSession( } now := time.Now().UTC() - updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + Where( + pendingauthsession.ConsumedAtIsNil(), + pendingauthsession.ExpiresAtGTE(now), + pendingauthsession.Or( + pendingauthsession.CompletionCodeExpiresAtIsNil(), + pendingauthsession.CompletionCodeExpiresAtGTE(now), + ), + ). SetConsumedAt(now). SetCompletionCodeHash(""). - ClearCompletionCodeExpiresAt(). - Save(ctx) - if err != nil { + ClearCompletionCodeExpiresAt() + if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" { + update = update.Where(pendingauthsession.BrowserSessionKeyEQ(expectedBrowserSessionKey)) + } + updated, err := update.Save(ctx) + if err == nil { + return updated, nil + } + if !dbent.IsNotFound(err) { return nil, err } - return updated, nil + + current, currentErr := s.entClient.PendingAuthSession.Get(ctx, session.ID) + if currentErr != nil { + if dbent.IsNotFound(currentErr) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, currentErr + } + if err := validatePendingSessionState(current, browserSessionKey, expiredErr, consumedErr); err != nil { + return nil, err + } + return nil, consumedErr } func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error { diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go index de0b18d2..deeeeb06 100644 --- a/backend/internal/service/auth_pending_identity_service_test.go +++ b/backend/internal/service/auth_pending_identity_service_test.go @@ -356,3 +356,69 @@ func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) { _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") require.ErrorIs(t, err, ErrPendingAuthSessionConsumed) } + +func TestAuthPendingIdentityService_ConsumeBrowserSessionRejectsStaleLoadedSessionReplay(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "stale-replay-subject", + }, + BrowserSessionKey: "browser-session", + }) + require.NoError(t, err) + + loaded, err := svc.getBrowserSession(ctx, session.SessionToken) + require.NoError(t, err) + + consumed, err := svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + + _, err = svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) + require.ErrorIs(t, err, ErrPendingAuthSessionConsumed) +} + +func TestAuthPendingIdentityService_ConsumeBrowserSessionScrubsLegacyCompletionTokens(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "legacy-token-subject", + }, + BrowserSessionKey: "browser-session", + LocalFlowState: map[string]any{ + "completion_response": map[string]any{ + "access_token": "legacy-access-token", + "refresh_token": "legacy-refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + "redirect": "/dashboard", + }, + }, + }) + require.NoError(t, err) + + consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + + stored, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + + completion, ok := stored.LocalFlowState["completion_response"].(map[string]any) + require.True(t, ok) + require.NotContains(t, completion, "access_token") + require.NotContains(t, completion, "refresh_token") + require.NotContains(t, completion, "expires_in") + require.NotContains(t, completion, "token_type") + require.Equal(t, "/dashboard", completion["redirect"]) +} From 18481a100b761742c26711f2e95ee2bf44c2c0e6 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 11:17:45 +0800 Subject: [PATCH 05/31] fix(migrations): defer online ddl follow-ups safely --- .../migrate/auth_identity_fk_ondelete_test.go | 73 +++++++++++++++++++ .../internal/repository/migrations_runner.go | 50 +++++++------ .../migrations_runner_checksum_test.go | 27 +++++++ .../109_auth_identity_compat_backfill.sql | 3 - ...rce_payment_orders_out_trade_no_unique.sql | 13 ++-- ...ayment_orders_out_trade_no_unique_notx.sql | 10 +++ ...h_identity_migration_report_type_widen.sql | 2 + ...tity_payment_migrations_regression_test.go | 32 +++++++- 8 files changed, 173 insertions(+), 37 deletions(-) create mode 100644 backend/ent/migrate/auth_identity_fk_ondelete_test.go create mode 100644 backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql create mode 100644 backend/migrations/121_auth_identity_migration_report_type_widen.sql diff --git a/backend/ent/migrate/auth_identity_fk_ondelete_test.go b/backend/ent/migrate/auth_identity_fk_ondelete_test.go new file mode 100644 index 00000000..0e37025a --- /dev/null +++ b/backend/ent/migrate/auth_identity_fk_ondelete_test.go @@ -0,0 +1,73 @@ +package migrate + +import ( + "testing" + + "entgo.io/ent/dialect/entsql" + entschema "entgo.io/ent/dialect/sql/schema" + "github.com/stretchr/testify/require" +) + +func TestAuthIdentityFoundationForeignKeyOnDeleteActions(t *testing.T) { + require.Equal( + t, + entschema.Cascade, + findForeignKeyBySymbol(t, AuthIdentitiesTable, "auth_identities_users_auth_identities").OnDelete, + ) + require.Equal( + t, + entschema.Cascade, + findForeignKeyBySymbol(t, AuthIdentityChannelsTable, "auth_identity_channels_auth_identities_channels").OnDelete, + ) + require.Equal( + t, + entschema.Cascade, + findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_pending_auth_sessions_adoption_decision").OnDelete, + ) + + require.Equal( + t, + entschema.SetNull, + findForeignKeyBySymbol(t, PendingAuthSessionsTable, "pending_auth_sessions_users_pending_auth_sessions").OnDelete, + ) + require.Equal( + t, + entschema.SetNull, + findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_auth_identities_adoption_decisions").OnDelete, + ) +} + +func TestPaymentOrdersOutTradeNoPartialUniqueIndex(t *testing.T) { + idx := findIndexByName(t, PaymentOrdersTable, "paymentorder_out_trade_no") + require.True(t, idx.Unique) + require.Len(t, idx.Columns, 1) + require.Equal(t, "out_trade_no", idx.Columns[0].Name) + require.NotNil(t, idx.Annotation) + require.Equal(t, (&entsql.IndexAnnotation{Where: "out_trade_no <> ''"}).Where, idx.Annotation.Where) +} + +func findForeignKeyBySymbol(t *testing.T, table *entschema.Table, symbol string) *entschema.ForeignKey { + t.Helper() + + for _, fk := range table.ForeignKeys { + if fk.Symbol == symbol { + return fk + } + } + + require.Failf(t, "missing foreign key", "table %s should include foreign key %s", table.Name, symbol) + return nil +} + +func findIndexByName(t *testing.T, table *entschema.Table, name string) *entschema.Index { + t.Helper() + + for _, idx := range table.Indexes { + if idx.Name == name { + return idx + } + } + + require.Failf(t, "missing index", "table %s should include index %s", table.Name, name) + return nil +} diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index 5a2e6677..edc85226 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -55,30 +55,17 @@ const nonTransactionalMigrationSuffix = "_notx.sql" type migrationChecksumCompatibilityRule struct { fileChecksum string acceptedDBChecksum map[string]struct{} + acceptedChecksums map[string]struct{} } // migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。 -// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。 +// 规则必须同时匹配「迁移名 + 数据库 checksum + 当前文件 checksum」且两者都落在该迁移的已知版本集合内才会放行, +// 避免放宽全局校验,也允许将误改的历史 migration 回滚为已发布版本而不要求人工修 checksum。 var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{ - "054_drop_legacy_cache_columns.sql": { - fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", - acceptedDBChecksum: map[string]struct{}{ - "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {}, - }, - }, - "061_add_usage_log_request_type.sql": { - fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", - acceptedDBChecksum: map[string]struct{}{ - "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {}, - "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {}, - }, - }, - "109_auth_identity_compat_backfill.sql": { - fileChecksum: "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee", - acceptedDBChecksum: map[string]struct{}{ - "2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3": {}, - }, - }, + "054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"), + "061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"), + "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"), + "119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"), } // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 @@ -328,16 +315,33 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) { return version, version, hash, nil } +func checksumSet(values ...string) map[string]struct{} { + out := make(map[string]struct{}, len(values)) + for _, value := range values { + out[value] = struct{}{} + } + return out +} + +func newMigrationChecksumCompatibilityRule(fileChecksum string, acceptedDBChecksums ...string) migrationChecksumCompatibilityRule { + return migrationChecksumCompatibilityRule{ + fileChecksum: fileChecksum, + acceptedDBChecksum: checksumSet(acceptedDBChecksums...), + acceptedChecksums: checksumSet(append([]string{fileChecksum}, acceptedDBChecksums...)...), + } +} + func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool { rule, ok := migrationChecksumCompatibilityRules[name] if !ok { return false } - if rule.fileChecksum != fileChecksum { + _, dbOK := rule.acceptedChecksums[dbChecksum] + if !dbOK { return false } - _, ok = rule.acceptedDBChecksum[dbChecksum] - return ok + _, fileOK := rule.acceptedChecksums[fileChecksum] + return fileOK } func validateMigrationExecutionMode(name, content string) (bool, error) { diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go index 6030991b..dc241a75 100644 --- a/backend/internal/repository/migrations_runner_checksum_test.go +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -60,4 +60,31 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { ) require.True(t, ok) }) + + t.Run("109回滚到历史文件后仍兼容已应用的新checksum", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "109_auth_identity_compat_backfill.sql", + "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee", + "2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", + ) + require.True(t, ok) + }) + + t.Run("119历史checksum可兼容占位文件", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "119_enforce_payment_orders_out_trade_no_unique.sql", + "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34", + "0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", + ) + require.True(t, ok) + }) + + t.Run("119未知checksum不兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "119_enforce_payment_orders_out_trade_no_unique.sql", + "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34", + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + ) + require.False(t, ok) + }) } diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql index 5147ae45..ddbbedbc 100644 --- a/backend/migrations/109_auth_identity_compat_backfill.sql +++ b/backend/migrations/109_auth_identity_compat_backfill.sql @@ -1,6 +1,3 @@ -ALTER TABLE auth_identity_migration_reports -ALTER COLUMN report_type TYPE VARCHAR(80); - INSERT INTO auth_identities ( user_id, provider_type, diff --git a/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql index 4e256562..15e2c15f 100644 --- a/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql +++ b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql @@ -1,7 +1,6 @@ --- Replace the legacy non-unique index with a partial unique index. --- Keep empty-string legacy rows compatible while enforcing uniqueness for real order IDs. -DROP INDEX IF EXISTS paymentorder_out_trade_no; - -CREATE UNIQUE INDEX IF NOT EXISTS paymentorder_out_trade_no - ON payment_orders (out_trade_no) - WHERE out_trade_no <> ''; +-- Intentionally left as a no-op. +-- The online index rollout lives in 120_enforce_payment_orders_out_trade_no_unique_notx.sql +DO $$ +BEGIN + NULL; +END $$; diff --git a/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql new file mode 100644 index 00000000..fe47698d --- /dev/null +++ b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql @@ -0,0 +1,10 @@ +-- Build the payment order uniqueness guarantee online. +-- Create the new partial unique index concurrently first so writes keep flowing, +-- then remove the legacy index name once the replacement is ready. +DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique; + +CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique + ON payment_orders (out_trade_no) + WHERE out_trade_no <> ''; + +DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no; diff --git a/backend/migrations/121_auth_identity_migration_report_type_widen.sql b/backend/migrations/121_auth_identity_migration_report_type_widen.sql new file mode 100644 index 00000000..66bfb44a --- /dev/null +++ b/backend/migrations/121_auth_identity_migration_report_type_widen.sql @@ -0,0 +1,2 @@ +ALTER TABLE auth_identity_migration_reports +ALTER COLUMN report_type TYPE VARCHAR(80); diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go index 1c4a51fa..988876a9 100644 --- a/backend/migrations/auth_identity_payment_migrations_regression_test.go +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -26,12 +26,36 @@ func TestMigration118DoesNotForceOverwriteAuthSourceGrantDefaults(t *testing.T) require.True(t, strings.Contains(sql, "ON CONFLICT (key) DO NOTHING")) } -func TestMigration119EnforcesOutTradeNoPartialUniqueIndex(t *testing.T) { +func TestMigration109KeepsPublishedBackfillBodyAndDefersReportTypeWidening(t *testing.T) { + content, err := FS.ReadFile("109_auth_identity_compat_backfill.sql") + require.NoError(t, err) + + sql := string(content) + require.NotContains(t, sql, "ALTER TABLE auth_identity_migration_reports") + + followupContent, err := FS.ReadFile("121_auth_identity_migration_report_type_widen.sql") + require.NoError(t, err) + + followupSQL := string(followupContent) + require.Contains(t, followupSQL, "ALTER TABLE auth_identity_migration_reports") + require.Contains(t, followupSQL, "ALTER COLUMN report_type TYPE VARCHAR(80)") +} + +func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) { content, err := FS.ReadFile("119_enforce_payment_orders_out_trade_no_unique.sql") require.NoError(t, err) sql := string(content) - require.Contains(t, sql, "DROP INDEX IF EXISTS paymentorder_out_trade_no") - require.Contains(t, sql, "CREATE UNIQUE INDEX IF NOT EXISTS paymentorder_out_trade_no") - require.Contains(t, sql, "WHERE out_trade_no <> ''") + require.Contains(t, sql, "120_enforce_payment_orders_out_trade_no_unique_notx.sql") + require.Contains(t, sql, "NULL;") + require.NotContains(t, sql, "CREATE UNIQUE INDEX") + require.NotContains(t, sql, "DROP INDEX") + + followupContent, err := FS.ReadFile("120_enforce_payment_orders_out_trade_no_unique_notx.sql") + require.NoError(t, err) + + followupSQL := string(followupContent) + require.Contains(t, followupSQL, "CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique") + require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no") + require.Contains(t, followupSQL, "WHERE out_trade_no <> ''") } From 454873221c1f2a3ea4df8b58a116fa73acc873a6 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 11:18:09 +0800 Subject: [PATCH 06/31] test(auth): strengthen pending oauth legacy token assertions --- .../handler/auth_oauth_pending_flow_test.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 6f457206..a212eb91 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -746,7 +746,11 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo }). SetLocalFlowState(map[string]any{ oauthCompletionResponseKey: map[string]any{ - "redirect": "/dashboard", + "access_token": "legacy-access-token", + "refresh_token": "legacy-refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + "redirect": "/dashboard", }, }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). @@ -767,6 +771,8 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo payload := decodeJSONResponseData(t, recorder) require.NotEmpty(t, payload["access_token"]) require.NotEmpty(t, payload["refresh_token"]) + require.NotEqual(t, "legacy-access-token", payload["access_token"]) + require.NotEqual(t, "legacy-refresh-token", payload["refresh_token"]) require.Equal(t, "/dashboard", payload["redirect"]) require.Equal(t, "Existing Login Example", payload["suggested_display_name"]) require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"]) @@ -781,6 +787,14 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.NotNil(t, storedSession.ConsumedAt) + + completion, ok := storedSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.NotContains(t, completion, "access_token") + require.NotContains(t, completion, "refresh_token") + require.NotContains(t, completion, "expires_in") + require.NotContains(t, completion, "token_type") + require.Equal(t, "/dashboard", completion["redirect"]) } func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) { From 9d5e9bbc1806c13f1a718584ac920e85ca67161b Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 11:28:58 +0800 Subject: [PATCH 07/31] fix(payment): respect configured visible method source --- .../service/payment_order_jsapi_test.go | 65 ++++++ .../service/payment_resume_service_test.go | 206 ++++++++++++++++++ .../payment_visible_method_instances.go | 52 ++++- 3 files changed, 322 insertions(+), 1 deletion(-) diff --git a/backend/internal/service/payment_order_jsapi_test.go b/backend/internal/service/payment_order_jsapi_test.go index a89d0380..8c5e4fc0 100644 --- a/backend/internal/service/payment_order_jsapi_test.go +++ b/backend/internal/service/payment_order_jsapi_test.go @@ -31,3 +31,68 @@ func TestUsesOfficialWxpayVisibleMethodDerivesFromEnabledProviderInstance(t *tes t.Fatal("expected official wxpay visible method to be detected from enabled provider instance") } } + +func TestUsesOfficialWxpayVisibleMethodRespectsConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) { + tests := []struct { + name string + source string + wantOfficial bool + }{ + { + name: "official source selected", + source: VisibleMethodSourceOfficialWechat, + wantOfficial: true, + }, + { + name: "easypay source selected", + source: VisibleMethodSourceEasyPayWechat, + wantOfficial: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("Official WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create official wxpay instance: %v", err) + } + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + SetSortOrder(2). + Save(ctx) + if err != nil { + t.Fatalf("create easypay wxpay instance: %v", err) + } + + svc := &PaymentService{ + configService: &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + SettingPaymentVisibleMethodWxpaySource: tt.source, + }, + }, + }, + } + + if got := svc.usesOfficialWxpayVisibleMethod(ctx); got != tt.wantOfficial { + t.Fatalf("usesOfficialWxpayVisibleMethod() = %v, want %v", got, tt.wantOfficial) + } + }) + } +} diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go index ffa55e69..e19e0b99 100644 --- a/backend/internal/service/payment_resume_service_test.go +++ b/backend/internal/service/payment_resume_service_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) func TestNormalizeVisibleMethods(t *testing.T) { @@ -419,6 +420,211 @@ func TestVisibleMethodLoadBalancerUsesEnabledProviderInstance(t *testing.T) { } } +func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method payment.PaymentType + officialName string + officialTypes string + easyPayName string + easyPayTypes string + sourceSetting string + wantProvider string + }{ + { + name: "alipay uses official source", + method: payment.TypeAlipay, + officialName: "Official Alipay", + officialTypes: "alipay", + easyPayName: "EasyPay Alipay", + easyPayTypes: "alipay", + sourceSetting: VisibleMethodSourceOfficialAlipay, + wantProvider: payment.TypeAlipay, + }, + { + name: "alipay uses easypay source", + method: payment.TypeAlipay, + officialName: "Official Alipay", + officialTypes: "alipay", + easyPayName: "EasyPay Alipay", + easyPayTypes: "alipay", + sourceSetting: VisibleMethodSourceEasyPayAlipay, + wantProvider: payment.TypeEasyPay, + }, + { + name: "wxpay uses official source", + method: payment.TypeWxpay, + officialName: "Official WeChat", + officialTypes: "wxpay", + easyPayName: "EasyPay WeChat", + easyPayTypes: "wxpay", + sourceSetting: VisibleMethodSourceOfficialWechat, + wantProvider: payment.TypeWxpay, + }, + { + name: "wxpay uses easypay source", + method: payment.TypeWxpay, + officialName: "Official WeChat", + officialTypes: "wxpay", + easyPayName: "EasyPay WeChat", + easyPayTypes: "wxpay", + sourceSetting: VisibleMethodSourceEasyPayWechat, + wantProvider: payment.TypeEasyPay, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + officialProviderKey := payment.TypeAlipay + if tt.method == payment.TypeWxpay { + officialProviderKey = payment.TypeWxpay + } + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(officialProviderKey). + SetName(tt.officialName). + SetConfig("{}"). + SetSupportedTypes(tt.officialTypes). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create official provider: %v", err) + } + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName(tt.easyPayName). + SetConfig("{}"). + SetSupportedTypes(tt.easyPayTypes). + SetEnabled(true). + SetSortOrder(2). + Save(ctx) + if err != nil { + t.Fatalf("create easypay provider: %v", err) + } + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + visibleMethodSourceSettingKey(tt.method): tt.sourceSetting, + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + _, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 12.5) + if err != nil { + t.Fatalf("SelectInstance returned error: %v", err) + } + if inner.lastProviderKey != tt.wantProvider { + t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, tt.wantProvider) + } + }) + } +} + +func TestVisibleMethodLoadBalancerRejectsMissingOrInvalidSourceWhenMultipleProvidersEnabled(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method payment.PaymentType + sourceValue string + wantMessage string + }{ + { + name: "missing alipay source", + method: payment.TypeAlipay, + sourceValue: "", + wantMessage: "alipay source is required when the visible method is enabled", + }, + { + name: "invalid wxpay source", + method: payment.TypeWxpay, + sourceValue: "stripe", + wantMessage: "wxpay source must be one of the supported payment providers", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + officialProviderKey := payment.TypeAlipay + officialSupportedTypes := "alipay" + officialName := "Official Alipay" + easyPaySupportedTypes := "alipay" + easyPayName := "EasyPay Alipay" + if tt.method == payment.TypeWxpay { + officialProviderKey = payment.TypeWxpay + officialSupportedTypes = "wxpay" + officialName = "Official WeChat" + easyPaySupportedTypes = "wxpay" + easyPayName = "EasyPay WeChat" + } + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(officialProviderKey). + SetName(officialName). + SetConfig("{}"). + SetSupportedTypes(officialSupportedTypes). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create official provider: %v", err) + } + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName(easyPayName). + SetConfig("{}"). + SetSupportedTypes(easyPaySupportedTypes). + SetEnabled(true). + SetSortOrder(2). + Save(ctx) + if err != nil { + t.Fatalf("create easypay provider: %v", err) + } + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + visibleMethodSourceSettingKey(tt.method): tt.sourceValue, + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + _, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 9.9) + if err == nil { + t.Fatal("SelectInstance should reject invalid visible method source configuration") + } + if infraerrors.Reason(err) != "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE" { + t.Fatalf("Reason(err) = %q, want %q", infraerrors.Reason(err), "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE") + } + if infraerrors.Message(err) != tt.wantMessage { + t.Fatalf("Message(err) = %q, want %q", infraerrors.Message(err), tt.wantMessage) + } + }) + } +} + func TestVisibleMethodLoadBalancerRejectsMissingEnabledVisibleMethodProvider(t *testing.T) { t.Parallel() diff --git a/backend/internal/service/payment_visible_method_instances.go b/backend/internal/service/payment_visible_method_instances.go index 477e8e8b..39358522 100644 --- a/backend/internal/service/payment_visible_method_instances.go +++ b/backend/internal/service/payment_visible_method_instances.go @@ -82,6 +82,19 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta return filtered } +func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance { + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" { + return nil + } + for _, inst := range instances { + if strings.EqualFold(strings.TrimSpace(inst.ProviderKey), providerKey) { + return inst + } + } + return nil +} + func buildPaymentProviderConflictError(method string, conflicting *dbent.PaymentProviderInstance) error { metadata := map[string]string{ "payment_method": NormalizeVisibleMethod(method), @@ -133,6 +146,32 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts( return nil } +func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context.Context, method string) (string, error) { + method = NormalizeVisibleMethod(method) + sourceKey := visibleMethodSourceSettingKey(method) + rawSource := "" + if s != nil && s.settingRepo != nil && sourceKey != "" { + value, err := s.settingRepo.GetValue(ctx, sourceKey) + if err != nil { + return "", fmt.Errorf("get %s: %w", sourceKey, err) + } + rawSource = value + } + + normalizedSource, err := normalizeVisibleMethodSettingSource(method, rawSource, true) + if err != nil { + return "", err + } + providerKey, ok := VisibleMethodProviderKeyForSource(method, normalizedSource) + if !ok { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source must be one of the supported payment providers", method), + ) + } + return providerKey, nil +} + func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance( ctx context.Context, method string, @@ -161,6 +200,17 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance( case 1: return matching[0], nil default: - return nil, buildPaymentProviderConflictError(method, matching[0]) + providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method) + if err != nil { + return nil, err + } + selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey) + if selected == nil { + return nil, infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source has no enabled provider instance", method), + ) + } + return selected, nil } } From be9df2bea78bcacd78ed8aaf676102c2701478c5 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 11:29:05 +0800 Subject: [PATCH 08/31] fix(auth): scrub legacy pending oauth tokens on upgrade --- .../handler/auth_oauth_pending_flow.go | 3 ++ .../handler/auth_oauth_pending_flow_test.go | 16 ++++++++ .../service/auth_pending_identity_service.go | 25 ++++++++++++ ..._pending_auth_completion_token_cleanup.sql | 15 +++++++ ...y_auth_source_grant_on_signup_defaults.sql | 39 +++++++++++++++++++ ...tity_payment_migrations_regression_test.go | 25 ++++++++++++ 6 files changed, 123 insertions(+) create mode 100644 backend/migrations/122_pending_auth_completion_token_cleanup.sql create mode 100644 backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index c7cd6103..658a5f52 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -1290,6 +1290,9 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any { normalized := clonePendingMap(payload) + for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} { + delete(normalized, key) + } step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step"))) switch step { case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required": diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index a212eb91..c0413d4d 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -851,6 +851,22 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl require.Nil(t, storedSession.ConsumedAt) } +func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) { + payload := normalizePendingOAuthCompletionResponse(map[string]any{ + "access_token": "legacy-access-token", + "refresh_token": "legacy-refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + "redirect": "/dashboard", + }) + + require.NotContains(t, payload, "access_token") + require.NotContains(t, payload, "refresh_token") + require.NotContains(t, payload, "expires_in") + require.NotContains(t, payload, "token_type") + require.Equal(t, "/dashboard", payload["redirect"]) +} + func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, true) ctx := context.Background() diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go index 26732c77..cc0522ab 100644 --- a/backend/internal/service/auth_pending_identity_service.go +++ b/backend/internal/service/auth_pending_identity_service.go @@ -236,6 +236,7 @@ func (s *AuthPendingIdentityService) consumeSession( return nil, err } + sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState) now := time.Now().UTC() update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). Where( @@ -247,6 +248,7 @@ func (s *AuthPendingIdentityService) consumeSession( ), ). SetConsumedAt(now). + SetLocalFlowState(sanitizedLocalFlowState). SetCompletionCodeHash(""). ClearCompletionCodeExpiresAt() if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" { @@ -273,6 +275,29 @@ func (s *AuthPendingIdentityService) consumeSession( return nil, consumedErr } +func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any { + sanitized := copyPendingMap(localFlowState) + if len(sanitized) == 0 { + return sanitized + } + + rawCompletion, ok := sanitized["completion_response"] + if !ok { + return sanitized + } + completion, ok := rawCompletion.(map[string]any) + if !ok { + return sanitized + } + + cleanedCompletion := copyPendingMap(completion) + for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} { + delete(cleanedCompletion, key) + } + sanitized["completion_response"] = cleanedCompletion + return sanitized +} + func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error { if session == nil { return ErrPendingAuthSessionNotFound diff --git a/backend/migrations/122_pending_auth_completion_token_cleanup.sql b/backend/migrations/122_pending_auth_completion_token_cleanup.sql new file mode 100644 index 00000000..e6341142 --- /dev/null +++ b/backend/migrations/122_pending_auth_completion_token_cleanup.sql @@ -0,0 +1,15 @@ +UPDATE pending_auth_sessions +SET + local_flow_state = jsonb_set( + local_flow_state, + '{completion_response}', + ((local_flow_state -> 'completion_response') - 'access_token' - 'refresh_token' - 'expires_in' - 'token_type'), + true + ) +WHERE jsonb_typeof(local_flow_state -> 'completion_response') = 'object' + AND ( + (local_flow_state -> 'completion_response') ? 'access_token' + OR (local_flow_state -> 'completion_response') ? 'refresh_token' + OR (local_flow_state -> 'completion_response') ? 'expires_in' + OR (local_flow_state -> 'completion_response') ? 'token_type' + ); diff --git a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql new file mode 100644 index 00000000..f6053ef0 --- /dev/null +++ b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql @@ -0,0 +1,39 @@ +WITH migration_110 AS ( + SELECT applied_at + FROM schema_migrations + WHERE filename = '110_pending_auth_and_provider_default_grants.sql' +), +legacy_provider_defaults AS ( + SELECT provider_type + FROM ( + VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat') + ) AS providers(provider_type) + CROSS JOIN migration_110 + JOIN settings balance + ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance' + JOIN settings concurrency + ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency' + JOIN settings subscriptions + ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions' + JOIN settings grant_on_signup + ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup' + JOIN settings grant_on_first_bind + ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind' + WHERE balance.value = '0' + AND concurrency.value = '5' + AND subscriptions.value = '[]' + AND grant_on_signup.value = 'true' + AND grant_on_first_bind.value = 'false' + AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' +) +UPDATE settings +SET + value = 'false', + updated_at = NOW() +FROM legacy_provider_defaults +WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup' + AND settings.value = 'true'; diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go index 988876a9..48cc427b 100644 --- a/backend/migrations/auth_identity_payment_migrations_regression_test.go +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -59,3 +59,28 @@ func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) { require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no") require.Contains(t, followupSQL, "WHERE out_trade_no <> ''") } + +func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) { + content, err := FS.ReadFile("122_pending_auth_completion_token_cleanup.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "UPDATE pending_auth_sessions") + require.Contains(t, sql, "completion_response") + require.Contains(t, sql, "access_token") + require.Contains(t, sql, "refresh_token") + require.Contains(t, sql, "expires_in") + require.Contains(t, sql, "token_type") +} + +func TestMigration123BackfillsLegacyAuthSourceGrantDefaultsSafely(t *testing.T) { + content, err := FS.ReadFile("123_fix_legacy_auth_source_grant_on_signup_defaults.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "110_pending_auth_and_provider_default_grants.sql") + require.Contains(t, sql, "schema_migrations") + require.Contains(t, sql, "updated_at") + require.Contains(t, sql, "'_grant_on_signup'") + require.Contains(t, sql, "value = 'false'") +} From 1ffebbb568375aa8d75f9e530fe8aad76bc747c9 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 12:29:52 +0800 Subject: [PATCH 09/31] fix(migrations): keep auth identity and payment upgrades safe --- ...ntity_legacy_migration_integration_test.go | 217 +++++++++++------- .../migrations_schema_integration_test.go | 89 +++++++ ...en_auth_identity_migration_report_type.sql | 14 ++ ...payment_orders_out_trade_no_index_name.sql | 22 ++ ...tity_payment_migrations_regression_test.go | 16 +- 5 files changed, 278 insertions(+), 80 deletions(-) create mode 100644 backend/migrations/108a_widen_auth_identity_migration_report_type.sql create mode 100644 backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go index e59c257c..41b64de7 100644 --- a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go +++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go @@ -4,6 +4,7 @@ package repository import ( "context" + "database/sql" "os" "path/filepath" "strconv" @@ -20,32 +21,8 @@ func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) { migrationSQL, err := os.ReadFile(migrationPath) require.NoError(t, err) - _, err = tx.ExecContext(ctx, ` -CREATE TABLE IF NOT EXISTS user_external_identities ( - id BIGSERIAL PRIMARY KEY, - user_id BIGINT NOT NULL, - provider TEXT NOT NULL, - provider_user_id TEXT NOT NULL, - provider_union_id TEXT NULL, - provider_username TEXT NOT NULL DEFAULT '', - display_name TEXT NOT NULL DEFAULT '', - profile_url TEXT NOT NULL DEFAULT '', - avatar_url TEXT NOT NULL DEFAULT '', - metadata TEXT NOT NULL DEFAULT '{}', - created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP -); - - TRUNCATE TABLE - auth_identity_channels, - identity_adoption_decisions, - auth_identities, - auth_identity_migration_reports, - user_external_identities, - users - RESTART IDENTITY CASCADE; -`) - require.NoError(t, err) + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) var linuxDoUserID int64 require.NoError(t, tx.QueryRowContext(ctx, ` @@ -218,32 +195,8 @@ func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectM migration116SQL, err := os.ReadFile(migration116Path) require.NoError(t, err) - _, err = tx.ExecContext(ctx, ` -CREATE TABLE IF NOT EXISTS user_external_identities ( - id BIGSERIAL PRIMARY KEY, - user_id BIGINT NOT NULL, - provider TEXT NOT NULL, - provider_user_id TEXT NOT NULL, - provider_union_id TEXT NULL, - provider_username TEXT NOT NULL DEFAULT '', - display_name TEXT NOT NULL DEFAULT '', - profile_url TEXT NOT NULL DEFAULT '', - avatar_url TEXT NOT NULL DEFAULT '', - metadata TEXT NOT NULL DEFAULT '{}', - created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP -); - -TRUNCATE TABLE - auth_identity_channels, - identity_adoption_decisions, - auth_identities, - auth_identity_migration_reports, - user_external_identities, - users -RESTART IDENTITY CASCADE; -`) - require.NoError(t, err) + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) var linuxDoMalformedUserID int64 require.NoError(t, tx.QueryRowContext(ctx, ` @@ -408,32 +361,8 @@ func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngrades migrationSQL, err := os.ReadFile(migrationPath) require.NoError(t, err) - _, err = tx.ExecContext(ctx, ` -CREATE TABLE IF NOT EXISTS user_external_identities ( - id BIGSERIAL PRIMARY KEY, - user_id BIGINT NOT NULL, - provider TEXT NOT NULL, - provider_user_id TEXT NOT NULL, - provider_union_id TEXT NULL, - provider_username TEXT NOT NULL DEFAULT '', - display_name TEXT NOT NULL DEFAULT '', - profile_url TEXT NOT NULL DEFAULT '', - avatar_url TEXT NOT NULL DEFAULT '', - metadata TEXT NOT NULL DEFAULT '{}', - created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP -); - - TRUNCATE TABLE - auth_identity_channels, - identity_adoption_decisions, - auth_identities, - auth_identity_migration_reports, - user_external_identities, - users - RESTART IDENTITY CASCADE; -`) - require.NoError(t, err) + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) userIDs := make([]int64, 0, 8) for _, email := range []string{ @@ -643,6 +572,136 @@ FROM auth_identity_migration_reports require.NoError(t, tx.QueryRowContext(ctx, ` SELECT COUNT(*) FROM auth_identity_migration_reports -`).Scan(&afterCount)) + `).Scan(&afterCount)) require.Equal(t, beforeCount, afterCount) } + +func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql") + migration108aSQL, err := os.ReadFile(migration108aPath) + require.NoError(t, err) + + migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql") + migration109SQL, err := os.ReadFile(migration109Path) + require.NoError(t, err) + + migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migration116SQL, err := os.ReadFile(migration116Path) + require.NoError(t, err) + + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) + + _, err = tx.ExecContext(ctx, ` +ALTER TABLE auth_identity_migration_reports +ALTER COLUMN report_type TYPE VARCHAR(40); +`) + require.NoError(t, err) + + var oidcSyntheticUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('oidc-before-121@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&oidcSyntheticUserID)) + + var linuxdoLegacyUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-before-121@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxdoLegacyUserID)) + + var invalidMetadataLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-before-121', NULL, 'legacy-linuxdo-before-121', 'Legacy LinuxDo Before 121', '{invalid') +RETURNING id +`, linuxdoLegacyUserID).Scan(&invalidMetadataLegacyID)) + + _, err = tx.ExecContext(ctx, string(migration108aSQL)) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration109SQL)) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration116SQL)) + require.NoError(t, err) + + var reportTypeWidth int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT character_maximum_length +FROM information_schema.columns +WHERE table_schema = 'public' + AND table_name = 'auth_identity_migration_reports' + AND column_name = 'report_type' +`).Scan(&reportTypeWidth)) + require.Equal(t, 80, reportTypeWidth) + + var oidcSyntheticRecoveryReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery' + AND report_key = $1 +`, strconv.FormatInt(oidcSyntheticUserID, 10)).Scan(&oidcSyntheticRecoveryReportCount)) + require.Equal(t, 1, oidcSyntheticRecoveryReportCount) + + var invalidMetadataReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_invalid_metadata_json' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(invalidMetadataLegacyID, 10)).Scan(&invalidMetadataReportCount)) + require.Equal(t, 1, invalidMetadataReportCount) +} + +func prepareLegacyExternalIdentitiesTable(t *testing.T, tx *sql.Tx, ctx context.Context) { + t.Helper() + + _, err := tx.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS user_external_identities ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + provider TEXT NOT NULL, + provider_user_id TEXT NOT NULL, + provider_union_id TEXT NULL, + provider_username TEXT NOT NULL DEFAULT '', + display_name TEXT NOT NULL DEFAULT '', + profile_url TEXT NOT NULL DEFAULT '', + avatar_url TEXT NOT NULL DEFAULT '', + metadata TEXT NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); +`) + require.NoError(t, err) +} + +func truncateAuthIdentityLegacyFixtureTables(t *testing.T, tx *sql.Tx, ctx context.Context) { + t.Helper() + + _, err := tx.ExecContext(ctx, ` +TRUNCATE TABLE + auth_identity_channels, + identity_adoption_decisions, + pending_auth_sessions, + auth_identities, + auth_identity_migration_reports, + user_provider_default_grants, + user_avatars, + user_external_identities, + users +RESTART IDENTITY CASCADE; +`) + require.NoError(t, err) +} diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index dd3019bb..ac4dea18 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -89,6 +89,22 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false) } +func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) { + tx := testTx(t) + + requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false) + + requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE") + requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE") + requireForeignKeyOnDelete(t, tx, "pending_auth_sessions", "target_user_id", "users", "SET NULL") + requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "pending_auth_session_id", "pending_auth_sessions", "CASCADE") + requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "identity_id", "auth_identities", "SET NULL") + + requireIndex(t, tx, "payment_orders", "paymentorder_out_trade_no") + requirePartialUniqueIndexDefinition(t, tx, "payment_orders", "paymentorder_out_trade_no", "out_trade_no", "WHERE") + requireIndexAbsent(t, tx, "payment_orders", "paymentorder_out_trade_no_unique") +} + func requireIndex(t *testing.T, tx *sql.Tx, table, index string) { t.Helper() @@ -106,6 +122,79 @@ SELECT EXISTS ( require.True(t, exists, "expected index %s on %s", index, table) } +func requireIndexAbsent(t *testing.T, tx *sql.Tx, table, index string) { + t.Helper() + + var exists bool + err := tx.QueryRowContext(context.Background(), ` +SELECT EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = $1 + AND indexname = $2 +) +`, table, index).Scan(&exists) + require.NoError(t, err, "query pg_indexes for %s.%s", table, index) + require.False(t, exists, "expected index %s on %s to be absent", index, table) +} + +func requirePartialUniqueIndexDefinition(t *testing.T, tx *sql.Tx, table, index string, fragments ...string) { + t.Helper() + + var ( + unique bool + def string + ) + + err := tx.QueryRowContext(context.Background(), ` +SELECT + i.indisunique, + pg_get_indexdef(i.indexrelid) +FROM pg_class idx +JOIN pg_index i ON i.indexrelid = idx.oid +JOIN pg_class tbl ON tbl.oid = i.indrelid +JOIN pg_namespace ns ON ns.oid = tbl.relnamespace +WHERE ns.nspname = 'public' + AND tbl.relname = $1 + AND idx.relname = $2 +`, table, index).Scan(&unique, &def) + require.NoError(t, err, "query index definition for %s.%s", table, index) + require.True(t, unique, "expected index %s on %s to be unique", index, table) + + for _, fragment := range fragments { + require.Contains(t, def, fragment, "expected index definition for %s.%s to contain %q", table, index, fragment) + } +} + +func requireForeignKeyOnDelete(t *testing.T, tx *sql.Tx, table, column, refTable, expected string) { + t.Helper() + + var actual string + err := tx.QueryRowContext(context.Background(), ` +SELECT CASE c.confdeltype + WHEN 'a' THEN 'NO ACTION' + WHEN 'r' THEN 'RESTRICT' + WHEN 'c' THEN 'CASCADE' + WHEN 'n' THEN 'SET NULL' + WHEN 'd' THEN 'SET DEFAULT' +END +FROM pg_constraint c +JOIN pg_class tbl ON tbl.oid = c.conrelid +JOIN pg_namespace ns ON ns.oid = tbl.relnamespace +JOIN pg_class ref_tbl ON ref_tbl.oid = c.confrelid +JOIN pg_attribute attr ON attr.attrelid = tbl.oid AND attr.attnum = ANY(c.conkey) +WHERE ns.nspname = 'public' + AND c.contype = 'f' + AND tbl.relname = $1 + AND attr.attname = $2 + AND ref_tbl.relname = $3 +LIMIT 1 +`, table, column, refTable).Scan(&actual) + require.NoError(t, err, "query foreign key action for %s.%s -> %s", table, column, refTable) + require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable) +} + func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { t.Helper() diff --git a/backend/migrations/108a_widen_auth_identity_migration_report_type.sql b/backend/migrations/108a_widen_auth_identity_migration_report_type.sql new file mode 100644 index 00000000..bc170fb8 --- /dev/null +++ b/backend/migrations/108a_widen_auth_identity_migration_report_type.sql @@ -0,0 +1,14 @@ +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'auth_identity_migration_reports' + AND column_name = 'report_type' + AND COALESCE(character_maximum_length, 0) < 80 + ) THEN + ALTER TABLE auth_identity_migration_reports + ALTER COLUMN report_type TYPE VARCHAR(80); + END IF; +END $$; diff --git a/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql b/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql new file mode 100644 index 00000000..ef2599dc --- /dev/null +++ b/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql @@ -0,0 +1,22 @@ +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = 'payment_orders' + AND indexname = 'paymentorder_out_trade_no_unique' + ) THEN + IF EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = 'payment_orders' + AND indexname = 'paymentorder_out_trade_no' + ) THEN + EXECUTE 'DROP INDEX IF EXISTS paymentorder_out_trade_no'; + END IF; + + EXECUTE 'ALTER INDEX paymentorder_out_trade_no_unique RENAME TO paymentorder_out_trade_no'; + END IF; +END $$; diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go index 48cc427b..dbf8fc47 100644 --- a/backend/migrations/auth_identity_payment_migrations_regression_test.go +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -26,7 +26,14 @@ func TestMigration118DoesNotForceOverwriteAuthSourceGrantDefaults(t *testing.T) require.True(t, strings.Contains(sql, "ON CONFLICT (key) DO NOTHING")) } -func TestMigration109KeepsPublishedBackfillBodyAndDefersReportTypeWidening(t *testing.T) { +func TestAuthIdentityReportTypeWideningRunsBeforeLongReportWritersAndStillReconcilesAt121(t *testing.T) { + preflightContent, err := FS.ReadFile("108a_widen_auth_identity_migration_report_type.sql") + require.NoError(t, err) + + preflightSQL := string(preflightContent) + require.Contains(t, preflightSQL, "ALTER TABLE auth_identity_migration_reports") + require.Contains(t, preflightSQL, "ALTER COLUMN report_type TYPE VARCHAR(80)") + content, err := FS.ReadFile("109_auth_identity_compat_backfill.sql") require.NoError(t, err) @@ -58,6 +65,13 @@ func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) { require.Contains(t, followupSQL, "CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique") require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no") require.Contains(t, followupSQL, "WHERE out_trade_no <> ''") + + alignmentContent, err := FS.ReadFile("120a_align_payment_orders_out_trade_no_index_name.sql") + require.NoError(t, err) + + alignmentSQL := string(alignmentContent) + require.Contains(t, alignmentSQL, "paymentorder_out_trade_no_unique") + require.Contains(t, alignmentSQL, "RENAME TO paymentorder_out_trade_no") } func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) { From 767f2f2dfe4e08490ca28f073c0156b2b6b7d912 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 12:30:00 +0800 Subject: [PATCH 10/31] fix(auth): harden pending oauth and backend mode flows --- backend/internal/handler/auth_handler.go | 2 + .../internal/handler/auth_linuxdo_oauth.go | 9 ++ .../handler/auth_linuxdo_oauth_test.go | 55 ++++++++ .../handler/auth_oauth_logout_test.go | 68 +++++++++ .../handler/auth_oauth_pending_flow.go | 129 ++++++++++++++++++ .../handler/auth_oauth_pending_flow_test.go | 20 +++ backend/internal/handler/auth_oidc_oauth.go | 23 +++- .../internal/handler/auth_oidc_oauth_test.go | 56 ++++++++ backend/internal/handler/auth_wechat_oauth.go | 9 ++ .../handler/auth_wechat_oauth_test.go | 104 ++++++++++---- .../server/middleware/backend_mode_guard.go | 47 +++++-- .../middleware/backend_mode_guard_test.go | 90 ++++++++++++ 12 files changed, 568 insertions(+), 44 deletions(-) create mode 100644 backend/internal/handler/auth_oauth_logout_test.go diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 9801b3b3..acd43e9f 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -678,6 +678,8 @@ func (h *AuthHandler) Logout(c *gin.Context) { // 不影响登出流程 } } + h.consumePendingOAuthSessionOnLogout(c) + clearOAuthLogoutCookies(c) response.Success(c, LogoutResponse{ Message: "Logged out successfully", diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index ef9a5bca..157be066 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -469,6 +469,15 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil { + response.ErrorFrom(c, err) + return + } else if handled { + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession)) + return + } else { + session = updatedSession + } if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index 841dc442..a9a5e3e6 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -757,6 +757,61 @@ func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *test require.Nil(t, storedSession.ConsumedAt) } +func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-choice-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-choice-subject-1"). + SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-choice-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": oauthPendingChoiceStep, + "redirect": "/dashboard", + "email": "fresh@example.com", + "resolved_email": "fresh@example.com", + "force_email_on_signup": true, + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["force_email_on_signup"]) + require.Empty(t, responseData["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { t.Helper() handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) diff --git a/backend/internal/handler/auth_oauth_logout_test.go b/backend/internal/handler/auth_oauth_logout_test.go new file mode 100644 index 00000000..0d4f94b1 --- /dev/null +++ b/backend/internal/handler/auth_oauth_logout_test.go @@ -0,0 +1,68 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("logout-pending-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("logout-subject-123"). + SetBrowserSessionKey("logout-browser-session-key"). + SetResolvedEmail("logout@example.com"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")}) + req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"}) + req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")}) + req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")}) + req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")}) + req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")}) + ginCtx.Request = req + + handler.Logout(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + cookies := recorder.Result().Cookies() + for _, name := range []string{ + oauthPendingSessionCookieName, + oauthPendingBrowserCookieName, + oauthBindAccessTokenCookieName, + linuxDoOAuthStateCookieName, + oidcOAuthStateCookieName, + wechatOAuthStateCookieName, + wechatPaymentOAuthStateName, + } { + cookie := findCookie(cookies, name) + require.NotNil(t, cookie, name) + require.Equal(t, -1, cookie.MaxAge, name) + require.True(t, cookie.HttpOnly, name) + } + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 658a5f52..c5df4db1 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -310,6 +310,78 @@ func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSes return nil } +func buildLegacyCompleteRegistrationPendingResponse( + session *dbent.PendingAuthSession, + forceEmailOnSignup bool, + emailVerificationRequired bool, +) map[string]any { + completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "create_account_allowed": true, + "force_email_on_signup": forceEmailOnSignup, + })) + + if email := strings.TrimSpace(session.ResolvedEmail); email != "" { + if _, exists := completionResponse["email"]; !exists { + completionResponse["email"] = email + } + if _, exists := completionResponse["resolved_email"]; !exists { + completionResponse["resolved_email"] = email + } + } + if _, exists := completionResponse["choice_reason"]; !exists { + switch { + case forceEmailOnSignup: + completionResponse["choice_reason"] = "force_email_on_signup" + case emailVerificationRequired: + completionResponse["choice_reason"] = "email_verification_required" + default: + completionResponse["choice_reason"] = "third_party_signup" + } + } + return completionResponse +} + +func (h *AuthHandler) legacyCompleteRegistrationSessionStatus( + c *gin.Context, + session *dbent.PendingAuthSession, +) (*dbent.PendingAuthSession, bool, error) { + if session == nil { + return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + + payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil)) + if step := pendingSessionStringValue(payload, "step"); step != "" { + return session, true, nil + } + + emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context()) + forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context()) + if !emailVerificationRequired && !forceEmailOnSignup { + return session, false, nil + } + + client := h.entClient() + if client == nil { + return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + updatedSession, err := updatePendingOAuthSessionProgress( + c.Request.Context(), + client, + session, + strings.TrimSpace(session.Intent), + strings.TrimSpace(session.ResolvedEmail), + nil, + buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired), + ) + if err != nil { + return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err) + } + return updatedSession, true, nil +} + func (r oauthAdoptionDecisionRequest) hasDecision() bool { return r.AdoptDisplayName != nil || r.AdoptAvatar != nil } @@ -1272,6 +1344,59 @@ func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.Au return svc, session, clearCookies, nil } +func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) { + if c == nil || c.Request == nil { + return + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + return + } + + svc, err := h.pendingIdentityService() + if err != nil { + return + } + _, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) +} + +func clearOAuthLogoutCookies(c *gin.Context) { + secureCookie := isRequestHTTPS(c) + + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + clearOAuthBindAccessTokenCookie(c, secureCookie) + + clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) + clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) + clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) + clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie) + clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie) + + oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie) + oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie) + oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie) + oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie) + oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie) + oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie) + + wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie) + + wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie) +} + func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H { completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil)) payload := gin.H{ @@ -1451,6 +1576,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) response.ErrorFrom(c, err) return } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { response.BadRequest(c, "Pending oauth session provider mismatch") return diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index c0413d4d..b3b8dfe1 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -1228,6 +1228,26 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) require.Nil(t, storedSession.ConsumedAt) } +func TestLogoutClearsPendingOAuthAndBindCookies(t *testing.T) { + handler, _ := newOAuthPendingFlowTestHandler(t, false) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue("pending-session-token")}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("pending-browser-key")}) + req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-token"}) + ginCtx.Request = req + + handler.Logout(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName).MaxAge) + require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingBrowserCookieName).MaxAge) + require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName).MaxAge) +} + func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) { handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810") ctx := context.Background() diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 7fe4b8d9..6345938b 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -374,19 +374,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { ProviderSubject: subject, } upstreamClaims := map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string { if idClaims != nil { return idClaims.Name } return "" }(), username), - "suggested_avatar_url": userInfoClaims.AvatarURL, + "suggested_avatar_url": userInfoClaims.AvatarURL, } if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) { upstreamClaims["compat_email"] = compatEmail @@ -622,6 +622,15 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil { + response.ErrorFrom(c, err) + return + } else if handled { + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession)) + return + } else { + session = updatedSession + } if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index a600fd56..63008344 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -692,6 +692,62 @@ func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing require.Nil(t, storedSession.ConsumedAt) } +func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-choice-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-choice-subject-1"). + SetResolvedEmail("oidc-choice-subject-1@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-choice-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": oauthPendingChoiceStep, + "redirect": "/dashboard", + "email": "fresh@example.com", + "resolved_email": "fresh@example.com", + "force_email_on_signup": true, + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-choice-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["force_email_on_signup"]) + require.Empty(t, responseData["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + type oidcProviderFixture struct { Subject string PreferredUsername string diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 39703ce7..3ed20a7d 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -525,6 +525,15 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil { + response.ErrorFrom(c, err) + return + } else if handled { + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession)) + return + } else { + session = updatedSession + } if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 99006701..349e7dd2 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -19,7 +19,6 @@ import ( "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" - dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/repository" @@ -700,7 +699,7 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes require.Zero(t, count) } -func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { +func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -773,27 +772,32 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing require.Equal(t, http.StatusOK, completeRecorder.Code) responseData := decodeJSONBody(t, completeRecorder) - require.NotEmpty(t, responseData["access_token"]) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["adoption_required"]) + require.Empty(t, responseData["access_token"]) - userEntity, err := client.User.Query(). - Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")). + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(pendingSession.ID)). Only(ctx) require.NoError(t, err) - require.Equal(t, "WeChat Display", userEntity.Username) + require.Nil(t, consumed.ConsumedAt) - identity, err := client.AuthIdentity.Query(). + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + identityCount, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("wechat"), authidentity.ProviderKeyEQ("wechat-main"), authidentity.ProviderSubjectEQ("union-456"), ). - Only(ctx) + Count(ctx) require.NoError(t, err) - require.Equal(t, userEntity.ID, identity.UserID) - require.Equal(t, "WeChat Display", identity.Metadata["display_name"]) - require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"]) + require.Zero(t, identityCount) - channel, err := client.AuthIdentityChannel.Query(). + channelCount, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ("wechat"), authidentitychannel.ProviderKeyEQ("wechat-main"), @@ -801,25 +805,15 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing authidentitychannel.ChannelAppIDEQ("wx-open-app"), authidentitychannel.ChannelSubjectEQ("openid-123"), ). - Only(ctx) + Count(ctx) require.NoError(t, err) - require.Equal(t, identity.ID, channel.IdentityID) - require.Equal(t, "union-456", channel.Metadata["unionid"]) + require.Zero(t, channelCount) - decision, err := client.IdentityAdoptionDecision.Query(). + decisionCount, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). - Only(ctx) + Count(ctx) require.NoError(t, err) - require.NotNil(t, decision.IdentityID) - require.Equal(t, identity.ID, *decision.IdentityID) - require.True(t, decision.AdoptDisplayName) - require.True(t, decision.AdoptAvatar) - - consumed, err := client.PendingAuthSession.Query(). - Where(pendingauthsession.IDEQ(pendingSession.ID)). - Only(ctx) - require.NoError(t, err) - require.NotNil(t, consumed.ConsumedAt) + require.Zero(t, decisionCount) } func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { @@ -981,6 +975,62 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi require.Nil(t, storedSession.ConsumedAt) } +func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) { + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + session, err := client.PendingAuthSession.Create(). + SetSessionToken("wechat-complete-choice-session"). + SetIntent("login"). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("wechat-choice-subject-1"). + SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid"). + SetBrowserSessionKey("wechat-choice-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "wechat_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": oauthPendingChoiceStep, + "redirect": "/dashboard", + "email": "fresh@example.com", + "resolved_email": "fresh@example.com", + "force_email_on_signup": true, + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")}) + completeCtx.Request = req + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["force_email_on_signup"]) + require.Empty(t, responseData["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL diff --git a/backend/internal/server/middleware/backend_mode_guard.go b/backend/internal/server/middleware/backend_mode_guard.go index 46482af3..ae53037e 100644 --- a/backend/internal/server/middleware/backend_mode_guard.go +++ b/backend/internal/server/middleware/backend_mode_guard.go @@ -27,23 +27,50 @@ func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFun } } +func backendModeAllowsAuthPath(path string) bool { + path = strings.ToLower(strings.TrimSpace(path)) + for _, suffix := range []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} { + if strings.HasSuffix(path, suffix) { + return true + } + } + + for _, suffix := range []string{ + "/auth/oauth/linuxdo/callback", + "/auth/oauth/wechat/callback", + "/auth/oauth/wechat/payment/callback", + "/auth/oauth/oidc/callback", + "/auth/oauth/linuxdo/complete-registration", + "/auth/oauth/wechat/complete-registration", + "/auth/oauth/oidc/complete-registration", + "/auth/oauth/linuxdo/create-account", + "/auth/oauth/wechat/create-account", + "/auth/oauth/oidc/create-account", + "/auth/oauth/linuxdo/bind-login", + "/auth/oauth/wechat/bind-login", + "/auth/oauth/oidc/bind-login", + } { + if strings.HasSuffix(path, suffix) { + return true + } + } + + return strings.Contains(path, "/auth/oauth/pending/") +} + // BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled. -// Allows: login, login/2fa, logout, refresh (admin needs these). -// Blocks: register, forgot-password, reset-password, OAuth, etc. +// Allows the minimal auth surface admins still need in backend mode, including +// OAuth callbacks and pending continuations. Handler-level backend mode checks +// still enforce admin-only login and forbid self-service registration. func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc { return func(c *gin.Context) { if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) { c.Next() return } - path := c.Request.URL.Path - // Allow login, 2FA, logout, refresh, public settings - allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} - for _, suffix := range allowedSuffixes { - if strings.HasSuffix(path, suffix) { - c.Next() - return - } + if backendModeAllowsAuthPath(c.Request.URL.Path) { + c.Next() + return } response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.") c.Abort() diff --git a/backend/internal/server/middleware/backend_mode_guard_test.go b/backend/internal/server/middleware/backend_mode_guard_test.go index 8878ebc9..bd77677b 100644 --- a/backend/internal/server/middleware/backend_mode_guard_test.go +++ b/backend/internal/server/middleware/backend_mode_guard_test.go @@ -198,6 +198,96 @@ func TestBackendModeAuthGuard(t *testing.T) { path: "/api/v1/auth/refresh", wantStatus: http.StatusOK, }, + { + name: "enabled_blocks_linuxdo_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/linuxdo/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_linuxdo_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/linuxdo/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_wechat_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_wechat_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_wechat_payment_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/payment/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_wechat_payment_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/payment/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_oidc_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/oidc/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_oidc_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/oidc/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_exchange", + enabled: "true", + path: "/api/v1/auth/oauth/pending/exchange", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_send_verify_code", + enabled: "true", + path: "/api/v1/auth/oauth/pending/send-verify-code", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_create_account", + enabled: "true", + path: "/api/v1/auth/oauth/pending/create-account", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_bind_login", + enabled: "true", + path: "/api/v1/auth/oauth/pending/bind-login", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_provider_bind_login", + enabled: "true", + path: "/api/v1/auth/oauth/oidc/bind-login", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_provider_create_account", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/create-account", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_legacy_complete_registration", + enabled: "true", + path: "/api/v1/auth/oauth/linuxdo/complete-registration", + wantStatus: http.StatusOK, + }, { name: "enabled_blocks_register", enabled: "true", From b2e07121901458173921a7241c4f846188363173 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 12:30:07 +0800 Subject: [PATCH 11/31] fix(settings): preserve oauth config compatibility on upgrade --- backend/internal/config/config.go | 294 +++++++++++++++++- backend/internal/config/config_test.go | 33 +- .../internal/handler/admin/setting_handler.go | 46 ++- backend/internal/server/api_contract_test.go | 194 ++++++++++++ backend/internal/service/setting_service.go | 242 ++++++++------ .../setting_service_oidc_config_test.go | 50 +++ .../service/setting_service_public_test.go | 19 ++ .../setting_service_wechat_config_test.go | 51 +++ 8 files changed, 830 insertions(+), 99 deletions(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 32ad91b7..d47eadd4 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -70,6 +70,7 @@ type Config struct { JWT JWTConfig `mapstructure:"jwt"` Totp TotpConfig `mapstructure:"totp"` LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + WeChat WeChatConnectConfig `mapstructure:"wechat_connect"` OIDC OIDCConnectConfig `mapstructure:"oidc_connect"` Default DefaultConfig `mapstructure:"default"` RateLimit RateLimitConfig `mapstructure:"rate_limit"` @@ -190,6 +191,25 @@ type LinuxDoConnectConfig struct { UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` } +type WeChatConnectConfig struct { + Enabled bool `mapstructure:"enabled"` + AppID string `mapstructure:"app_id"` + AppSecret string `mapstructure:"app_secret"` + OpenAppID string `mapstructure:"open_app_id"` + OpenAppSecret string `mapstructure:"open_app_secret"` + MPAppID string `mapstructure:"mp_app_id"` + MPAppSecret string `mapstructure:"mp_app_secret"` + MobileAppID string `mapstructure:"mobile_app_id"` + MobileAppSecret string `mapstructure:"mobile_app_secret"` + OpenEnabled bool `mapstructure:"open_enabled"` + MPEnabled bool `mapstructure:"mp_enabled"` + MobileEnabled bool `mapstructure:"mobile_enabled"` + Mode string `mapstructure:"mode"` + Scopes string `mapstructure:"scopes"` + RedirectURL string `mapstructure:"redirect_url"` + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` +} + type OIDCConnectConfig struct { Enabled bool `mapstructure:"enabled"` ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等 @@ -218,6 +238,217 @@ type OIDCConnectConfig struct { UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` } +const ( + defaultWeChatConnectMode = "open" + defaultWeChatConnectScopes = "snsapi_login" + defaultWeChatConnectFrontendRedirect = "/auth/wechat/callback" +) + +func firstNonEmptyString(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + +func normalizeWeChatConnectMode(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "mp": + return "mp" + case "mobile": + return "mobile" + default: + return defaultWeChatConnectMode + } +} + +func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string { + mode = normalizeWeChatConnectMode(mode) + switch mode { + case "open": + if openEnabled { + return "open" + } + case "mp": + if mpEnabled { + return "mp" + } + case "mobile": + if mobileEnabled { + return "mobile" + } + } + switch { + case openEnabled: + return "open" + case mpEnabled: + return "mp" + case mobileEnabled: + return "mobile" + default: + return mode + } +} + +func defaultWeChatConnectScopesForMode(mode string) string { + switch normalizeWeChatConnectMode(mode) { + case "mp": + return "snsapi_userinfo" + case "mobile": + return "" + default: + return defaultWeChatConnectScopes + } +} + +func normalizeWeChatConnectScopes(raw, mode string) string { + switch normalizeWeChatConnectMode(mode) { + case "mp": + switch strings.TrimSpace(raw) { + case "snsapi_base": + return "snsapi_base" + case "snsapi_userinfo": + return "snsapi_userinfo" + default: + return defaultWeChatConnectScopesForMode(mode) + } + case "mobile": + return "" + default: + return defaultWeChatConnectScopes + } +} + +func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool { + if viper.InConfig(configKey) { + return false + } + _, hasNewEnv := os.LookupEnv(envKey) + return !hasNewEnv +} + +func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) { + if cfg == nil { + return + } + + legacyOpenAppID := "" + if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_id", "WECHAT_CONNECT_OPEN_APP_ID") && + shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") { + legacyOpenAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) + if legacyOpenAppID != "" { + cfg.OpenAppID = legacyOpenAppID + } + } + + legacyOpenAppSecret := "" + if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_secret", "WECHAT_CONNECT_OPEN_APP_SECRET") && + shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") { + legacyOpenAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) + if legacyOpenAppSecret != "" { + cfg.OpenAppSecret = legacyOpenAppSecret + } + } + + legacyMPAppID := "" + if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_id", "WECHAT_CONNECT_MP_APP_ID") && + shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") { + legacyMPAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) + if legacyMPAppID != "" { + cfg.MPAppID = legacyMPAppID + } + } + + legacyMPAppSecret := "" + if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_secret", "WECHAT_CONNECT_MP_APP_SECRET") && + shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") { + legacyMPAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) + if legacyMPAppSecret != "" { + cfg.MPAppSecret = legacyMPAppSecret + } + } + + if shouldApplyLegacyWeChatEnv("wechat_connect.frontend_redirect_url", "WECHAT_CONNECT_FRONTEND_REDIRECT_URL") { + if legacyFrontend := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")); legacyFrontend != "" { + cfg.FrontendRedirectURL = legacyFrontend + } + } + + hasLegacyOpen := legacyOpenAppID != "" && legacyOpenAppSecret != "" + hasLegacyMP := legacyMPAppID != "" && legacyMPAppSecret != "" + + if shouldApplyLegacyWeChatEnv("wechat_connect.enabled", "WECHAT_CONNECT_ENABLED") && (hasLegacyOpen || hasLegacyMP) { + cfg.Enabled = true + } + if shouldApplyLegacyWeChatEnv("wechat_connect.open_enabled", "WECHAT_CONNECT_OPEN_ENABLED") && hasLegacyOpen { + cfg.OpenEnabled = true + } + if shouldApplyLegacyWeChatEnv("wechat_connect.mp_enabled", "WECHAT_CONNECT_MP_ENABLED") && hasLegacyMP { + cfg.MPEnabled = true + } + if shouldApplyLegacyWeChatEnv("wechat_connect.mode", "WECHAT_CONNECT_MODE") { + switch { + case hasLegacyMP && !hasLegacyOpen: + cfg.Mode = "mp" + case hasLegacyOpen: + cfg.Mode = "open" + } + } + if shouldApplyLegacyWeChatEnv("wechat_connect.scopes", "WECHAT_CONNECT_SCOPES") { + switch { + case hasLegacyMP && !hasLegacyOpen: + cfg.Scopes = defaultWeChatConnectScopesForMode("mp") + case hasLegacyOpen: + cfg.Scopes = defaultWeChatConnectScopesForMode("open") + } + } +} + +func normalizeWeChatConnectConfig(cfg *WeChatConnectConfig) { + if cfg == nil { + return + } + + cfg.AppID = strings.TrimSpace(cfg.AppID) + cfg.AppSecret = strings.TrimSpace(cfg.AppSecret) + cfg.OpenAppID = strings.TrimSpace(cfg.OpenAppID) + cfg.OpenAppSecret = strings.TrimSpace(cfg.OpenAppSecret) + cfg.MPAppID = strings.TrimSpace(cfg.MPAppID) + cfg.MPAppSecret = strings.TrimSpace(cfg.MPAppSecret) + cfg.MobileAppID = strings.TrimSpace(cfg.MobileAppID) + cfg.MobileAppSecret = strings.TrimSpace(cfg.MobileAppSecret) + cfg.Mode = normalizeWeChatConnectMode(cfg.Mode) + cfg.RedirectURL = strings.TrimSpace(cfg.RedirectURL) + cfg.FrontendRedirectURL = strings.TrimSpace(cfg.FrontendRedirectURL) + + cfg.AppID = firstNonEmptyString(cfg.AppID, cfg.OpenAppID, cfg.MPAppID, cfg.MobileAppID) + cfg.AppSecret = firstNonEmptyString(cfg.AppSecret, cfg.OpenAppSecret, cfg.MPAppSecret, cfg.MobileAppSecret) + cfg.OpenAppID = firstNonEmptyString(cfg.OpenAppID, cfg.AppID) + cfg.OpenAppSecret = firstNonEmptyString(cfg.OpenAppSecret, cfg.AppSecret) + cfg.MPAppID = firstNonEmptyString(cfg.MPAppID, cfg.AppID) + cfg.MPAppSecret = firstNonEmptyString(cfg.MPAppSecret, cfg.AppSecret) + cfg.MobileAppID = firstNonEmptyString(cfg.MobileAppID, cfg.AppID) + cfg.MobileAppSecret = firstNonEmptyString(cfg.MobileAppSecret, cfg.AppSecret) + + if !cfg.OpenEnabled && !cfg.MPEnabled && !cfg.MobileEnabled && cfg.Enabled { + switch cfg.Mode { + case "mp": + cfg.MPEnabled = true + case "mobile": + cfg.MobileEnabled = true + default: + cfg.OpenEnabled = true + } + } + cfg.Mode = normalizeWeChatConnectStoredMode(cfg.OpenEnabled, cfg.MPEnabled, cfg.MobileEnabled, cfg.Mode) + cfg.Scopes = normalizeWeChatConnectScopes(cfg.Scopes, cfg.Mode) + if cfg.FrontendRedirectURL == "" { + cfg.FrontendRedirectURL = defaultWeChatConnectFrontendRedirect + } +} + // TokenRefreshConfig OAuth token自动刷新配置 type TokenRefreshConfig struct { // 是否启用自动刷新 @@ -1012,6 +1243,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath) cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath) cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath) + applyLegacyWeChatConnectEnvCompatibility(&cfg.WeChat) + normalizeWeChatConnectConfig(&cfg.WeChat) cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName) cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID) cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret) @@ -1207,6 +1440,24 @@ func setDefaults() { viper.SetDefault("linuxdo_connect.userinfo_id_path", "") viper.SetDefault("linuxdo_connect.userinfo_username_path", "") + // WeChat Connect OAuth 登录 + viper.SetDefault("wechat_connect.enabled", false) + viper.SetDefault("wechat_connect.app_id", "") + viper.SetDefault("wechat_connect.app_secret", "") + viper.SetDefault("wechat_connect.open_app_id", "") + viper.SetDefault("wechat_connect.open_app_secret", "") + viper.SetDefault("wechat_connect.mp_app_id", "") + viper.SetDefault("wechat_connect.mp_app_secret", "") + viper.SetDefault("wechat_connect.mobile_app_id", "") + viper.SetDefault("wechat_connect.mobile_app_secret", "") + viper.SetDefault("wechat_connect.open_enabled", false) + viper.SetDefault("wechat_connect.mp_enabled", false) + viper.SetDefault("wechat_connect.mobile_enabled", false) + viper.SetDefault("wechat_connect.mode", defaultWeChatConnectMode) + viper.SetDefault("wechat_connect.scopes", defaultWeChatConnectScopes) + viper.SetDefault("wechat_connect.redirect_url", "") + viper.SetDefault("wechat_connect.frontend_redirect_url", defaultWeChatConnectFrontendRedirect) + // Generic OIDC OAuth 登录 viper.SetDefault("oidc_connect.enabled", false) viper.SetDefault("oidc_connect.provider_name", "OIDC") @@ -1222,8 +1473,8 @@ func setDefaults() { viper.SetDefault("oidc_connect.redirect_url", "") viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback") viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post") - viper.SetDefault("oidc_connect.use_pkce", false) - viper.SetDefault("oidc_connect.validate_id_token", false) + viper.SetDefault("oidc_connect.use_pkce", true) + viper.SetDefault("oidc_connect.validate_id_token", true) viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256") viper.SetDefault("oidc_connect.clock_skew_seconds", 120) viper.SetDefault("oidc_connect.require_email_verified", false) @@ -1664,6 +1915,45 @@ func (c *Config) Validate() error { warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL) warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) } + if c.WeChat.Enabled { + weChat := c.WeChat + normalizeWeChatConnectConfig(&weChat) + + if weChat.OpenEnabled { + if strings.TrimSpace(weChat.OpenAppID) == "" { + return fmt.Errorf("wechat_connect.open_app_id is required when wechat_connect.open_enabled=true") + } + if strings.TrimSpace(weChat.OpenAppSecret) == "" { + return fmt.Errorf("wechat_connect.open_app_secret is required when wechat_connect.open_enabled=true") + } + } + if weChat.MPEnabled { + if strings.TrimSpace(weChat.MPAppID) == "" { + return fmt.Errorf("wechat_connect.mp_app_id is required when wechat_connect.mp_enabled=true") + } + if strings.TrimSpace(weChat.MPAppSecret) == "" { + return fmt.Errorf("wechat_connect.mp_app_secret is required when wechat_connect.mp_enabled=true") + } + } + if weChat.MobileEnabled { + if strings.TrimSpace(weChat.MobileAppID) == "" { + return fmt.Errorf("wechat_connect.mobile_app_id is required when wechat_connect.mobile_enabled=true") + } + if strings.TrimSpace(weChat.MobileAppSecret) == "" { + return fmt.Errorf("wechat_connect.mobile_app_secret is required when wechat_connect.mobile_enabled=true") + } + } + if v := strings.TrimSpace(weChat.RedirectURL); v != "" { + if err := ValidateAbsoluteHTTPURL(v); err != nil { + return fmt.Errorf("wechat_connect.redirect_url invalid: %w", err) + } + warnIfInsecureURL("wechat_connect.redirect_url", v) + } + if err := ValidateFrontendRedirectURL(weChat.FrontendRedirectURL); err != nil { + return fmt.Errorf("wechat_connect.frontend_redirect_url invalid: %w", err) + } + warnIfInsecureURL("wechat_connect.frontend_redirect_url", weChat.FrontendRedirectURL) + } if c.OIDC.Enabled { if strings.TrimSpace(c.OIDC.ClientID) == "" { return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true") diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index f40a5f57..8b59ef5f 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -225,6 +225,37 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) { } } +func TestLoadWeChatConnectConfigFromLegacyEnv(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app") + t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/legacy-callback") + + cfg, err := Load() + require.NoError(t, err) + require.True(t, cfg.WeChat.Enabled) + require.True(t, cfg.WeChat.OpenEnabled) + require.True(t, cfg.WeChat.MPEnabled) + require.False(t, cfg.WeChat.MobileEnabled) + require.Equal(t, "open", cfg.WeChat.Mode) + require.Equal(t, "wx-open-app", cfg.WeChat.OpenAppID) + require.Equal(t, "wx-open-secret", cfg.WeChat.OpenAppSecret) + require.Equal(t, "wx-mp-app", cfg.WeChat.MPAppID) + require.Equal(t, "wx-mp-secret", cfg.WeChat.MPAppSecret) + require.Equal(t, "/auth/wechat/legacy-callback", cfg.WeChat.FrontendRedirectURL) +} + +func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + require.NoError(t, err) + require.True(t, cfg.OIDC.UsePKCE) + require.True(t, cfg.OIDC.ValidateIDToken) +} + func TestLoadForcedCodexInstructionsTemplate(t *testing.T) { resetViperWithJWTSecret(t) @@ -424,7 +455,7 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T } } -func TestValidateOIDCAllowsDisablingPKCEAndIDTokenValidation(t *testing.T) { +func TestValidateOIDCAllowsExplicitCompatibilityOverridesForPKCEAndIDTokenValidation(t *testing.T) { resetViperWithJWTSecret(t) cfg, err := Load() diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index f85f199b..d340a8a6 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -565,6 +565,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes) req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL) req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL) + req.WeChatConnectAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectAppID, previousSettings.WeChatConnectAppID)) + req.WeChatConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectRedirectURL, previousSettings.WeChatConnectRedirectURL)) + req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectFrontendRedirectURL, previousSettings.WeChatConnectFrontendRedirectURL)) + if req.WeChatConnectMode == "" { + req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(previousSettings.WeChatConnectMode)) + } + if req.WeChatConnectScopes == "" { + req.WeChatConnectScopes = strings.TrimSpace(previousSettings.WeChatConnectScopes) + } if req.WeChatConnectMPEnabled && req.WeChatConnectMobileEnabled { response.BadRequest(c, "WeChat Official Account and Mobile App cannot be enabled at the same time") @@ -598,9 +607,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } - req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID)) - req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID)) - req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID)) + req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectOpenAppID, previousSettings.WeChatConnectAppID)) + req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMPAppID, previousSettings.WeChatConnectAppID)) + req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMobileAppID, previousSettings.WeChatConnectAppID)) if req.WeChatConnectOpenAppSecret == "" { req.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectOpenAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret)) @@ -691,10 +700,35 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(req.OIDCConnectUserInfoEmailPath) req.OIDCConnectUserInfoIDPath = strings.TrimSpace(req.OIDCConnectUserInfoIDPath) req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(req.OIDCConnectUserInfoUsernamePath) - - if req.OIDCConnectProviderName == "" { - req.OIDCConnectProviderName = "OIDC" + req.OIDCConnectProviderName = strings.TrimSpace(firstNonEmpty(req.OIDCConnectProviderName, previousSettings.OIDCConnectProviderName, "OIDC")) + req.OIDCConnectClientID = strings.TrimSpace(firstNonEmpty(req.OIDCConnectClientID, previousSettings.OIDCConnectClientID)) + req.OIDCConnectIssuerURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectIssuerURL, previousSettings.OIDCConnectIssuerURL)) + req.OIDCConnectDiscoveryURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectDiscoveryURL, previousSettings.OIDCConnectDiscoveryURL)) + req.OIDCConnectAuthorizeURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAuthorizeURL, previousSettings.OIDCConnectAuthorizeURL)) + req.OIDCConnectTokenURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenURL, previousSettings.OIDCConnectTokenURL)) + req.OIDCConnectUserInfoURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoURL, previousSettings.OIDCConnectUserInfoURL)) + req.OIDCConnectJWKSURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectJWKSURL, previousSettings.OIDCConnectJWKSURL)) + req.OIDCConnectScopes = strings.TrimSpace(firstNonEmpty(req.OIDCConnectScopes, previousSettings.OIDCConnectScopes, "openid email profile")) + req.OIDCConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectRedirectURL, previousSettings.OIDCConnectRedirectURL)) + req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectFrontendRedirectURL, previousSettings.OIDCConnectFrontendRedirectURL, "/auth/oidc/callback")) + req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenAuthMethod, previousSettings.OIDCConnectTokenAuthMethod, "client_secret_post"))) + req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAllowedSigningAlgs, previousSettings.OIDCConnectAllowedSigningAlgs, "RS256,ES256,PS256")) + req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoEmailPath, previousSettings.OIDCConnectUserInfoEmailPath)) + req.OIDCConnectUserInfoIDPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoIDPath, previousSettings.OIDCConnectUserInfoIDPath)) + req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoUsernamePath, previousSettings.OIDCConnectUserInfoUsernamePath)) + if !req.OIDCConnectUsePKCE { + req.OIDCConnectUsePKCE = previousSettings.OIDCConnectUsePKCE } + if !req.OIDCConnectValidateIDToken { + req.OIDCConnectValidateIDToken = previousSettings.OIDCConnectValidateIDToken + } + if req.OIDCConnectClockSkewSeconds == 0 { + req.OIDCConnectClockSkewSeconds = previousSettings.OIDCConnectClockSkewSeconds + if req.OIDCConnectClockSkewSeconds == 0 { + req.OIDCConnectClockSkewSeconds = 120 + } + } + if req.OIDCConnectClientID == "" { response.BadRequest(c, "OIDC Client ID is required when enabled") return diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index ed7764cf..3d933dbc 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -784,6 +784,198 @@ func TestAPIContracts(t *testing.T) { } }`, }, + { + name: "GET /api/v1/admin/settings falls back to config oauth defaults", + setup: func(t *testing.T, deps *contractDeps) { + t.Helper() + deps.cfg.OIDC = config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "ConfigOIDC", + ClientID: "oidc-config-client", + ClientSecret: "oidc-config-secret", + IssuerURL: "https://issuer.example.com", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + Scopes: "openid email profile", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256,ES256,PS256", + ClockSkewSeconds: 120, + } + deps.cfg.WeChat = config.WeChatConnectConfig{ + Enabled: true, + OpenEnabled: true, + OpenAppID: "wx-open-config", + OpenAppSecret: "wx-open-secret", + Mode: "open", + Scopes: "snsapi_login", + FrontendRedirectURL: "/auth/wechat/callback", + } + deps.settingRepo.SetAll(map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyEmailVerifyEnabled: "false", + service.SettingKeyRegistrationEmailSuffixWhitelist: "[]", + }) + }, + method: http.MethodGet, + path: "/api/v1/admin/settings", + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": { + "registration_enabled": true, + "email_verify_enabled": false, + "registration_email_suffix_whitelist": [], + "promo_code_enabled": true, + "password_reset_enabled": false, + "frontend_url": "", + "invitation_code_enabled": false, + "totp_enabled": false, + "totp_encryption_key_configured": false, + "smtp_host": "", + "smtp_port": 587, + "smtp_username": "", + "smtp_password_configured": false, + "smtp_from_email": "", + "smtp_from_name": "", + "smtp_use_tls": false, + "turnstile_enabled": false, + "turnstile_site_key": "", + "turnstile_secret_key_configured": false, + "linuxdo_connect_enabled": false, + "linuxdo_connect_client_id": "", + "linuxdo_connect_client_secret_configured": false, + "linuxdo_connect_redirect_url": "", + "oidc_connect_enabled": true, + "oidc_connect_provider_name": "ConfigOIDC", + "oidc_connect_client_id": "oidc-config-client", + "oidc_connect_client_secret_configured": true, + "oidc_connect_issuer_url": "https://issuer.example.com", + "oidc_connect_discovery_url": "", + "oidc_connect_authorize_url": "", + "oidc_connect_token_url": "", + "oidc_connect_userinfo_url": "", + "oidc_connect_jwks_url": "", + "oidc_connect_scopes": "openid email profile", + "oidc_connect_redirect_url": "https://api.example.com/api/v1/auth/oauth/oidc/callback", + "oidc_connect_frontend_redirect_url": "/auth/oidc/callback", + "oidc_connect_token_auth_method": "client_secret_post", + "oidc_connect_use_pkce": true, + "oidc_connect_validate_id_token": true, + "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256", + "oidc_connect_clock_skew_seconds": 120, + "oidc_connect_require_email_verified": false, + "oidc_connect_userinfo_email_path": "", + "oidc_connect_userinfo_id_path": "", + "oidc_connect_userinfo_username_path": "", + "site_name": "Sub2API", + "site_logo": "", + "site_subtitle": "Subscription to API Conversion Platform", + "api_base_url": "", + "contact_info": "", + "doc_url": "", + "home_content": "", + "hide_ccs_import_button": false, + "purchase_subscription_enabled": false, + "purchase_subscription_url": "", + "table_default_page_size": 20, + "table_page_size_options": [10, 20, 50], + "custom_menu_items": [], + "custom_endpoints": [], + "default_concurrency": 0, + "default_balance": 0, + "default_subscriptions": [], + "enable_model_fallback": false, + "fallback_model_anthropic": "claude-3-5-sonnet-20241022", + "fallback_model_openai": "gpt-4o", + "fallback_model_gemini": "gemini-2.5-pro", + "fallback_model_antigravity": "gemini-2.5-pro", + "enable_identity_patch": true, + "identity_patch_prompt": "", + "ops_monitoring_enabled": false, + "ops_realtime_monitoring_enabled": true, + "ops_query_mode_default": "auto", + "ops_metrics_interval_seconds": 60, + "min_claude_code_version": "", + "max_claude_code_version": "", + "allow_ungrouped_key_scheduling": false, + "backend_mode_enabled": false, + "enable_fingerprint_unification": true, + "enable_metadata_passthrough": false, + "enable_cch_signing": false, + "web_search_emulation_enabled": false, + "payment_visible_method_alipay_source": "", + "payment_visible_method_wxpay_source": "", + "payment_visible_method_alipay_enabled": false, + "payment_visible_method_wxpay_enabled": false, + "openai_advanced_scheduler_enabled": false, + "payment_enabled": false, + "payment_min_amount": 0, + "payment_max_amount": 0, + "payment_daily_limit": 0, + "payment_order_timeout_minutes": 0, + "payment_max_pending_orders": 0, + "payment_enabled_types": null, + "payment_balance_disabled": false, + "payment_balance_recharge_multiplier": 0, + "payment_recharge_fee_rate": 0, + "payment_load_balance_strategy": "", + "payment_product_name_prefix": "", + "payment_product_name_suffix": "", + "payment_help_image_url": "", + "payment_help_text": "", + "payment_cancel_rate_limit_enabled": false, + "payment_cancel_rate_limit_max": 0, + "payment_cancel_rate_limit_window": 0, + "payment_cancel_rate_limit_unit": "", + "payment_cancel_rate_limit_window_mode": "", + "balance_low_notify_enabled": false, + "account_quota_notify_enabled": false, + "balance_low_notify_threshold": 0, + "balance_low_notify_recharge_url": "", + "account_quota_notify_emails": [], + "wechat_connect_enabled": true, + "wechat_connect_app_id": "wx-open-config", + "wechat_connect_app_secret_configured": true, + "wechat_connect_mode": "open", + "wechat_connect_open_enabled": true, + "wechat_connect_open_app_id": "wx-open-config", + "wechat_connect_open_app_secret_configured": true, + "wechat_connect_mp_enabled": false, + "wechat_connect_mp_app_id": "wx-open-config", + "wechat_connect_mp_app_secret_configured": true, + "wechat_connect_mobile_enabled": false, + "wechat_connect_mobile_app_id": "wx-open-config", + "wechat_connect_mobile_app_secret_configured": true, + "wechat_connect_redirect_url": "", + "wechat_connect_frontend_redirect_url": "/auth/wechat/callback", + "wechat_connect_scopes": "snsapi_login", + "auth_source_default_email_balance": 0, + "auth_source_default_email_concurrency": 5, + "auth_source_default_email_subscriptions": [], + "auth_source_default_email_grant_on_signup": false, + "auth_source_default_email_grant_on_first_bind": false, + "auth_source_default_linuxdo_balance": 0, + "auth_source_default_linuxdo_concurrency": 5, + "auth_source_default_linuxdo_subscriptions": [], + "auth_source_default_linuxdo_grant_on_signup": false, + "auth_source_default_linuxdo_grant_on_first_bind": false, + "auth_source_default_oidc_balance": 0, + "auth_source_default_oidc_concurrency": 5, + "auth_source_default_oidc_subscriptions": [], + "auth_source_default_oidc_grant_on_signup": false, + "auth_source_default_oidc_grant_on_first_bind": false, + "auth_source_default_wechat_balance": 0, + "auth_source_default_wechat_concurrency": 5, + "auth_source_default_wechat_subscriptions": [], + "auth_source_default_wechat_grant_on_signup": false, + "auth_source_default_wechat_grant_on_first_bind": false, + "force_email_on_third_party_signup": false + } + }`, + }, { name: "POST /api/v1/admin/accounts/bulk-update", method: http.MethodPost, @@ -827,6 +1019,7 @@ func TestAPIContracts(t *testing.T) { type contractDeps struct { now time.Time router http.Handler + cfg *config.Config apiKeyRepo *stubApiKeyRepo groupRepo *stubGroupRepo userSubRepo *stubUserSubscriptionRepo @@ -947,6 +1140,7 @@ func newContractDeps(t *testing.T) *contractDeps { return &contractDeps{ now: now, router: r, + cfg: cfg, apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userSubRepo: userSubRepo, diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 059bbcd3..72569882 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -245,15 +245,107 @@ func parseWeChatConnectCapabilitySettings(settings map[string]string, enabled bo } func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string { + mode = normalizeWeChatConnectModeSetting(mode) + switch mode { + case "open": + if openEnabled { + return "open" + } + case "mp": + if mpEnabled { + return "mp" + } + case "mobile": + if mobileEnabled { + return "mobile" + } + } switch { + case openEnabled: + return "open" case mpEnabled: return "mp" case mobileEnabled: return "mobile" - case openEnabled: - return "open" default: - return normalizeWeChatConnectModeSetting(mode) + return mode + } +} + +func mergeWeChatConnectCapabilitySettings(settings map[string]string, base config.WeChatConnectConfig, enabled bool, mode string) (bool, bool, bool) { + mode = normalizeWeChatConnectModeSetting(firstNonEmpty(mode, base.Mode)) + rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled] + rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled] + rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled] + openConfigured := hasOpen && strings.TrimSpace(rawOpen) != "" + mpConfigured := hasMP && strings.TrimSpace(rawMP) != "" + mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != "" + + if openConfigured || mpConfigured || mobileConfigured { + return parseWeChatConnectCapabilitySettings(settings, enabled, mode) + } + if !enabled { + return false, false, false + } + if base.OpenEnabled || base.MPEnabled || base.MobileEnabled { + return base.OpenEnabled, base.MPEnabled, base.MobileEnabled + } + return parseWeChatConnectCapabilitySettings(settings, enabled, mode) +} + +func (s *SettingService) effectiveWeChatConnectOAuthConfig(settings map[string]string) WeChatConnectOAuthConfig { + base := config.WeChatConnectConfig{} + if s != nil && s.cfg != nil { + base = s.cfg.WeChat + } + + enabled := base.Enabled + if raw, ok := settings[SettingKeyWeChatConnectEnabled]; ok { + enabled = strings.TrimSpace(raw) == "true" + } + + legacyAppID := strings.TrimSpace(firstNonEmpty( + settings[SettingKeyWeChatConnectAppID], + base.AppID, + base.OpenAppID, + base.MPAppID, + base.MobileAppID, + )) + legacyAppSecret := strings.TrimSpace(firstNonEmpty( + settings[SettingKeyWeChatConnectAppSecret], + base.AppSecret, + base.OpenAppSecret, + base.MPAppSecret, + base.MobileAppSecret, + )) + openAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], base.OpenAppID, legacyAppID)) + openAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], base.OpenAppSecret, legacyAppSecret)) + mpAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], base.MPAppID, legacyAppID)) + mpAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], base.MPAppSecret, legacyAppSecret)) + mobileAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], base.MobileAppID, legacyAppID)) + mobileAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], base.MobileAppSecret, legacyAppSecret)) + + modeRaw := firstNonEmpty(settings[SettingKeyWeChatConnectMode], base.Mode) + openEnabled, mpEnabled, mobileEnabled := mergeWeChatConnectCapabilitySettings(settings, base, enabled, modeRaw) + mode := normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled, modeRaw) + + return WeChatConnectOAuthConfig{ + Enabled: enabled, + LegacyAppID: legacyAppID, + LegacyAppSecret: legacyAppSecret, + OpenAppID: openAppID, + OpenAppSecret: openAppSecret, + MPAppID: mpAppID, + MPAppSecret: mpAppSecret, + MobileAppID: mobileAppID, + MobileAppSecret: mobileAppSecret, + OpenEnabled: openEnabled, + MPEnabled: mpEnabled, + MobileEnabled: mobileEnabled, + Mode: mode, + Scopes: normalizeWeChatConnectScopeSetting(firstNonEmpty(settings[SettingKeyWeChatConnectScopes], base.Scopes), mode), + RedirectURL: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectRedirectURL], base.RedirectURL)), + FrontendRedirectURL: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectFrontendRedirectURL], base.FrontendRedirectURL, defaultWeChatConnectFrontend)), } } @@ -535,32 +627,7 @@ func DefaultWeChatConnectScopesForMode(mode string) string { } func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]string) (WeChatConnectOAuthConfig, error) { - enabled := settings[SettingKeyWeChatConnectEnabled] == "true" - mode := normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode]) - openEnabled, mpEnabled, mobileEnabled := parseWeChatConnectCapabilitySettings(settings, enabled, mode) - mode = normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled, mode) - - cfg := WeChatConnectOAuthConfig{ - Enabled: enabled, - LegacyAppID: strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]), - LegacyAppSecret: strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]), - OpenAppID: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], settings[SettingKeyWeChatConnectAppID])), - OpenAppSecret: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], settings[SettingKeyWeChatConnectAppSecret])), - MPAppID: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], settings[SettingKeyWeChatConnectAppID])), - MPAppSecret: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], settings[SettingKeyWeChatConnectAppSecret])), - MobileAppID: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], settings[SettingKeyWeChatConnectAppID])), - MobileAppSecret: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], settings[SettingKeyWeChatConnectAppSecret])), - OpenEnabled: openEnabled, - MPEnabled: mpEnabled, - MobileEnabled: mobileEnabled, - Mode: mode, - Scopes: normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], mode), - RedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]), - FrontendRedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]), - } - if cfg.FrontendRedirectURL == "" { - cfg.FrontendRedirectURL = defaultWeChatConnectFrontend - } + cfg := s.effectiveWeChatConnectOAuthConfig(settings) if !cfg.Enabled || (!cfg.OpenEnabled && !cfg.MPEnabled) { return WeChatConnectOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") @@ -589,14 +656,10 @@ func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]strin return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app secret not configured") } } - if cfg.RedirectURL == "" { - return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured") - } - if cfg.FrontendRedirectURL == "" { - return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url not configured") - } - if err := config.ValidateAbsoluteHTTPURL(cfg.RedirectURL); err != nil { - return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid") + if v := strings.TrimSpace(cfg.RedirectURL); v != "" { + if err := config.ValidateAbsoluteHTTPURL(v); err != nil { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid") + } } if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil { return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url invalid") @@ -605,31 +668,14 @@ func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]strin } func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string]string) (bool, bool, bool, bool) { - if settings[SettingKeyWeChatConnectEnabled] != "true" { + cfg := s.effectiveWeChatConnectOAuthConfig(settings) + if !cfg.Enabled { return false, false, false, false } - mode := normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode]) - openEnabled, mpEnabled, mobileEnabled := parseWeChatConnectCapabilitySettings(settings, true, mode) - redirectURL := strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]) - frontendRedirectURL := strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]) - if frontendRedirectURL == "" { - frontendRedirectURL = defaultWeChatConnectFrontend - } - - legacyAppID := strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]) - legacyAppSecret := strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]) - openAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], legacyAppID)) - openAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], legacyAppSecret)) - mpAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], legacyAppID)) - mpAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], legacyAppSecret)) - mobileAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], legacyAppID)) - mobileAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], legacyAppSecret)) - - webRedirectReady := redirectURL != "" && frontendRedirectURL != "" - openReady := openEnabled && webRedirectReady && openAppID != "" && openAppSecret != "" - mpReady := mpEnabled && webRedirectReady && mpAppID != "" && mpAppSecret != "" - mobileReady := mobileEnabled && mobileAppID != "" && mobileAppSecret != "" + openReady := cfg.OpenEnabled && cfg.AppIDForMode("open") != "" && cfg.AppSecretForMode("open") != "" + mpReady := cfg.MPEnabled && cfg.AppIDForMode("mp") != "" && cfg.AppSecretForMode("mp") != "" + mobileReady := cfg.MobileEnabled && cfg.AppIDForMode("mobile") != "" && cfg.AppSecretForMode("mobile") != "" return openReady || mpReady, openReady, mpReady, mobileReady } @@ -1436,6 +1482,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyCustomMenuItems: "[]", SettingKeyCustomEndpoints: "[]", SettingKeyWeChatConnectEnabled: "false", + SettingKeyWeChatConnectAppID: "", + SettingKeyWeChatConnectAppSecret: "", SettingKeyWeChatConnectOpenAppID: "", SettingKeyWeChatConnectOpenAppSecret: "", SettingKeyWeChatConnectMPAppID: "", @@ -1447,9 +1495,30 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyWeChatConnectMobileEnabled: "false", SettingKeyWeChatConnectMode: "open", SettingKeyWeChatConnectScopes: "snsapi_login", + SettingKeyWeChatConnectRedirectURL: "", SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend, SettingKeyOIDCConnectEnabled: "false", SettingKeyOIDCConnectProviderName: "OIDC", + SettingKeyOIDCConnectClientID: "", + SettingKeyOIDCConnectClientSecret: "", + SettingKeyOIDCConnectIssuerURL: "", + SettingKeyOIDCConnectDiscoveryURL: "", + SettingKeyOIDCConnectAuthorizeURL: "", + SettingKeyOIDCConnectTokenURL: "", + SettingKeyOIDCConnectUserInfoURL: "", + SettingKeyOIDCConnectJWKSURL: "", + SettingKeyOIDCConnectScopes: "openid email profile", + SettingKeyOIDCConnectRedirectURL: "", + SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", + SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", + SettingKeyOIDCConnectUsePKCE: "true", + SettingKeyOIDCConnectValidateIDToken: "true", + SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", + SettingKeyOIDCConnectClockSkewSeconds: "120", + SettingKeyOIDCConnectRequireEmailVerified: "false", + SettingKeyOIDCConnectUserInfoEmailPath: "", + SettingKeyOIDCConnectUserInfoIDPath: "", + SettingKeyOIDCConnectUserInfoUsernamePath: "", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultSubscriptions: "[]", @@ -1737,37 +1806,30 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != "" - // WeChat Connect 设置:完全以 DB 系统设置为准。 - result.WeChatConnectEnabled = settings[SettingKeyWeChatConnectEnabled] == "true" - result.WeChatConnectAppID = strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]) - result.WeChatConnectAppSecret = strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]) - result.WeChatConnectAppSecretConfigured = result.WeChatConnectAppSecret != "" - result.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], result.WeChatConnectAppID)) - result.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], result.WeChatConnectAppSecret)) - result.WeChatConnectOpenAppSecretConfigured = result.WeChatConnectOpenAppSecret != "" - result.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], result.WeChatConnectAppID)) - result.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], result.WeChatConnectAppSecret)) - result.WeChatConnectMPAppSecretConfigured = result.WeChatConnectMPAppSecret != "" - result.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], result.WeChatConnectAppID)) - result.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], result.WeChatConnectAppSecret)) - result.WeChatConnectMobileAppSecretConfigured = result.WeChatConnectMobileAppSecret != "" - result.WeChatConnectOpenEnabled, result.WeChatConnectMPEnabled, result.WeChatConnectMobileEnabled = parseWeChatConnectCapabilitySettings( - settings, - result.WeChatConnectEnabled, - settings[SettingKeyWeChatConnectMode], - ) - result.WeChatConnectMode = normalizeWeChatConnectStoredMode( - result.WeChatConnectOpenEnabled, - result.WeChatConnectMPEnabled, - result.WeChatConnectMobileEnabled, - settings[SettingKeyWeChatConnectMode], - ) - result.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], result.WeChatConnectMode) - result.WeChatConnectRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]) - result.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]) - if result.WeChatConnectFrontendRedirectURL == "" { - result.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend - } + // WeChat Connect 设置: + // - 优先读取 DB 系统设置 + // - 缺失时回退到 config/env,保持升级兼容 + weChatEffective := s.effectiveWeChatConnectOAuthConfig(settings) + result.WeChatConnectEnabled = weChatEffective.Enabled + result.WeChatConnectAppID = weChatEffective.LegacyAppID + result.WeChatConnectAppSecret = weChatEffective.LegacyAppSecret + result.WeChatConnectAppSecretConfigured = weChatEffective.LegacyAppSecret != "" + result.WeChatConnectOpenAppID = weChatEffective.OpenAppID + result.WeChatConnectOpenAppSecret = weChatEffective.OpenAppSecret + result.WeChatConnectOpenAppSecretConfigured = weChatEffective.OpenAppSecret != "" + result.WeChatConnectMPAppID = weChatEffective.MPAppID + result.WeChatConnectMPAppSecret = weChatEffective.MPAppSecret + result.WeChatConnectMPAppSecretConfigured = weChatEffective.MPAppSecret != "" + result.WeChatConnectMobileAppID = weChatEffective.MobileAppID + result.WeChatConnectMobileAppSecret = weChatEffective.MobileAppSecret + result.WeChatConnectMobileAppSecretConfigured = weChatEffective.MobileAppSecret != "" + result.WeChatConnectOpenEnabled = weChatEffective.OpenEnabled + result.WeChatConnectMPEnabled = weChatEffective.MPEnabled + result.WeChatConnectMobileEnabled = weChatEffective.MobileEnabled + result.WeChatConnectMode = weChatEffective.Mode + result.WeChatConnectScopes = weChatEffective.Scopes + result.WeChatConnectRedirectURL = weChatEffective.RedirectURL + result.WeChatConnectFrontendRedirectURL = weChatEffective.FrontendRedirectURL // Model fallback settings result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" diff --git a/backend/internal/service/setting_service_oidc_config_test.go b/backend/internal/service/setting_service_oidc_config_test.go index a5a3959a..eb312d2c 100644 --- a/backend/internal/service/setting_service_oidc_config_test.go +++ b/backend/internal/service/setting_service_oidc_config_test.go @@ -115,6 +115,22 @@ func TestSettingService_ParseSettings_PreservesOptionalOIDCCompatibilityFlags(t require.False(t, got.OIDCConnectValidateIDToken) } +func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValues(t *testing.T) { + svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{ + OIDC: config.OIDCConnectConfig{ + UsePKCE: true, + ValidateIDToken: true, + }, + }) + + got := svc.parseSettings(map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + }) + + require.True(t, got.OIDCConnectUsePKCE) + require.True(t, got.OIDCConnectValidateIDToken) +} + func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) { cfg := &config.Config{ OIDC: config.OIDCConnectConfig{ @@ -145,3 +161,37 @@ func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTok require.False(t, got.UsePKCE) require.False(t, got.ValidateIDToken) } + +func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *testing.T) { + cfg := &config.Config{ + OIDC: config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/auth", + TokenURL: "https://issuer.example.com/token", + UserInfoURL: "https://issuer.example.com/userinfo", + JWKSURL: "https://issuer.example.com/jwks", + RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + Scopes: "openid email profile", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + }, + } + + repo := &settingOIDCRepoStub{values: map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + }} + svc := NewSettingService(repo, cfg) + + got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.UsePKCE) + require.True(t, got.ValidateIDToken) +} diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 4c7ca14b..1ecd4e6f 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -132,3 +132,22 @@ func TestSettingService_GetPublicSettings_DoesNotExposeMobileOnlyWeChatAsWebOAut require.False(t, settings.WeChatOAuthMPEnabled) require.True(t, settings.WeChatOAuthMobileEnabled) } + +func TestSettingService_GetPublicSettings_FallsBackToConfigForWeChatOAuthCapabilities(t *testing.T) { + svc := NewSettingService(&settingPublicRepoStub{values: map[string]string{}}, &config.Config{ + WeChat: config.WeChatConnectConfig{ + Enabled: true, + OpenEnabled: true, + OpenAppID: "wx-open-config", + OpenAppSecret: "wx-open-secret", + FrontendRedirectURL: "/auth/wechat/config-callback", + }, + }) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.WeChatOAuthEnabled) + require.True(t, settings.WeChatOAuthOpenEnabled) + require.False(t, settings.WeChatOAuthMPEnabled) + require.False(t, settings.WeChatOAuthMobileEnabled) +} diff --git a/backend/internal/service/setting_service_wechat_config_test.go b/backend/internal/service/setting_service_wechat_config_test.go index 73d86e8f..08f67b7c 100644 --- a/backend/internal/service/setting_service_wechat_config_test.go +++ b/backend/internal/service/setting_service_wechat_config_test.go @@ -79,3 +79,54 @@ func TestSettingService_GetWeChatConnectOAuthConfig_UsesDatabaseOverrides(t *tes require.Equal(t, "https://api.example.com/api/v1/auth/oauth/wechat/callback", got.RedirectURL) require.Equal(t, "/auth/wechat/callback", got.FrontendRedirectURL) } + +func TestSettingService_GetWeChatConnectOAuthConfig_FallsBackToConfigWhenDatabaseEmpty(t *testing.T) { + repo := &settingWeChatRepoStub{values: map[string]string{}} + svc := NewSettingService(repo, &config.Config{ + WeChat: config.WeChatConnectConfig{ + Enabled: true, + OpenEnabled: true, + MPEnabled: true, + Mode: "open", + OpenAppID: "wx-open-config", + OpenAppSecret: "wx-open-secret", + MPAppID: "wx-mp-config", + MPAppSecret: "wx-mp-secret", + FrontendRedirectURL: "/auth/wechat/config-callback", + }, + }) + + got, err := svc.GetWeChatConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.Enabled) + require.True(t, got.OpenEnabled) + require.True(t, got.MPEnabled) + require.Equal(t, "wx-open-config", got.AppIDForMode("open")) + require.Equal(t, "wx-open-secret", got.AppSecretForMode("open")) + require.Equal(t, "wx-mp-config", got.AppIDForMode("mp")) + require.Equal(t, "wx-mp-secret", got.AppSecretForMode("mp")) + require.Equal(t, "/auth/wechat/config-callback", got.FrontendRedirectURL) + require.Empty(t, got.RedirectURL) +} + +func TestSettingService_ParseSettings_FallsBackToConfigForWeChatAdminView(t *testing.T) { + svc := NewSettingService(&settingWeChatRepoStub{values: map[string]string{}}, &config.Config{ + WeChat: config.WeChatConnectConfig{ + Enabled: true, + OpenEnabled: true, + Mode: "open", + OpenAppID: "wx-open-config", + OpenAppSecret: "wx-open-secret", + FrontendRedirectURL: "/auth/wechat/config-callback", + }, + }) + + got := svc.parseSettings(map[string]string{}) + require.True(t, got.WeChatConnectEnabled) + require.True(t, got.WeChatConnectOpenEnabled) + require.Equal(t, "wx-open-config", got.WeChatConnectOpenAppID) + require.True(t, got.WeChatConnectOpenAppSecretConfigured) + require.Equal(t, "/auth/wechat/config-callback", got.WeChatConnectFrontendRedirectURL) + require.Equal(t, "open", got.WeChatConnectMode) + require.Equal(t, "snsapi_login", got.WeChatConnectScopes) +} From d6a04bb772bee52a501c8d69cbf6a82a87afb1f4 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 12:30:17 +0800 Subject: [PATCH 12/31] fix(payment): support source routing and compatible resume signing --- .../handler/payment_handler_resume_test.go | 120 +++++++++++++- .../internal/service/payment_config_limits.go | 14 +- .../service/payment_config_limits_test.go | 149 +++++++++++------- .../service/payment_config_providers_test.go | 41 +++-- .../service/payment_order_lifecycle.go | 29 ++++ .../service/payment_order_result_test.go | 41 +++++ .../internal/service/payment_resume_lookup.go | 20 ++- .../service/payment_resume_lookup_test.go | 13 +- .../service/payment_resume_service.go | 47 +++++- .../service/payment_resume_service_test.go | 71 +++++++-- backend/internal/service/payment_service.go | 48 +++++- .../payment_visible_method_instances.go | 113 +++++++------ 12 files changed, 570 insertions(+), 136 deletions(-) diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go index 5a2ecb46..a7bc4ba3 100644 --- a/backend/internal/handler/payment_handler_resume_test.go +++ b/backend/internal/handler/payment_handler_resume_test.go @@ -164,9 +164,8 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) { } func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) { - t.Parallel() - gin.SetMode(gin.TestMode) + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef") db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared") require.NoError(t, err) @@ -250,3 +249,120 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing require.Contains(t, resp.Data, "expires_at") require.Contains(t, resp.Data, "refund_amount") } + +func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef") + + db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + user, err := client.User.Create(). + SetEmail("public-resolve-mismatch@example.com"). + SetPasswordHash("hash"). + SetUsername("public-resolve-mismatch-user"). + Save(context.Background()) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(100). + SetPayAmount(103). + SetFeeRate(0.03). + SetRechargeCode("PUBLIC-RESOLVE-MISMATCH"). + SetOutTradeNo("resolve-order-mismatch-no"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-public-resolve-mismatch"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(service.OrderStatusPaid). + SetExpiresAt(time.Now().Add(time.Hour)). + SetPaidAt(time.Now()). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(context.Background()) + require.NoError(t, err) + + resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID + 999, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef")) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil) + h := NewPaymentHandler(paymentSvc, nil, nil) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest( + http.MethodPost, + "/api/v1/payment/public/orders/resolve", + bytes.NewBufferString(`{"resume_token":"`+token+`"}`), + ) + ctx.Request.Header.Set("Content-Type", "application/json") + + h.ResolveOrderPublicByResumeToken(ctx) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + var resp struct { + Code int `json:"code"` + Reason string `json:"reason"` + Message string `json:"message"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, http.StatusBadRequest, resp.Code) + require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason) +} + +func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) { + gin.SetMode(gin.TestMode) + + db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil) + h := NewPaymentHandler(paymentSvc, nil, nil) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest( + http.MethodPost, + "/api/v1/payment/public/orders/verify", + bytes.NewBufferString(`{"out_trade_no":" "}`), + ) + ctx.Request.Header.Set("Content-Type", "application/json") + + h.VerifyOrderPublic(ctx) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + var resp struct { + Code int `json:"code"` + Reason string `json:"reason"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, http.StatusBadRequest, resp.Code) + require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason) +} diff --git a/backend/internal/service/payment_config_limits.go b/backend/internal/service/payment_config_limits.go index 57a4108f..e44bf2e7 100644 --- a/backend/internal/service/payment_config_limits.go +++ b/backend/internal/service/payment_config_limits.go @@ -20,7 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M return nil, fmt.Errorf("query provider instances: %w", err) } typeInstances := pcGroupByPaymentType(instances) - typeInstances = pcApplyEnabledVisibleMethodInstances(typeInstances, instances) + typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances) resp := &MethodLimitsResponse{ Methods: make(map[string]MethodLimits, len(typeInstances)), } @@ -32,7 +32,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M return resp, nil } -func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance { +func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance { if len(typeInstances) == 0 { return typeInstances } @@ -44,11 +44,17 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} { matching := filterEnabledVisibleMethodInstances(instances, method) - if len(matching) != 1 { + providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching) + if err != nil || providerKey == "" { delete(filtered, method) continue } - filtered[method] = []*dbent.PaymentProviderInstance{matching[0]} + selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey) + if len(selectedInstances) == 0 { + delete(filtered, method) + continue + } + filtered[method] = selectedInstances } return filtered } diff --git a/backend/internal/service/payment_config_limits_test.go b/backend/internal/service/payment_config_limits_test.go index b3925583..12cd6866 100644 --- a/backend/internal/service/payment_config_limits_test.go +++ b/backend/internal/service/payment_config_limits_test.go @@ -301,65 +301,104 @@ func TestPcInstanceTypeLimits(t *testing.T) { }) } -func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testing.T) { - ctx := context.Background() - client := newPaymentConfigServiceTestClient(t) - - _, err := client.PaymentProviderInstance.Create(). - SetProviderKey(payment.TypeAlipay). - SetName("Official Alipay"). - SetConfig("{}"). - SetSupportedTypes("alipay"). - SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`). - SetEnabled(true). - Save(ctx) - if err != nil { - t.Fatalf("create official alipay instance: %v", err) - } - _, err = client.PaymentProviderInstance.Create(). - SetProviderKey(payment.TypeEasyPay). - SetName("EasyPay Alipay"). - SetConfig("{}"). - SetSupportedTypes("alipay"). - SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`). - SetEnabled(true). - Save(ctx) - if err != nil { - t.Fatalf("create easypay alipay instance: %v", err) - } - _, err = client.PaymentProviderInstance.Create(). - SetProviderKey(payment.TypeWxpay). - SetName("Official WeChat"). - SetConfig("{}"). - SetSupportedTypes("wxpay"). - SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`). - SetEnabled(true). - Save(ctx) - if err != nil { - t.Fatalf("create official wxpay instance: %v", err) +func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) { + tests := []struct { + name string + sourceSetting string + wantAlipaySingleMin float64 + wantAlipaySingleMax float64 + wantGlobalMin float64 + wantGlobalMax float64 + }{ + { + name: "official source", + sourceSetting: VisibleMethodSourceOfficialAlipay, + wantAlipaySingleMin: 10, + wantAlipaySingleMax: 100, + wantGlobalMin: 10, + wantGlobalMax: 300, + }, + { + name: "easypay source", + sourceSetting: VisibleMethodSourceEasyPayAlipay, + wantAlipaySingleMin: 20, + wantAlipaySingleMax: 200, + wantGlobalMin: 20, + wantGlobalMax: 300, + }, } - svc := &PaymentConfigService{ - entClient: client, - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) - resp, err := svc.GetAvailableMethodLimits(ctx) - if err != nil { - t.Fatalf("GetAvailableMethodLimits returned error: %v", err) - } + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Official Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create official alipay instance: %v", err) + } + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create easypay alipay instance: %v", err) + } + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("Official WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create official wxpay instance: %v", err) + } - if _, ok := resp.Methods[payment.TypeAlipay]; ok { - t.Fatalf("alipay should be hidden when multiple enabled providers claim it, got %v", resp.Methods[payment.TypeAlipay]) - } + svc := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting, + }, + }, + } - wxpayLimits, ok := resp.Methods[payment.TypeWxpay] - if !ok { - t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods) - } - if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 { - t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits) - } - if resp.GlobalMin != 30 || resp.GlobalMax != 300 { - t.Fatalf("global range = (%v, %v), want (30, 300)", resp.GlobalMin, resp.GlobalMax) + resp, err := svc.GetAvailableMethodLimits(ctx) + if err != nil { + t.Fatalf("GetAvailableMethodLimits returned error: %v", err) + } + + alipayLimits, ok := resp.Methods[payment.TypeAlipay] + if !ok { + t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods) + } + if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax { + t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax) + } + + wxpayLimits, ok := resp.Methods[payment.TypeWxpay] + if !ok { + t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods) + } + if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 { + t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits) + } + if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax { + t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax) + } + }) } } diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go index 2c0f8206..51d5c7b6 100644 --- a/backend/internal/service/payment_config_providers_test.go +++ b/backend/internal/service/payment_config_providers_test.go @@ -4,9 +4,12 @@ package service import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "testing" - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -199,7 +202,7 @@ func TestJoinTypes(t *testing.T) { } } -func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *testing.T) { +func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) { t.Parallel() ctx := context.Background() @@ -227,15 +230,14 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test _, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ ProviderKey: "alipay", Name: "Official Alipay", - Config: map[string]string{"appId": "app-1"}, + Config: map[string]string{"appId": "app-1", "privateKey": "private-key"}, SupportedTypes: []string{"alipay"}, Enabled: true, }) - require.Error(t, err) - require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err)) + require.NoError(t, err) } -func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t *testing.T) { +func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) { t.Parallel() ctx := context.Background() @@ -264,7 +266,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ ProviderKey: "wxpay", Name: "Official WeChat", - Config: map[string]string{"appId": "wx-app"}, + Config: validWxpayProviderConfig(t), SupportedTypes: []string{"wxpay"}, Enabled: false, }) @@ -273,8 +275,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t _, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{ Enabled: boolPtrValue(true), }) - require.Error(t, err) - require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err)) + require.NoError(t, err) } func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) { @@ -317,3 +318,25 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) { func boolPtrValue(v bool) *bool { return &v } + +func validWxpayProviderConfig(t *testing.T) map[string]string { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + privDER, err := x509.MarshalPKCS8PrivateKey(key) + require.NoError(t, err) + pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + require.NoError(t, err) + + return map[string]string{ + "appId": "wx-app-test", + "mchId": "mch-test", + "privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})), + "apiV3Key": "12345678901234567890123456789012", + "publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})), + "publicKeyId": "public-key-id-test", + "certSerial": "cert-serial-test", + } +} diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index ffb63066..f14dc55d 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -234,6 +234,10 @@ func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, current // if a payment was made, and processes it if so. This handles the case where // the provider's notify callback was missed (e.g. EasyPay popup mode). func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) { + outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo) + if err != nil { + return nil, err + } o, err := s.entClient.PaymentOrder.Query(). Where(paymentorder.OutTradeNo(outTradeNo)). Only(ctx) @@ -261,6 +265,10 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo // triggering any upstream reconciliation. Signed resume-token recovery is the // only public recovery path allowed to query upstream state. func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) { + outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo) + if err != nil { + return nil, err + } o, err := s.entClient.PaymentOrder.Query(). Where(paymentorder.OutTradeNo(outTradeNo)). Only(ctx) @@ -270,6 +278,27 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin return o, nil } +func normalizeOrderLookupOutTradeNo(raw string) (string, error) { + outTradeNo := strings.TrimSpace(raw) + if outTradeNo == "" { + return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required") + } + if len(outTradeNo) > 64 { + return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid") + } + for _, ch := range outTradeNo { + switch { + case ch >= 'a' && ch <= 'z': + case ch >= 'A' && ch <= 'Z': + case ch >= '0' && ch <= '9': + case ch == '_' || ch == '-': + default: + return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid") + } + } + return outTradeNo, nil +} + func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) { now := time.Now() orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx) diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go index 23371cfd..2d7412e0 100644 --- a/backend/internal/service/payment_order_result_test.go +++ b/backend/internal/service/payment_order_result_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "strings" "testing" "time" @@ -91,6 +92,8 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) { } func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) { + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef") + svc := newWeChatPaymentOAuthTestService(map[string]string{ SettingKeyWeChatConnectEnabled: "true", SettingKeyWeChatConnectAppID: "wx123456", @@ -198,6 +201,44 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testin } } +func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) { + svc := &PaymentService{ + configService: &PaymentConfigService{ + settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }}, + // Legacy stable signing key remains available for no-config upgrade compatibility. + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + }, + } + + resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ + Amount: 12.5, + PaymentType: payment.TypeWxpay, + IsWeChatBrowser: true, + SrcURL: "https://merchant.example/payment?from=wechat", + OrderType: payment.OrderTypeBalance, + }, 12.5, 12.88, 0.03) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if resp == nil { + t.Fatal("expected oauth-required response, got nil") + } + if resp.ResultType != payment.CreatePaymentResultOAuthRequired { + t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired) + } + if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" { + t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth) + } +} + func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) { svc := newWeChatPaymentOAuthTestService(map[string]string{ SettingKeyWeChatConnectEnabled: "true", diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go index 05626aa6..1ff061e8 100644 --- a/backend/internal/service/payment_resume_lookup.go +++ b/backend/internal/service/payment_resume_lookup.go @@ -6,6 +6,7 @@ import ( "strings" dbent "github.com/Wei-Shaw/sub2api/ent" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) { @@ -16,10 +17,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID) if err != nil { + if dbent.IsNotFound(err) { + return nil, infraerrors.NotFound("NOT_FOUND", "order not found") + } return nil, fmt.Errorf("get order by resume token: %w", err) } if claims.UserID > 0 && order.UserID != claims.UserID { - return nil, fmt.Errorf("resume token user mismatch") + return nil, invalidResumeTokenMatchError() } snapshot := psOrderProviderSnapshot(order) orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID)) @@ -33,13 +37,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token } } if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID { - return nil, fmt.Errorf("resume token provider instance mismatch") + return nil, invalidResumeTokenMatchError() } - if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey { - return nil, fmt.Errorf("resume token provider key mismatch") + if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) { + return nil, invalidResumeTokenMatchError() } - if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType { - return nil, fmt.Errorf("resume token payment type mismatch") + if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) { + return nil, invalidResumeTokenMatchError() } if order.Status == OrderStatusPending || order.Status == OrderStatusExpired { result := s.checkPaid(ctx, order) @@ -54,6 +58,10 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token return order, nil } +func invalidResumeTokenMatchError() error { + return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order") +} + func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) { return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token)) } diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go index 946e7aa2..a7b5b737 100644 --- a/backend/internal/service/payment_resume_lookup_test.go +++ b/backend/internal/service/payment_resume_lookup_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/require" ) @@ -143,7 +144,7 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) { _, err = svc.GetPublicOrderByResumeToken(ctx, token) require.Error(t, err) - require.Contains(t, err.Error(), "resume token") + require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err)) } func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) { @@ -302,3 +303,13 @@ func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) { require.Equal(t, order.ID, got.ID) require.Equal(t, 0, provider.queryCount) } + +func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) { + svc := &PaymentService{ + entClient: newPaymentConfigServiceTestClient(t), + } + + _, err := svc.VerifyOrderPublic(context.Background(), " ") + require.Error(t, err) + require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err)) +} diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go index 438aa59f..9ae62fde 100644 --- a/backend/internal/service/payment_resume_service.go +++ b/backend/internal/service/payment_resume_service.go @@ -1,6 +1,7 @@ package service import ( + "bytes" "context" "crypto/hmac" "crypto/sha256" @@ -68,6 +69,7 @@ type WeChatPaymentResumeClaims struct { type PaymentResumeService struct { signingKey []byte + verifyKeys [][]byte } type visibleMethodLoadBalancer struct { @@ -75,8 +77,29 @@ type visibleMethodLoadBalancer struct { configService *PaymentConfigService } -func NewPaymentResumeService(signingKey []byte) *PaymentResumeService { - return &PaymentResumeService{signingKey: signingKey} +func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService { + svc := &PaymentResumeService{} + if len(signingKey) > 0 { + svc.signingKey = append([]byte(nil), signingKey...) + svc.verifyKeys = append(svc.verifyKeys, svc.signingKey) + } + for _, fallback := range verifyFallbacks { + if len(fallback) == 0 { + continue + } + cloned := append([]byte(nil), fallback...) + duplicate := false + for _, existing := range svc.verifyKeys { + if bytes.Equal(existing, cloned) { + duplicate = true + break + } + } + if !duplicate { + svc.verifyKeys = append(svc.verifyKeys, cloned) + } + } + return svc } func (s *PaymentResumeService) isSigningConfigured() bool { @@ -410,7 +433,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error { if len(parts) != 2 || parts[0] == "" || parts[1] == "" { return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed") } - if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) { + if !s.verifySignature(parts[0], parts[1]) { return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch") } payload, err := base64.RawURLEncoding.DecodeString(parts[0]) @@ -420,6 +443,18 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error { return json.Unmarshal(payload, dest) } +func (s *PaymentResumeService) verifySignature(payload string, signature string) bool { + if s == nil { + return false + } + for _, key := range s.verifyKeys { + if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) { + return true + } + } + return false +} + func validatePaymentResumeExpiry(expiresAt int64, code, message string) error { if expiresAt <= 0 { return nil @@ -431,7 +466,11 @@ func validatePaymentResumeExpiry(expiresAt int64, code, message string) error { } func (s *PaymentResumeService) sign(payload string) string { - mac := hmac.New(sha256.New, s.signingKey) + return signPaymentResumePayload(payload, s.signingKey) +} + +func signPaymentResumePayload(payload string, key []byte) string { + mac := hmac.New(sha256.New, key) _, _ = mac.Write([]byte(payload)) return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) } diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go index e19e0b99..9e756971 100644 --- a/backend/internal/service/payment_resume_service_test.go +++ b/backend/internal/service/payment_resume_service_test.go @@ -334,6 +334,59 @@ func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) { } } +func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) { + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key") + + token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-explicit-key", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + svc := &PaymentService{ + configService: &PaymentConfigService{ + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + }, + } + + claims, err := svc.ParseWeChatPaymentResumeToken(token) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if claims.OpenID != "openid-explicit-key" { + t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key") + } +} + +func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) { + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key") + + legacyKey := []byte("0123456789abcdef0123456789abcdef") + token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-legacy-key", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + svc := &PaymentService{ + configService: &PaymentConfigService{ + encryptionKey: legacyKey, + }, + } + + claims, err := svc.ParseWeChatPaymentResumeToken(token) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if claims.OpenID != "openid-legacy-key" { + t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key") + } +} + func TestNormalizeVisibleMethodSource(t *testing.T) { t.Parallel() @@ -424,14 +477,14 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl t.Parallel() tests := []struct { - name string - method payment.PaymentType - officialName string - officialTypes string - easyPayName string - easyPayTypes string - sourceSetting string - wantProvider string + name string + method payment.PaymentType + officialName string + officialTypes string + easyPayName string + easyPayTypes string + sourceSetting string + wantProvider string }{ { name: "alipay uses official source", @@ -487,7 +540,7 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl officialProviderKey = payment.TypeWxpay } - _, err = client.PaymentProviderInstance.Create(). + _, err := client.PaymentProviderInstance.Create(). SetProviderKey(officialProviderKey). SetName(tt.officialName). SetConfig("{}"). diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 73bbb256..159f97d3 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -1,10 +1,14 @@ package service import ( + "bytes" "context" + "encoding/hex" "fmt" "log/slog" "math/rand/v2" + "os" + "strings" "sync" "time" @@ -44,6 +48,8 @@ const ( orderIDPrefix = "sub2_" ) +const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY" + // --- Types --- // generateOutTradeNo creates a unique external order ID for payment providers. @@ -179,7 +185,7 @@ type PaymentService struct { func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} - svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService)) + svc.resumeService = psNewPaymentResumeService(configService) return svc } @@ -259,16 +265,54 @@ func (s *PaymentService) paymentResume() *PaymentResumeService { if s.resumeService != nil { return s.resumeService } - return NewPaymentResumeService(psResumeSigningKey(s.configService)) + return psNewPaymentResumeService(s.configService) +} + +func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService { + signingKey, verifyFallbacks := psResumeSigningKeys(configService) + return NewPaymentResumeService(signingKey, verifyFallbacks...) } func psResumeSigningKey(configService *PaymentConfigService) []byte { + signingKey, _ := psResumeSigningKeys(configService) + return signingKey +} + +func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte) { + signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv)) + legacyKey := psResumeLegacyVerificationKey(configService) + if len(signingKey) == 0 { + if len(legacyKey) == 0 { + return nil, nil + } + return legacyKey, nil + } + if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) { + return signingKey, nil + } + return signingKey, [][]byte{legacyKey} +} + +func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte { if configService == nil { return nil } return configService.encryptionKey } +func parsePaymentResumeSigningKey(raw string) []byte { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + if len(raw) >= 64 && len(raw)%2 == 0 { + if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 { + return decoded + } + } + return []byte(raw) +} + func psSliceContains(sl []string, s string) bool { for _, v := range sl { if v == s { diff --git a/backend/internal/service/payment_visible_method_instances.go b/backend/internal/service/payment_visible_method_instances.go index 39358522..86ea5ead 100644 --- a/backend/internal/service/payment_visible_method_instances.go +++ b/backend/internal/service/payment_visible_method_instances.go @@ -82,6 +82,41 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta return filtered } +func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance { + filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances)) + for _, inst := range instances { + if !providerSupportsVisibleMethod(inst, method) { + continue + } + if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) { + continue + } + filtered = append(filtered, inst) + } + return filtered +} + +func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string { + seen := make(map[string]struct{}, len(instances)) + keys := make([]string, 0, len(instances)) + for _, inst := range instances { + if inst == nil { + continue + } + key := strings.TrimSpace(inst.ProviderKey) + if key == "" { + continue + } + normalized := strings.ToLower(key) + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + keys = append(keys, key) + } + return keys +} + func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance { providerKey = strings.TrimSpace(providerKey) if providerKey == "" { @@ -117,32 +152,10 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts( supportedTypes string, enabled bool, ) error { - if s == nil || s.entClient == nil || !enabled { - return nil - } - - claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes) - if len(claimedMethods) == 0 { - return nil - } - - query := s.entClient.PaymentProviderInstance.Query(). - Where(paymentproviderinstance.EnabledEQ(true)) - if excludeID > 0 { - query = query.Where(paymentproviderinstance.IDNEQ(excludeID)) - } - instances, err := query.All(ctx) - if err != nil { - return fmt.Errorf("query enabled payment providers: %w", err) - } - - for _, method := range claimedMethods { - for _, inst := range instances { - if providerSupportsVisibleMethod(inst, method) { - return buildPaymentProviderConflictError(method, inst) - } - } - } + // Visible methods are selected by configured source (official/easypay), + // so multiple enabled providers can intentionally claim the same user-facing + // method. Order creation and limits will route through the configured source. + _, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled return nil } @@ -172,6 +185,32 @@ func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context return providerKey, nil } +func (s *PaymentConfigService) resolveVisibleMethodProviderKey( + ctx context.Context, + method string, + matching []*dbent.PaymentProviderInstance, +) (string, error) { + switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) { + case 0: + return "", nil + case 1: + return strings.TrimSpace(providerKeys[0]), nil + default: + providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method) + if err != nil { + return "", err + } + selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey) + if selected == nil { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source has no enabled provider instance", method), + ) + } + return strings.TrimSpace(selected.ProviderKey), nil + } +} + func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance( ctx context.Context, method string, @@ -194,23 +233,9 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance( } matching := filterEnabledVisibleMethodInstances(instances, method) - switch len(matching) { - case 0: - return nil, nil - case 1: - return matching[0], nil - default: - providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method) - if err != nil { - return nil, err - } - selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey) - if selected == nil { - return nil, infraerrors.BadRequest( - "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", - fmt.Sprintf("%s source has no enabled provider instance", method), - ) - } - return selected, nil + providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching) + if err != nil { + return nil, err } + return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil } From 29caf85104bea864eaf1d2d33ba5759b8c835ff1 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 12:30:24 +0800 Subject: [PATCH 13/31] fix(frontend): stabilize wechat payment resume recovery --- frontend/src/views/user/PaymentResultView.vue | 88 +++++++++---- frontend/src/views/user/PaymentView.vue | 44 ++++++- .../user/__tests__/PaymentResultView.spec.ts | 61 +++++++++ .../views/user/__tests__/PaymentView.spec.ts | 121 ++++++++++++++++++ .../__tests__/paymentWechatResume.spec.ts | 3 +- .../src/views/user/paymentWechatResume.ts | 17 ++- 6 files changed, 304 insertions(+), 30 deletions(-) diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index cbebaa83..b75d75df 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -181,6 +181,54 @@ function isPendingStatus(status: string | null | undefined): boolean { return PENDING_STATUSES.has(normalizeOrderStatus(status)) } +function readRouteQueryString(key: string): string { + const value = route.query[key] + if (Array.isArray(value)) { + return typeof value[0] === 'string' ? value[0] : '' + } + return typeof value === 'string' ? value : '' +} + +function restoreRecoverySnapshot(context: { + resumeToken: string + routeOrderId: number + routeOutTradeNo: string +}) { + if (typeof window === 'undefined') { + return null + } + + const rawSnapshot = window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY) + if (!rawSnapshot) { + return null + } + + if (context.resumeToken) { + return readPaymentRecoverySnapshot(rawSnapshot, { + resumeToken: context.resumeToken, + }) + } + + if (!context.routeOrderId && !context.routeOutTradeNo) { + return null + } + + const restored = readPaymentRecoverySnapshot(rawSnapshot) + if (!restored) { + return null + } + + if (context.routeOrderId > 0 && restored.orderId !== context.routeOrderId) { + return null + } + + if (context.routeOutTradeNo && restored.outTradeNo !== context.routeOutTradeNo) { + return null + } + + return restored +} + async function resolveOrderFromResumeToken(resumeToken: string): Promise { try { const result = await paymentAPI.resolveOrderPublicByResumeToken(resumeToken) @@ -239,24 +287,21 @@ function scheduleStatusRefresh(refreshOrder: (() => Promise } onMounted(async () => { - const resumeToken = typeof route.query.resume_token === 'string' - ? route.query.resume_token - : '' - const routeOrderId = Number(route.query.order_id) || 0 - let outTradeNo = String(route.query.out_trade_no || '') + const resumeToken = readRouteQueryString('resume_token') + const routeOrderId = Number(readRouteQueryString('order_id')) || 0 + let outTradeNo = readRouteQueryString('out_trade_no') let orderId = 0 - if (typeof window !== 'undefined') { - const restored = readPaymentRecoverySnapshot( - window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY), - resumeToken ? { resumeToken } : {}, - ) - if (restored?.orderId) { - orderId = restored.orderId - } - if (!outTradeNo && restored?.outTradeNo) { - outTradeNo = restored.outTradeNo - } + const restored = restoreRecoverySnapshot({ + resumeToken, + routeOrderId, + routeOutTradeNo: outTradeNo, + }) + if (restored?.orderId) { + orderId = restored.orderId + } + if (!outTradeNo && restored?.outTradeNo) { + outTradeNo = restored.outTradeNo } if (resumeToken) { @@ -266,15 +311,14 @@ onMounted(async () => { if (!orderId) { orderId = resolvedOrder.id } + } else if (routeOrderId > 0) { + orderId = routeOrderId } - } - - if (!resumeToken) { + } else if (routeOrderId > 0) { orderId = routeOrderId } - const hasLegacyFallbackContext = typeof route.query.trade_status === 'string' - && route.query.trade_status.trim() !== '' + const hasLegacyFallbackContext = readRouteQueryString('trade_status').trim() !== '' const shouldUsePublicOutTradeNo = !resumeToken && outTradeNo !== '' && (hasLegacyFallbackContext || routeOrderId > 0 || orderId > 0) if (!order.value && shouldUsePublicOutTradeNo) { @@ -287,7 +331,7 @@ onMounted(async () => { } } - if (!order.value && !resumeToken && orderId) { + if (!order.value && orderId && (!resumeToken || routeOrderId > 0)) { try { order.value = await paymentStore.pollOrderStatus(orderId) } catch (_err: unknown) { diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index 1577039e..05d70512 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -409,6 +409,43 @@ async function redirectToPaymentResult(state: PaymentRecoverySnapshot): Promise< }) } +function buildWechatOAuthAuthorizeUrl( + authorizeUrl: string, + context: { paymentType: string; orderType: OrderType; planId?: number; orderAmount: number }, +): string { + const normalizedUrl = authorizeUrl.trim() + if (!normalizedUrl || typeof window === 'undefined') { + return normalizedUrl + } + + try { + const targetUrl = new URL(normalizedUrl, window.location.origin) + const redirectPath = targetUrl.searchParams.get('redirect') || '/purchase' + const redirectUrl = new URL(redirectPath, window.location.origin) + const paymentType = normalizeVisibleMethod(context.paymentType) || context.paymentType.trim() || 'wxpay' + + redirectUrl.searchParams.set('payment_type', paymentType) + redirectUrl.searchParams.set('order_type', context.orderType) + + if (context.planId) { + redirectUrl.searchParams.set('plan_id', String(context.planId)) + } else { + redirectUrl.searchParams.delete('plan_id') + } + + if (context.orderAmount > 0) { + redirectUrl.searchParams.set('amount', String(context.orderAmount)) + } else { + redirectUrl.searchParams.delete('amount') + } + + targetUrl.searchParams.set('redirect', `${redirectUrl.pathname}${redirectUrl.search}`) + return targetUrl.toString() + } catch { + return normalizedUrl + } +} + function onPaymentDone() { const wasSubscription = paymentState.value.orderType === 'subscription' resetPayment() @@ -676,7 +713,12 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n }) if (decision.kind === 'wechat_oauth' && decision.oauth?.authorize_url) { - window.location.href = decision.oauth.authorize_url + window.location.href = buildWechatOAuthAuthorizeUrl(decision.oauth.authorize_url, { + paymentType: visibleMethod, + orderType, + planId, + orderAmount, + }) return } diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts index 91741963..81a7ccf0 100644 --- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts @@ -220,6 +220,41 @@ describe('PaymentResultView', () => { expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() }) + it('falls back to order_id polling when resume-token recovery fails', async () => { + routeState.query = { + resume_token: 'resume-fail', + order_id: '77', + } + window.localStorage.setItem( + PAYMENT_RECOVERY_STORAGE_KEY, + JSON.stringify({ + ...recoverySnapshotFactory('resume-fail'), + orderId: 42, + }), + ) + resolveOrderPublicByResumeToken.mockRejectedValueOnce(new Error('resume failed')) + pollOrderStatus.mockResolvedValueOnce({ + ...orderFactory('PAID'), + id: 77, + }) + + const wrapper = mount(PaymentResultView, { + global: { + stubs: { + OrderStatusBadge: true, + }, + }, + }) + + await flushPromises() + + expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-fail') + expect(pollOrderStatus).toHaveBeenCalledWith(77) + expect(verifyOrderPublic).not.toHaveBeenCalled() + expect(wrapper.text()).toContain('payment.result.success') + expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() + }) + it('does not fall back to public out_trade_no verification when resume_token recovery fails', async () => { routeState.query = { resume_token: 'resume-fail', @@ -241,6 +276,32 @@ describe('PaymentResultView', () => { expect(verifyOrderPublic).not.toHaveBeenCalled() }) + it('ignores a stale global recovery snapshot when legacy return markers do not identify the order', async () => { + routeState.query = { + trade_status: 'TRADE_SUCCESS', + } + window.localStorage.setItem( + PAYMENT_RECOVERY_STORAGE_KEY, + JSON.stringify(recoverySnapshotFactory('resume-stale')), + ) + + const wrapper = mount(PaymentResultView, { + global: { + stubs: { + OrderStatusBadge: true, + }, + }, + }) + + await flushPromises() + + expect(resolveOrderPublicByResumeToken).not.toHaveBeenCalled() + expect(verifyOrderPublic).not.toHaveBeenCalled() + expect(pollOrderStatus).not.toHaveBeenCalled() + expect(wrapper.text()).toContain('payment.result.failed') + expect(wrapper.text()).not.toContain('sub2_20260420abcd1234') + }) + it('uses public out_trade_no verification when no signed resume context is available', async () => { routeState.query = { out_trade_no: 'legacy-123', diff --git a/frontend/src/views/user/__tests__/PaymentView.spec.ts b/frontend/src/views/user/__tests__/PaymentView.spec.ts index 66648da4..2b81a085 100644 --- a/frontend/src/views/user/__tests__/PaymentView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentView.spec.ts @@ -109,6 +109,35 @@ function checkoutInfoFixture() { } } +function checkoutInfoWithPlansFixture() { + return { + data: { + ...checkoutInfoFixture().data, + plans: [ + { + id: 7, + group_id: 3, + name: 'Starter', + description: '', + price: 128, + original_price: 0, + validity_days: 30, + validity_unit: 'day', + rate_multiplier: 1, + daily_limit_usd: null, + weekly_limit_usd: null, + monthly_limit_usd: null, + features: [], + group_platform: 'openai', + sort_order: 1, + for_sale: true, + group_name: 'OpenAI', + }, + ], + }, + } +} + function jsapiOrderFixture(resumeToken: string) { return { order_id: 123, @@ -131,6 +160,24 @@ function jsapiOrderFixture(resumeToken: string) { } } +function oauthOrderFixture() { + return { + order_id: 456, + amount: 128, + pay_amount: 128, + fee_rate: 0, + expires_at: '2099-01-01T00:10:00.000Z', + payment_type: 'wxpay', + result_type: 'oauth_required' as const, + oauth: { + authorize_url: '/api/v1/auth/oauth/wechat/payment/start?payment_type=wxpay&redirect=%2Fpurchase%3Ffrom%3Dwechat', + appid: 'wx123', + scope: 'snsapi_base', + redirect_url: '/auth/wechat/payment/callback', + }, + } +} + describe('PaymentView WeChat JSAPI flow', () => { beforeEach(() => { routeState.path = '/purchase' @@ -239,4 +286,78 @@ describe('PaymentView WeChat JSAPI flow', () => { })) expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() }) + + it('keeps subscription resume context for token-only WeChat callbacks', async () => { + routeState.query = { + wechat_resume: '1', + wechat_resume_token: 'resume-subscription-7', + payment_type: 'wxpay_direct', + order_type: 'subscription', + plan_id: '7', + } + getCheckoutInfo.mockResolvedValue(checkoutInfoWithPlansFixture()) + createOrder.mockResolvedValue(oauthOrderFixture()) + + const originalLocation = window.location + const locationState = { + href: 'http://localhost/purchase', + origin: 'http://localhost', + } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState, + }) + + shallowMount(PaymentView, { + global: { + stubs: { + Teleport: true, + Transition: false, + }, + }, + }) + await flushPromises() + await flushPromises() + + expect(routerReplace).toHaveBeenCalledWith({ path: '/purchase', query: {} }) + expect(createOrder).toHaveBeenCalledWith(expect.objectContaining({ + payment_type: 'wxpay', + order_type: 'subscription', + plan_id: 7, + wechat_resume_token: 'resume-subscription-7', + })) + expect(locationState.href).toContain('/api/v1/auth/oauth/wechat/payment/start?') + expect(new URL(locationState.href, 'http://localhost').searchParams.get('redirect')).toBe( + '/purchase?from=wechat&payment_type=wxpay&order_type=subscription&plan_id=7', + ) + + Object.defineProperty(window, 'location', { + configurable: true, + value: originalLocation, + }) + }) + + it('shows explicit H5 authorization guidance instead of failing silently', async () => { + routeState.query = { + wechat_resume: '1', + wechat_resume_token: 'resume-token-h5', + payment_type: 'wxpay_direct', + } + createOrder.mockRejectedValueOnce({ reason: 'WECHAT_H5_NOT_AUTHORIZED' }) + + shallowMount(PaymentView, { + global: { + stubs: { + Teleport: true, + Transition: false, + }, + }, + }) + await flushPromises() + await flushPromises() + + expect(showError).toHaveBeenCalledWith( + 'payment.errors.wechatH5NotAuthorized payment.errors.wechatOpenInWeChatHint', + ) + }) }) diff --git a/frontend/src/views/user/__tests__/paymentWechatResume.spec.ts b/frontend/src/views/user/__tests__/paymentWechatResume.spec.ts index c850ec1b..85b2b0fd 100644 --- a/frontend/src/views/user/__tests__/paymentWechatResume.spec.ts +++ b/frontend/src/views/user/__tests__/paymentWechatResume.spec.ts @@ -14,8 +14,9 @@ describe('parseWechatResumeRoute', () => { }, [], 88)).toEqual({ wechatResumeToken: 'resume-token-123', paymentType: 'wxpay', - orderType: 'balance', + orderType: 'subscription', orderAmount: 0, + planId: 7, }) }) diff --git a/frontend/src/views/user/paymentWechatResume.ts b/frontend/src/views/user/paymentWechatResume.ts index 64f254da..8121bc56 100644 --- a/frontend/src/views/user/paymentWechatResume.ts +++ b/frontend/src/views/user/paymentWechatResume.ts @@ -37,12 +37,20 @@ export function parseWechatResumeRoute( } const wechatResumeToken = readQueryString(query, 'wechat_resume_token') + const paymentType = normalizeVisibleMethod(readQueryString(query, 'payment_type')) || 'wxpay' + const planId = Number.parseInt(readQueryString(query, 'plan_id'), 10) + const hasPlanId = Number.isFinite(planId) && planId > 0 + const orderType = readQueryString(query, 'order_type') === 'subscription' || hasPlanId + ? 'subscription' + : 'balance' + if (wechatResumeToken) { return { wechatResumeToken, - paymentType: 'wxpay', - orderType: 'balance', + paymentType, + orderType, orderAmount: 0, + planId: hasPlanId ? planId : undefined, } } @@ -51,9 +59,6 @@ export function parseWechatResumeRoute( return null } - const paymentType = normalizeVisibleMethod(readQueryString(query, 'payment_type')) || 'wxpay' - const orderType = readQueryString(query, 'order_type') === 'subscription' ? 'subscription' : 'balance' - const planId = Number.parseInt(readQueryString(query, 'plan_id'), 10) const rawAmount = Number.parseFloat(readQueryString(query, 'amount')) const orderAmount = Number.isFinite(rawAmount) && rawAmount > 0 ? rawAmount @@ -66,7 +71,7 @@ export function parseWechatResumeRoute( paymentType, orderType, orderAmount, - planId: Number.isFinite(planId) && planId > 0 ? planId : undefined, + planId: hasPlanId ? planId : undefined, } } From 06136af8050a3c986fdcb15540fae63fc8cbe3c2 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 13:18:10 +0800 Subject: [PATCH 14/31] fix(upgrade): preserve legacy auth and payment compatibility --- .../internal/handler/admin/setting_handler.go | 20 +++-- ...tting_handler_auth_source_defaults_test.go | 88 +++++++++++++++++++ .../internal/repository/migrations_runner.go | 11 ++- .../migrations_runner_extra_test.go | 13 +++ .../internal/service/payment_config_limits.go | 10 ++- .../service/payment_config_limits_test.go | 57 ++++++++++++ .../service/payment_resume_service_test.go | 61 +++++++++++-- .../payment_visible_method_instances.go | 20 ++++- backend/internal/service/setting_service.go | 21 +++-- .../setting_service_wechat_config_test.go | 30 +++++++ ...hat_dual_mode_and_auth_source_defaults.sql | 2 + ...ayment_orders_out_trade_no_unique_notx.sql | 2 - ...y_auth_source_grant_on_signup_defaults.sql | 42 +-------- ...tity_payment_migrations_regression_test.go | 10 +-- 14 files changed, 311 insertions(+), 76 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index d340a8a6..c6b45ab8 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -304,8 +304,8 @@ type UpdateSettingsRequest struct { OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"` OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"` OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"` - OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"` - OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"` + OIDCConnectUsePKCE *bool `json:"oidc_connect_use_pkce"` + OIDCConnectValidateIDToken *bool `json:"oidc_connect_validate_id_token"` OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"` OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"` OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"` @@ -682,6 +682,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } // Generic OIDC 参数验证 + oidcUsePKCE := previousSettings.OIDCConnectUsePKCE + oidcValidateIDToken := previousSettings.OIDCConnectValidateIDToken if req.OIDCConnectEnabled { req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName) req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID) @@ -716,11 +718,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoEmailPath, previousSettings.OIDCConnectUserInfoEmailPath)) req.OIDCConnectUserInfoIDPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoIDPath, previousSettings.OIDCConnectUserInfoIDPath)) req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoUsernamePath, previousSettings.OIDCConnectUserInfoUsernamePath)) - if !req.OIDCConnectUsePKCE { - req.OIDCConnectUsePKCE = previousSettings.OIDCConnectUsePKCE + if req.OIDCConnectUsePKCE != nil { + oidcUsePKCE = *req.OIDCConnectUsePKCE } - if !req.OIDCConnectValidateIDToken { - req.OIDCConnectValidateIDToken = previousSettings.OIDCConnectValidateIDToken + if req.OIDCConnectValidateIDToken != nil { + oidcValidateIDToken = *req.OIDCConnectValidateIDToken } if req.OIDCConnectClockSkewSeconds == 0 { req.OIDCConnectClockSkewSeconds = previousSettings.OIDCConnectClockSkewSeconds @@ -795,7 +797,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600") return } - if req.OIDCConnectValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" { + if oidcValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" { response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") return } @@ -1076,8 +1078,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { OIDCConnectRedirectURL: req.OIDCConnectRedirectURL, OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL, OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod, - OIDCConnectUsePKCE: req.OIDCConnectUsePKCE, - OIDCConnectValidateIDToken: req.OIDCConnectValidateIDToken, + OIDCConnectUsePKCE: oidcUsePKCE, + OIDCConnectValidateIDToken: oidcValidateIDToken, OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs, OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds, OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified, diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go index cef531e0..8045d0c9 100644 --- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -247,6 +247,94 @@ func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedS require.Equal(t, true, data["openai_advanced_scheduler_enabled"]) } +func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodSource(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + service.SettingPaymentVisibleMethodAlipayEnabled: "true", + service.SettingPaymentVisibleMethodAlipaySource: "", + service.SettingPaymentVisibleMethodWxpayEnabled: "false", + service.SettingPaymentVisibleMethodWxpaySource: "", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": false, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "", repo.values[service.SettingPaymentVisibleMethodAlipaySource]) + require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled]) +} + +func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFlags(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyOIDCConnectEnabled: "true", + service.SettingKeyOIDCConnectProviderName: "OIDC", + service.SettingKeyOIDCConnectClientID: "oidc-client", + service.SettingKeyOIDCConnectClientSecret: "oidc-secret", + service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com", + service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth", + service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token", + service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo", + service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks", + service.SettingKeyOIDCConnectScopes: "openid email profile", + service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", + service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", + service.SettingKeyOIDCConnectUsePKCE: "true", + service.SettingKeyOIDCConnectValidateIDToken: "true", + service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256", + service.SettingKeyOIDCConnectClockSkewSeconds: "120", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "oidc_connect_enabled": true, + "oidc_connect_use_pkce": false, + "oidc_connect_validate_id_token": false, + "oidc_connect_allowed_signing_algs": "", + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE]) + require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken]) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, false, data["oidc_connect_use_pkce"]) + require.Equal(t, false, data["oidc_connect_validate_id_token"]) +} + func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { gin.SetMode(gin.TestMode) repo := &settingHandlerRepoStub{ diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index edc85226..662a3972 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -62,10 +62,13 @@ type migrationChecksumCompatibilityRule struct { // 规则必须同时匹配「迁移名 + 数据库 checksum + 当前文件 checksum」且两者都落在该迁移的已知版本集合内才会放行, // 避免放宽全局校验,也允许将误改的历史 migration 回滚为已发布版本而不要求人工修 checksum。 var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{ - "054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"), - "061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"), - "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"), - "119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"), + "054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"), + "061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"), + "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"), + "118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"), + "119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"), + "120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"), + "123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"), } // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go index 9f8a94c6..af1adc50 100644 --- a/backend/internal/repository/migrations_runner_extra_test.go +++ b/backend/internal/repository/migrations_runner_extra_test.go @@ -94,6 +94,19 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) { require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum)) } +func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) { + for _, name := range []string{ + "118_wechat_dual_mode_and_auth_source_defaults.sql", + "120_enforce_payment_orders_out_trade_no_unique_notx.sql", + "123_fix_legacy_auth_source_grant_on_signup_defaults.sql", + } { + rule, ok := migrationChecksumCompatibilityRules[name] + require.Truef(t, ok, "missing compatibility rule for %s", name) + require.NotEmpty(t, rule.fileChecksum) + require.NotEmpty(t, rule.acceptedDBChecksum) + } +} + func TestEnsureAtlasBaselineAligned(t *testing.T) { t.Run("skip_when_no_legacy_table", func(t *testing.T) { db, mock, err := sqlmock.New() diff --git a/backend/internal/service/payment_config_limits.go b/backend/internal/service/payment_config_limits.go index e44bf2e7..973c601a 100644 --- a/backend/internal/service/payment_config_limits.go +++ b/backend/internal/service/payment_config_limits.go @@ -45,10 +45,18 @@ func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context. for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} { matching := filterEnabledVisibleMethodInstances(instances, method) providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching) - if err != nil || providerKey == "" { + if err != nil { delete(filtered, method) continue } + if providerKey == "" { + if len(matching) == 0 { + delete(filtered, method) + continue + } + filtered[method] = matching + continue + } selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey) if len(selectedInstances) == 0 { delete(filtered, method) diff --git a/backend/internal/service/payment_config_limits_test.go b/backend/internal/service/payment_config_limits_test.go index 12cd6866..4df506d6 100644 --- a/backend/internal/service/payment_config_limits_test.go +++ b/backend/internal/service/payment_config_limits_test.go @@ -6,6 +6,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" ) func TestUnionFloat(t *testing.T) { @@ -402,3 +403,59 @@ func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) }) } } + +func TestGetAvailableMethodLimitsPreservesLegacyCrossProviderBehaviorWhenVisibleMethodSourceMissing(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Official Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Mixed"). + SetConfig("{}"). + SetSupportedTypes("alipay,wxpay"). + SetLimits(`{"alipay":{"singleMin":20,"singleMax":200},"wxpay":{"singleMin":40,"singleMax":400}}`). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("Official WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{}}, + } + + resp, err := svc.GetAvailableMethodLimits(ctx) + require.NoError(t, err) + + alipayLimits, ok := resp.Methods[payment.TypeAlipay] + require.True(t, ok, "expected alipay limits to remain visible") + require.Equal(t, 10.0, alipayLimits.SingleMin) + require.Equal(t, 200.0, alipayLimits.SingleMax) + + wxpayLimits, ok := resp.Methods[payment.TypeWxpay] + require.True(t, ok, "expected wxpay limits to remain visible") + require.Equal(t, 30.0, wxpayLimits.SingleMin) + require.Equal(t, 400.0, wxpayLimits.SingleMax) + + require.Equal(t, 10.0, resp.GlobalMin) + require.Equal(t, 400.0, resp.GlobalMax) +} diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go index 9e756971..59a2221e 100644 --- a/backend/internal/service/payment_resume_service_test.go +++ b/backend/internal/service/payment_resume_service_test.go @@ -586,7 +586,60 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl } } -func TestVisibleMethodLoadBalancerRejectsMissingOrInvalidSourceWhenMultipleProvidersEnabled(t *testing.T) { +func TestVisibleMethodLoadBalancerPreservesLegacyCrossProviderRoutingWhenSourceMissing(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Official Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create official provider: %v", err) + } + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetSortOrder(2). + Save(ctx) + if err != nil { + t.Fatalf("create easypay provider: %v", err) + } + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + visibleMethodSourceSettingKey(payment.TypeAlipay): "", + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + _, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 9.9) + if err != nil { + t.Fatalf("SelectInstance returned error: %v", err) + } + if inner.lastProviderKey != "" { + t.Fatalf("lastProviderKey = %q, want legacy cross-provider empty key", inner.lastProviderKey) + } + if inner.lastPaymentType != payment.TypeAlipay { + t.Fatalf("lastPaymentType = %q, want %q", inner.lastPaymentType, payment.TypeAlipay) + } +} + +func TestVisibleMethodLoadBalancerRejectsInvalidSourceWhenMultipleProvidersEnabled(t *testing.T) { t.Parallel() tests := []struct { @@ -595,12 +648,6 @@ func TestVisibleMethodLoadBalancerRejectsMissingOrInvalidSourceWhenMultipleProvi sourceValue string wantMessage string }{ - { - name: "missing alipay source", - method: payment.TypeAlipay, - sourceValue: "", - wantMessage: "alipay source is required when the visible method is enabled", - }, { name: "invalid wxpay source", method: payment.TypeWxpay, diff --git a/backend/internal/service/payment_visible_method_instances.go b/backend/internal/service/payment_visible_method_instances.go index 86ea5ead..5dcdab16 100644 --- a/backend/internal/service/payment_visible_method_instances.go +++ b/backend/internal/service/payment_visible_method_instances.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "strings" @@ -166,15 +167,21 @@ func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context if s != nil && s.settingRepo != nil && sourceKey != "" { value, err := s.settingRepo.GetValue(ctx, sourceKey) if err != nil { - return "", fmt.Errorf("get %s: %w", sourceKey, err) + if !errors.Is(err, ErrSettingNotFound) { + return "", fmt.Errorf("get %s: %w", sourceKey, err) + } + } else { + rawSource = value } - rawSource = value } normalizedSource, err := normalizeVisibleMethodSettingSource(method, rawSource, true) if err != nil { return "", err } + if normalizedSource == "" { + return "", nil + } providerKey, ok := VisibleMethodProviderKeyForSource(method, normalizedSource) if !ok { return "", infraerrors.BadRequest( @@ -200,6 +207,9 @@ func (s *PaymentConfigService) resolveVisibleMethodProviderKey( if err != nil { return "", err } + if providerKey == "" { + return "", nil + } selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey) if selected == nil { return "", infraerrors.BadRequest( @@ -237,5 +247,11 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance( if err != nil { return nil, err } + if providerKey == "" { + if len(matching) == 0 { + return nil, nil + } + return &dbent.PaymentProviderInstance{ProviderKey: ""}, nil + } return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 72569882..f08274c7 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -282,7 +282,19 @@ func mergeWeChatConnectCapabilitySettings(settings map[string]string, base confi mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != "" if openConfigured || mpConfigured || mobileConfigured { - return parseWeChatConnectCapabilitySettings(settings, enabled, mode) + openEnabled := strings.TrimSpace(rawOpen) == "true" + mpEnabled := strings.TrimSpace(rawMP) == "true" + mobileEnabled := strings.TrimSpace(rawMobile) == "true" + _, enabledConfigured := settings[SettingKeyWeChatConnectEnabled] + if !enabledConfigured && + enabled && + !openEnabled && + !mpEnabled && + !mobileEnabled && + (base.OpenEnabled || base.MPEnabled || base.MobileEnabled) { + return base.OpenEnabled, base.MPEnabled, base.MobileEnabled + } + return openEnabled, mpEnabled, mobileEnabled } if !enabled { return false, false, false @@ -1921,14 +1933,9 @@ func isFalseSettingValue(value string) bool { } func normalizeVisibleMethodSettingSource(method, source string, enabled bool) (string, error) { + _ = enabled source = strings.TrimSpace(source) if source == "" { - if enabled { - return "", infraerrors.BadRequest( - "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", - fmt.Sprintf("%s source is required when the visible method is enabled", method), - ) - } return "", nil } diff --git a/backend/internal/service/setting_service_wechat_config_test.go b/backend/internal/service/setting_service_wechat_config_test.go index 08f67b7c..a2de614b 100644 --- a/backend/internal/service/setting_service_wechat_config_test.go +++ b/backend/internal/service/setting_service_wechat_config_test.go @@ -109,6 +109,36 @@ func TestSettingService_GetWeChatConnectOAuthConfig_FallsBackToConfigWhenDatabas require.Empty(t, got.RedirectURL) } +func TestSettingService_GetWeChatConnectOAuthConfig_IgnoresSyntheticDisabledCapabilitiesFromMigration118(t *testing.T) { + repo := &settingWeChatRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectOpenEnabled: "false", + SettingKeyWeChatConnectMPEnabled: "false", + }, + } + svc := NewSettingService(repo, &config.Config{ + WeChat: config.WeChatConnectConfig{ + Enabled: true, + OpenEnabled: true, + MPEnabled: true, + Mode: "open", + OpenAppID: "wx-open-config", + OpenAppSecret: "wx-open-secret", + MPAppID: "wx-mp-config", + MPAppSecret: "wx-mp-secret", + FrontendRedirectURL: "/auth/wechat/config-callback", + }, + }) + + got, err := svc.GetWeChatConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.Enabled) + require.True(t, got.OpenEnabled) + require.True(t, got.MPEnabled) + require.Equal(t, "wx-open-config", got.AppIDForMode("open")) + require.Equal(t, "wx-mp-config", got.AppIDForMode("mp")) +} + func TestSettingService_ParseSettings_FallsBackToConfigForWeChatAdminView(t *testing.T) { svc := NewSettingService(&settingWeChatRepoStub{values: map[string]string{}}, &config.Config{ WeChat: config.WeChatConnectConfig{ diff --git a/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql index 9b037984..18782617 100644 --- a/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql +++ b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql @@ -3,6 +3,7 @@ VALUES ( 'wechat_connect_open_enabled', CASE + WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN '' WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false' WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'false' ELSE 'true' @@ -11,6 +12,7 @@ VALUES ( 'wechat_connect_mp_enabled', CASE + WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN '' WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false' WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'true' ELSE 'false' diff --git a/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql index fe47698d..00836698 100644 --- a/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql +++ b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql @@ -1,8 +1,6 @@ -- Build the payment order uniqueness guarantee online. -- Create the new partial unique index concurrently first so writes keep flowing, -- then remove the legacy index name once the replacement is ready. -DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique; - CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique ON payment_orders (out_trade_no) WHERE out_trade_no <> ''; diff --git a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql index f6053ef0..094b223c 100644 --- a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql +++ b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql @@ -1,39 +1,3 @@ -WITH migration_110 AS ( - SELECT applied_at - FROM schema_migrations - WHERE filename = '110_pending_auth_and_provider_default_grants.sql' -), -legacy_provider_defaults AS ( - SELECT provider_type - FROM ( - VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat') - ) AS providers(provider_type) - CROSS JOIN migration_110 - JOIN settings balance - ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance' - JOIN settings concurrency - ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency' - JOIN settings subscriptions - ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions' - JOIN settings grant_on_signup - ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup' - JOIN settings grant_on_first_bind - ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind' - WHERE balance.value = '0' - AND concurrency.value = '5' - AND subscriptions.value = '[]' - AND grant_on_signup.value = 'true' - AND grant_on_first_bind.value = 'false' - AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' - AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' - AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' - AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' - AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' -) -UPDATE settings -SET - value = 'false', - updated_at = NOW() -FROM legacy_provider_defaults -WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup' - AND settings.value = 'true'; +-- Intentionally left as a no-op. +-- Legacy installs may have intentionally kept the original signup grant defaults, +-- and we cannot distinguish those cases safely from untouched migration 110 rows. diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go index dbf8fc47..dcb0bb9c 100644 --- a/backend/migrations/auth_identity_payment_migrations_regression_test.go +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -24,6 +24,7 @@ func TestMigration118DoesNotForceOverwriteAuthSourceGrantDefaults(t *testing.T) require.NotContains(t, sql, "UPDATE settings") require.NotContains(t, sql, "SET value = 'false'") require.True(t, strings.Contains(sql, "ON CONFLICT (key) DO NOTHING")) + require.Contains(t, sql, "THEN ''") } func TestAuthIdentityReportTypeWideningRunsBeforeLongReportWritersAndStillReconcilesAt121(t *testing.T) { @@ -63,6 +64,7 @@ func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) { followupSQL := string(followupContent) require.Contains(t, followupSQL, "CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique") + require.NotContains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique") require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no") require.Contains(t, followupSQL, "WHERE out_trade_no <> ''") @@ -92,9 +94,7 @@ func TestMigration123BackfillsLegacyAuthSourceGrantDefaultsSafely(t *testing.T) require.NoError(t, err) sql := string(content) - require.Contains(t, sql, "110_pending_auth_and_provider_default_grants.sql") - require.Contains(t, sql, "schema_migrations") - require.Contains(t, sql, "updated_at") - require.Contains(t, sql, "'_grant_on_signup'") - require.Contains(t, sql, "value = 'false'") + require.Contains(t, sql, "Intentionally left as a no-op") + require.NotContains(t, sql, "UPDATE settings") + require.NotContains(t, sql, "value = 'false'") } From 83cad63ce0c7737c5b09151d9710aaa8e3bd1e83 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 13:19:20 +0800 Subject: [PATCH 15/31] fix(auth): harden oauth callback adoption flows --- backend/internal/handler/auth_handler.go | 19 +++ .../internal/handler/auth_linuxdo_oauth.go | 2 +- .../handler/auth_linuxdo_oauth_test.go | 131 +++++++++++++++++ .../handler/auth_oauth_pending_flow.go | 18 +-- .../handler/auth_oauth_pending_flow_test.go | 50 +++++++ .../handler/auth_oauth_test_helpers_test.go | 18 +++ backend/internal/handler/auth_oidc_oauth.go | 2 +- .../internal/handler/auth_oidc_oauth_test.go | 114 +++++++++++++++ backend/internal/handler/auth_wechat_oauth.go | 18 ++- .../handler/auth_wechat_oauth_test.go | 133 ++++++++++++++++++ 10 files changed, 490 insertions(+), 15 deletions(-) diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index acd43e9f..ca3a5a77 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -78,9 +78,24 @@ type AuthResponse struct { User *dto.User `json:"user"` } +func ensureLoginUserActive(user *service.User) error { + if user == nil { + return infraerrors.Unauthorized("INVALID_USER", "user not found") + } + if !user.IsActive() { + return service.ErrUserNotActive + } + return nil +} + // respondWithTokenPair 生成 Token 对并返回认证响应 // 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容) func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) { + if err := ensureLoginUserActive(user); err != nil { + response.ErrorFrom(c, err) + return + } + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") if err != nil { slog.Error("failed to generate token pair", "error", err, "user_id", user.ID) @@ -293,6 +308,10 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { response.ErrorFrom(c, err) return } + if err := ensureLoginUserActive(user); err != nil { + response.ErrorFrom(c, err) + return + } if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 157be066..a7e77c09 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -495,7 +495,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ AdoptDisplayName: req.AdoptDisplayName, AdoptAvatar: req.AdoptAvatar, }) diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index a9a5e3e6..d535c178 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -408,6 +408,74 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t require.Nil(t, completion["error"]) } +func TestLinuxDoOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_disabled","name":"LinuxDo Disabled"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(linuxDoSyntheticEmail("654")). + SetUsername("disabled-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("654"). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-disabled&state=state-disabled", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-disabled")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-disabled")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { @@ -812,6 +880,69 @@ func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillReq require.Nil(t, storedSession.ConsumedAt) } +func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-no-adoption-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-subject-no-adoption"). + SetResolvedEmail("linuxdo-subject-no-adoption@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-browser-no-adoption"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "LinuxDo Legacy", + "suggested_avatar_url": "https://cdn.example/linuxdo-legacy.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser-no-adoption")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + require.NotEmpty(t, responseData["refresh_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "linuxdo_user", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("linuxdo-subject-no-adoption"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) +} + func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { t.Helper() handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index c5df4db1..7be01e74 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -464,15 +464,7 @@ func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity servic } return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) } - - userEntity, err := client.User.Get(ctx, record.UserID) - if err != nil { - if dbent.IsNotFound(err) { - return nil, nil - } - return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err) - } - return userEntity, nil + return findActiveUserByID(ctx, client, record.UserID) } func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") } @@ -998,6 +990,9 @@ func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64) } return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err) } + if !strings.EqualFold(strings.TrimSpace(userEntity.Status), service.StatusActive) { + return nil, service.ErrUserNotActive + } return userEntity, nil } @@ -1801,6 +1796,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { response.ErrorFrom(c, err) return } + if err := ensureLoginUserActive(loginUser); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil { clearCookies() response.ErrorFrom(c, err) diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index b3b8dfe1..bc8fe7eb 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -851,6 +851,56 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl require.Nil(t, storedSession.ConsumedAt) } +func TestExchangePendingOAuthCompletionRejectsDisabledTargetUser(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("disabled-linked@example.com"). + SetUsername("disabled-linked-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("disabled-linked-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("disabled-linked-subject"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("disabled-linked-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Disabled Linked User", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("disabled-linked-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) { payload := normalizePendingOAuthCompletionResponse(map[string]any{ "access_token": "legacy-access-token", diff --git a/backend/internal/handler/auth_oauth_test_helpers_test.go b/backend/internal/handler/auth_oauth_test_helpers_test.go index 8eb87dbb..47bad942 100644 --- a/backend/internal/handler/auth_oauth_test_helpers_test.go +++ b/backend/internal/handler/auth_oauth_test_helpers_test.go @@ -2,6 +2,7 @@ package handler import ( "net/http" + "net/url" "testing" "github.com/stretchr/testify/require" @@ -37,3 +38,20 @@ func decodeCookieValueForTest(t *testing.T, value string) string { require.NoError(t, err) return decoded } + +func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) { + t.Helper() + require.NotEmpty(t, location) + + parsed, err := url.Parse(location) + require.NoError(t, err) + + rawValues := parsed.RawQuery + if rawValues == "" { + rawValues = parsed.Fragment + } + values, err := url.ParseQuery(rawValues) + require.NoError(t, err) + require.Equal(t, errorCode, values.Get("error")) + require.Equal(t, errorMessage, values.Get("error_message")) +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 6345938b..3c67e421 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -648,7 +648,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ AdoptDisplayName: req.AdoptDisplayName, AdoptAvatar: req.AdoptAvatar, }) diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index 63008344..c2855dc9 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -340,6 +340,56 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *t require.Nil(t, completion["error"]) } +func TestOIDCOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-disabled-subject", + PreferredUsername: "oidc_disabled", + DisplayName: "OIDC Disabled", + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-disabled-subject"))). + SetUsername("disabled-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("oidc"). + SetProviderKey(cfg.IssuerURL). + SetProviderSubject("oidc-disabled-subject"). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-disabled", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-disabled")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-disabled")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-disabled-subject")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) { cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ Subject: "oidc-subject-compat", @@ -748,6 +798,70 @@ func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequir require.Nil(t, storedSession.ConsumedAt) } +func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-no-adoption-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-subject-no-adoption"). + SetResolvedEmail("8c9f12b2a2e14b1db9efc08b27e0ef5c@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-browser-no-adoption"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + "suggested_display_name": "OIDC Legacy", + "suggested_avatar_url": "https://cdn.example/oidc-legacy.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser-no-adoption")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + require.NotEmpty(t, responseData["refresh_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "oidc_user", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example.com"), + authidentity.ProviderSubjectEQ("oidc-subject-no-adoption"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) +} + type oidcProviderFixture struct { Subject string PreferredUsername string diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 3ed20a7d..dc93fcae 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -551,7 +551,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ AdoptDisplayName: req.AdoptDisplayName, AdoptAvatar: req.AdoptAvatar, }) @@ -827,7 +827,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) } if user, err := singleWeChatIdentityUser(records); err != nil || user != nil { - return user, err + if err != nil || user == nil { + return user, err + } + return findActiveUserByID(ctx, client, user.ID) } } @@ -851,7 +854,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) } if user, err := singleWeChatChannelUser(records); err != nil || user != nil { - return user, err + if err != nil || user == nil { + return user, err + } + return findActiveUserByID(ctx, client, user.ID) } } @@ -870,7 +876,11 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( if err != nil { return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) } - return singleWeChatIdentityUser(records) + user, err := singleWeChatIdentityUser(records) + if err != nil || user == nil { + return user, err + } + return findActiveUserByID(ctx, client, user.ID) } func wechatCompatibleProviderKeys(providerKey string) []string { diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 349e7dd2..b8bd21ce 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/repository" @@ -292,6 +293,71 @@ func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWit require.False(t, hasRefreshToken) } +func TestWeChatOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-disabled","unionid":"union-disabled","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-disabled","unionid":"union-disabled","nickname":"Disabled WeChat","headimgurl":"https://cdn.example/disabled.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(wechatSyntheticEmail("union-disabled")). + SetUsername("disabled-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("union-disabled"). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-disabled", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-disabled")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL t.Cleanup(func() { @@ -816,6 +882,73 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPend require.Zero(t, decisionCount) } +func TestCompleteWeChatOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("wechat-complete-no-adoption-session"). + SetIntent("login"). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("wechat-subject-no-adoption"). + SetResolvedEmail("wechat-subject-no-adoption@wechat-connect.invalid"). + SetBrowserSessionKey("wechat-browser-no-adoption"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "wechat_user", + "suggested_display_name": "WeChat Legacy", + "suggested_avatar_url": "https://cdn.example/wechat-legacy.png", + "mode": "open", + "channel": "open", + "channel_app_id": "wx-open-app", + "channel_subject": "openid-legacy", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(recorder) + completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + completeReq.Header.Set("Content-Type", "application/json") + completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-browser-no-adoption")}) + completeCtx.Request = completeReq + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + require.NotEmpty(t, responseData["refresh_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "wechat_user", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("wechat-subject-no-adoption"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) +} + func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL From 81c827ee5128f77d148461eeea4079a04b53c087 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 13:19:28 +0800 Subject: [PATCH 16/31] fix(profile): stabilize identity binding management --- .../handler/auth_current_user_test.go | 21 +-- backend/internal/handler/user_handler.go | 17 +- backend/internal/handler/user_handler_test.go | 159 ++++++++++++++++-- .../repository/user_profile_identity_repo.go | 77 +++++++-- ...ser_profile_identity_repo_contract_test.go | 73 ++++++++ backend/internal/repository/user_repo.go | 21 +++ .../user_repo_email_lookup_unit_test.go | 77 +++++++++ ...dmin_service_auth_identity_binding_test.go | 87 ++++++++++ backend/internal/service/user_service_test.go | 4 +- .../ProfileIdentityBindingsSection.vue | 9 +- .../user/profile/ProfileInfoCard.vue | 14 +- .../ProfileIdentityBindingsSection.spec.ts | 23 +++ .../profile/__tests__/ProfileInfoCard.spec.ts | 41 +++++ 13 files changed, 584 insertions(+), 39 deletions(-) diff --git a/backend/internal/handler/auth_current_user_test.go b/backend/internal/handler/auth_current_user_test.go index 31d92a36..cb3e4ba5 100644 --- a/backend/internal/handler/auth_current_user_test.go +++ b/backend/internal/handler/auth_current_user_test.go @@ -29,18 +29,19 @@ func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T AvatarURL: "https://cdn.example.com/linuxdo.png", AvatarSource: "remote_url", }, - identities: []service.UserAuthIdentityRecord{ - { - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: "linuxdo-subject-31", - VerifiedAt: &verifiedAt, - Metadata: map[string]any{ - "username": "linuxdo-handle", + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-31", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + "avatar_url": "https://cdn.example.com/linuxdo.png", + }, }, }, - }, - } + } handler := &AuthHandler{ userService: service.NewUserService(repo, nil, nil, nil), diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 3e5ca080..80dcc5ce 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -258,6 +258,12 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) { response.ErrorFrom(c, err) return } + if h.authService != nil { + if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil { + response.ErrorFrom(c, err) + return + } + } profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) if err != nil { @@ -504,8 +510,12 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity thirdParty := thirdPartyIdentityProviders(identities) var avatarSource *userProfileSourceContext - if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 { - avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider) + avatarValue := strings.TrimSpace(user.AvatarURL) + for _, summary := range thirdParty { + if avatarValue != "" && avatarValue == strings.TrimSpace(summary.DisplayName) { + avatarSource = buildUserProfileSourceContext(summary.Provider) + break + } } usernameValue := strings.TrimSpace(user.Username) @@ -516,9 +526,6 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity break } } - if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 { - usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider) - } profileSources := map[string]*userProfileSourceContext{} if avatarSource != nil { diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index 51d5a814..87e168c0 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -270,18 +270,19 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { AvatarURL: "https://cdn.example.com/linuxdo.png", AvatarSource: "remote_url", }, - identities: []service.UserAuthIdentityRecord{ - { - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: "linuxdo-subject-21", - VerifiedAt: &verifiedAt, - Metadata: map[string]any{ - "username": "linuxdo-handle", + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-21", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + "avatar_url": "https://cdn.example.com/linuxdo.png", + }, }, }, - }, - } + } handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) recorder := httptest.NewRecorder() @@ -331,10 +332,102 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { require.Equal(t, "linuxdo", usernameSource["source"]) } +func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIdentityMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 22, + Email: "edited-profile@example.com", + Username: "custom-name", + Role: service.RoleUser, + Status: service.StatusActive, + AvatarURL: "https://cdn.example.com/custom.png", + AvatarSource: "remote_url", + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-22", + Metadata: map[string]any{ + "username": "linuxdo-handle", + "avatar_url": "https://cdn.example.com/linuxdo.png", + }, + }, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 22}) + + handler.GetProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.NotContains(t, resp.Data, "avatar_source") + require.NotContains(t, resp.Data, "username_source") + require.NotContains(t, resp.Data, "profile_sources") +} + type userHandlerEmailCacheStub struct { data *service.VerificationCodeData } +type userHandlerRefreshTokenCacheStub struct { + revokedUserIDs []int64 +} + +func (s *userHandlerRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *userHandlerRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error { + s.revokedUserIDs = append(s.revokedUserIDs, userID) + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *userHandlerRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *userHandlerRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) { return s.data, nil } @@ -495,6 +588,52 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) { require.Equal(t, false, linuxdoBinding["bound"]) } +func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigured(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 23, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "identity@example.com", + }, + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-23", + }, + }, + } + refreshTokenCache := &userHandlerRefreshTokenCacheStub{} + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 23}) + c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}} + + handler.UnbindIdentity(c) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, []int64{23}, refreshTokenCache.revokedUserIDs) +} + func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go index 2d812394..87094ad7 100644 --- a/backend/internal/repository/user_profile_identity_repo.go +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -301,17 +301,18 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA client := clientFromContext(txCtx, r.client) canonical := input.Canonical - identity, err := client.AuthIdentity.Query(). + identityRecords, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)), - authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)), + authidentity.ProviderKeyIn(compatibleIdentityProviderKeys(canonical.ProviderType, canonical.ProviderKey)...), authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)), ). - Only(txCtx) - if err != nil && !dbent.IsNotFound(err) { + All(txCtx) + if err != nil { return err } - if identity != nil && identity.UserID != input.UserID { + identity := selectOwnedCompatibleIdentity(identityRecords, input.UserID) + if identity == nil && hasCompatibleIdentityConflict(identityRecords, input.UserID) { return ErrAuthIdentityOwnershipConflict } if identity == nil { @@ -346,20 +347,21 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA var channel *dbent.AuthIdentityChannel if input.Channel != nil { - channel, err = client.AuthIdentityChannel.Query(). + channelRecords, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)), - authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)), + authidentitychannel.ProviderKeyIn(compatibleIdentityProviderKeys(input.Channel.ProviderType, input.Channel.ProviderKey)...), authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)), authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)), authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)), ). WithIdentity(). - Only(txCtx) - if err != nil && !dbent.IsNotFound(err) { + All(txCtx) + if err != nil { return err } - if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID { + channel = selectOwnedCompatibleChannel(channelRecords, input.UserID) + if channel == nil && hasCompatibleChannelConflict(channelRecords, input.UserID) { return ErrAuthIdentityChannelOwnershipConflict } if channel == nil { @@ -397,6 +399,61 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA return result, nil } +func compatibleIdentityProviderKeys(providerType, providerKey string) []string { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" { + return []string{providerKey} + } + if providerType != "wechat" { + return []string{providerKey} + } + keys := []string{providerKey} + if !strings.EqualFold(providerKey, "wechat-main") { + keys = append(keys, "wechat-main") + } + if !strings.EqualFold(providerKey, "wechat") { + keys = append(keys, "wechat") + } + return keys +} + +func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity { + for _, record := range records { + if record.UserID == userID { + return record + } + } + return nil +} + +func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool { + for _, record := range records { + if record.UserID != userID { + return true + } + } + return false +} + +func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel { + for _, record := range records { + if record.Edges.Identity != nil && record.Edges.Identity.UserID == userID { + return record + } + } + return nil +} + +func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool { + for _, record := range records { + if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID { + return true + } + } + return false +} + func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) { exec := txAwareSQLExecutor(ctx, r.sql, r.client) if exec == nil { diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go index 697e96a4..69a25fbe 100644 --- a/backend/internal/repository/user_profile_identity_repo_contract_test.go +++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go @@ -186,6 +186,79 @@ func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAn s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict) } +func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_ReusesLegacyWeChatAliasRecords() { + user := s.mustCreateUser("wechat-legacy-alias") + + legacyIdentity, err := s.client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetProviderSubject("union-legacy-123"). + SetMetadata(map[string]any{"source": "legacy-alias"}). + Save(s.ctx) + s.Require().NoError(err) + + legacyChannel, err := s.client.AuthIdentityChannel.Create(). + SetIdentityID(legacyIdentity.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetChannel("oa"). + SetChannelAppID("wx-app-legacy"). + SetChannelSubject("openid-legacy-123"). + SetMetadata(map[string]any{"scene": "legacy-alias"}). + Save(s.ctx) + s.Require().NoError(err) + + bound, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-legacy-123", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + Channel: "oa", + ChannelAppID: "wx-app-legacy", + ChannelSubject: "openid-legacy-123", + }, + Metadata: map[string]any{"source": "canonical-bind"}, + ChannelMetadata: map[string]any{"scene": "canonical-bind"}, + }) + s.Require().NoError(err) + s.Require().NotNil(bound) + s.Require().NotNil(bound.Identity) + s.Require().NotNil(bound.Channel) + s.Require().Equal(legacyIdentity.ID, bound.Identity.ID) + s.Require().Equal(legacyChannel.ID, bound.Channel.ID) + s.Require().Equal("wechat-main", bound.Identity.ProviderKey) + s.Require().Equal("wechat-main", bound.Channel.ProviderKey) + s.Require().Equal("canonical-bind", bound.Identity.Metadata["source"]) + s.Require().Equal("canonical-bind", bound.Channel.Metadata["scene"]) + + identityCount, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderSubjectEQ("union-legacy-123"), + ). + Count(s.ctx) + s.Require().NoError(err) + s.Require().Equal(1, identityCount) + + channelCount, err := s.client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ChannelEQ("oa"), + authidentitychannel.ChannelAppIDEQ("wx-app-legacy"), + authidentitychannel.ChannelSubjectEQ("openid-legacy-123"), + ). + Count(s.ctx) + s.Require().NoError(err) + s.Require().Equal(1, channelCount) +} + func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() { user := s.mustCreateUser("provider-mismatch-create") diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index c7d301c7..68e51eeb 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -43,6 +43,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error if userIn == nil { return nil } + if err := r.ensureNormalizedEmailAvailable(ctx, 0, userIn.Email); err != nil { + return err + } // 统一使用 ent 的事务:保证用户与允许分组的更新原子化, // 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。 @@ -146,6 +149,9 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error if userIn == nil { return nil } + if err := r.ensureNormalizedEmailAvailable(ctx, userIn.ID, userIn.Email); err != nil { + return err + } // 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。 tx, err := r.client.Tx(ctx) @@ -704,6 +710,21 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx) } +func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error { + matches, err := r.client.User.Query(). + Where(userEmailLookupPredicate(email)). + All(ctx) + if err != nil { + return err + } + for _, match := range matches { + if match.ID != userID { + return service.ErrEmailExists + } + } + return nil +} + func userEmailLookupPredicate(email string) predicate.User { normalized := strings.ToLower(strings.TrimSpace(email)) if normalized == "" { diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go index d42ce9ac..b2b02ef5 100644 --- a/backend/internal/repository/user_repo_email_lookup_unit_test.go +++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go @@ -67,3 +67,80 @@ func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) require.NoError(t, err) require.True(t, exists) } + +func TestUserRepositoryCreateRejectsNormalizedEmailDuplicate(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Existing@Example.com ", + Username: "existing-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + err = repo.Create(ctx, &service.User{ + Email: "existing@example.com", + Username: "duplicate-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.ErrorIs(t, err, service.ErrEmailExists) +} + +func TestUserRepositoryUpdateRejectsNormalizedEmailDuplicate(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + first := &service.User{ + Email: " Existing@Example.com ", + Username: "existing-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, first)) + + second := &service.User{ + Email: "second@example.com", + Username: "second-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, second)) + + second.Email = " existing@example.com " + err := repo.Update(ctx, second) + require.ErrorIs(t, err, service.ErrEmailExists) +} + +func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) { + repo, client := newUserEntRepo(t) + ctx := context.Background() + + _, err := client.User.Create(). + SetEmail("Conflict@Example.com"). + SetUsername("conflict-user-1"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.User.Create(). + SetEmail(" conflict@example.com "). + SetUsername("conflict-user-2"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = repo.GetByEmail(ctx, "conflict@example.com") + require.Error(t, err) + require.ErrorContains(t, err, "normalized email lookup matched multiple users") +} diff --git a/backend/internal/service/admin_service_auth_identity_binding_test.go b/backend/internal/service/admin_service_auth_identity_binding_test.go index f8ce3935..719199f2 100644 --- a/backend/internal/service/admin_service_auth_identity_binding_test.go +++ b/backend/internal/service/admin_service_auth_identity_binding_test.go @@ -188,6 +188,93 @@ func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) { require.Equal(t, "second", identities[0].Metadata["source"]) } +func TestAdminServiceBindUserAuthIdentityReusesLegacyWeChatAliasRecords(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("wechat-alias@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetProviderSubject("union-legacy-123"). + SetMetadata(map[string]any{"source": "legacy"}). + Save(ctx) + require.NoError(t, err) + + legacyChannel, err := client.AuthIdentityChannel.Create(). + SetIdentityID(legacyIdentity.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetChannel("open"). + SetChannelAppID("wx-open"). + SetChannelSubject("openid-legacy-123"). + SetMetadata(map[string]any{"scene": "legacy"}). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-legacy-123", + Metadata: map[string]any{"source": "admin-repair"}, + Channel: &AdminBindAuthIdentityChannelInput{ + Channel: "open", + ChannelAppID: "wx-open", + ChannelSubject: "openid-legacy-123", + Metadata: map[string]any{"scene": "admin-repair"}, + }, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "wechat-main", result.ProviderKey) + require.NotNil(t, result.Channel) + require.Equal(t, "open", result.Channel.Channel) + + identity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID) + require.NoError(t, err) + require.Equal(t, "wechat-main", identity.ProviderKey) + require.Equal(t, "admin-repair", identity.Metadata["source"]) + + channel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID) + require.NoError(t, err) + require.Equal(t, "wechat-main", channel.ProviderKey) + require.Equal(t, legacyIdentity.ID, channel.IdentityID) + require.Equal(t, "admin-repair", channel.Metadata["scene"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderSubjectEQ("union-legacy-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + + channelCount, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open"), + authidentitychannel.ChannelSubjectEQ("openid-legacy-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, channelCount) +} + func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) { client := newAdminServiceAuthIdentityBindingTestClient(t) ctx := context.Background() diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 109d459d..0ad95356 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -406,13 +406,15 @@ func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testin }, }, } - svc := NewUserService(repo, nil, nil, nil) + invalidator := &mockAuthCacheInvalidator{} + svc := NewUserService(repo, nil, invalidator, nil) user, err := svc.UnbindUserAuthProvider(context.Background(), 12, "linuxdo") require.NoError(t, err) require.Equal(t, []string{"linuxdo"}, repo.unboundProviders) require.Equal(t, int64(12), user.ID) + require.Equal(t, []int64{12}, invalidator.invalidatedUserIDs) summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 12, user) require.NoError(t, err) diff --git a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue index 48b1b879..8a3af858 100644 --- a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue +++ b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue @@ -444,7 +444,14 @@ function providerIconClass(provider: UserAuthProvider): string { function providerSummary(provider: UserAuthProvider): string { if (provider === 'email') { - return currentUser.value?.email || '' + const email = currentUser.value?.email?.trim() || '' + if (!email) { + return '' + } + if (currentUser.value?.email_bound === false && email.endsWith('.invalid')) { + return '' + } + return email } return '' } diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue index 99559de5..4544c337 100644 --- a/frontend/src/components/user/profile/ProfileInfoCard.vue +++ b/frontend/src/components/user/profile/ProfileInfoCard.vue @@ -40,7 +40,7 @@

- {{ user?.email }} + {{ primaryEmailDisplay }}

props.user?.avatar_url?.trim() || '') const displayName = computed(() => props.user?.username?.trim() || props.user?.email?.trim() || t('profile.user')) +const primaryEmailDisplay = computed(() => { + const email = props.user?.email?.trim() || '' + if (!email) { + return '' + } + if (props.user?.email_bound === false && email.endsWith('.invalid')) { + return '' + } + return email +}) const avatarInitial = computed(() => displayName.value.charAt(0).toUpperCase() || 'U') const memberSinceLabel = computed(() => { const raw = props.user?.created_at?.trim() @@ -229,7 +239,7 @@ const memberSinceLabel = computed(() => { const providerLabels = computed>(() => ({ email: t('profile.authBindings.providers.email'), linuxdo: t('profile.authBindings.providers.linuxdo'), - oidc: t('profile.authBindings.providers.oidc', { providerName: 'OIDC' }), + oidc: t('profile.authBindings.providers.oidc', { providerName: props.oidcProviderName }), wechat: t('profile.authBindings.providers.wechat') })) diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts index 9d8c88d4..77d2219e 100644 --- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts +++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts @@ -335,6 +335,29 @@ describe('ProfileIdentityBindingsSection', () => { expect(wrapper.get('[data-testid="profile-binding-email-input"]').exists()).toBe(true) }) + it('does not show a synthetic oauth-only email as the bound email summary', () => { + const wrapper = mount(ProfileIdentityBindingsSection, { + global: { + plugins: [pinia], + }, + props: { + user: createUser({ + email: 'legacy-user@linuxdo-connect.invalid', + email_bound: false, + auth_bindings: { + email: { bound: false }, + }, + }), + linuxdoEnabled: false, + oidcEnabled: false, + wechatEnabled: false, + }, + }) + + expect(wrapper.text()).not.toContain('legacy-user@linuxdo-connect.invalid') + expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Not bound') + }) + it('keeps the email form available for replacing a bound primary email', async () => { userApiMocks.sendEmailBindingCode.mockResolvedValue(undefined) userApiMocks.bindEmailIdentity.mockResolvedValue( diff --git a/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts index 229c27cb..c7e60d9b 100644 --- a/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts +++ b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts @@ -111,6 +111,47 @@ describe('ProfileInfoCard', () => { expect(wrapper.text()).toContain('Username synced from LinuxDo') }) + it('uses the configured OIDC provider name in source hints', () => { + const wrapper = mount(ProfileInfoCard, { + props: { + user: createUser({ + profile_sources: { + username: { provider: 'oidc', source: 'oidc' } + } + }), + oidcProviderName: 'ExampleID' + }, + global: { + stubs: { + Icon: true + } + } + }) + + expect(wrapper.text()).toContain('Username synced from ExampleID') + }) + + it('does not display synthetic oauth-only emails as a real bound email', () => { + const wrapper = mount(ProfileInfoCard, { + props: { + user: createUser({ + email: 'legacy-user@oidc-connect.invalid', + email_bound: false, + auth_bindings: { + email: { bound: false } + } + }) + }, + global: { + stubs: { + Icon: true + } + } + }) + + expect(wrapper.text()).not.toContain('legacy-user@oidc-connect.invalid') + }) + it('renders the approved overview hero and two-column content shell', () => { const wrapper = mount(ProfileInfoCard, { props: { From 6696e61c7b158ce415ecbca0f3f00a7906df9957 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 13:19:41 +0800 Subject: [PATCH 17/31] fix(frontend): preserve callback recovery state --- frontend/src/router/__tests__/guards.spec.ts | 16 ++++++- frontend/src/router/index.ts | 2 +- .../src/views/auth/LinuxDoCallbackView.vue | 17 ++++--- frontend/src/views/auth/OidcCallbackView.vue | 17 ++++--- .../src/views/auth/WechatCallbackView.vue | 17 ++++--- .../views/auth/WechatPaymentCallbackView.vue | 28 +++++++++++- .../__tests__/LinuxDoCallbackView.spec.ts | 44 +++++++++++++++++++ .../auth/__tests__/OidcCallbackView.spec.ts | 44 +++++++++++++++++++ .../auth/__tests__/WechatCallbackView.spec.ts | 44 +++++++++++++++++++ .../WechatPaymentCallbackView.spec.ts | 23 ++++++++++ 10 files changed, 229 insertions(+), 23 deletions(-) diff --git a/frontend/src/router/__tests__/guards.spec.ts b/frontend/src/router/__tests__/guards.spec.ts index bdf07b18..076b943d 100644 --- a/frontend/src/router/__tests__/guards.spec.ts +++ b/frontend/src/router/__tests__/guards.spec.ts @@ -78,7 +78,7 @@ function simulateGuard( return authState.isAdmin ? '/admin/dashboard' : '/dashboard' } if (authState.backendModeEnabled && !authState.isAuthenticated) { - const allowed = ['/login', '/key-usage', '/setup'] + const allowed = ['/login', '/key-usage', '/setup', '/payment/result'] const callbackPaths = [ '/auth/callback', '/auth/linuxdo/callback', @@ -127,7 +127,7 @@ function simulateGuard( if (authState.isAuthenticated && authState.isAdmin) { return null } - const allowed = ['/login', '/key-usage', '/setup'] + const allowed = ['/login', '/key-usage', '/setup', '/payment/result'] const callbackPaths = [ '/auth/callback', '/auth/linuxdo/callback', @@ -462,6 +462,18 @@ describe('路由守卫逻辑', () => { expect(redirect).toBeNull() }) + it('unauthenticated: /payment/result is allowed', () => { + const authState: MockAuthState = { + isAuthenticated: false, + isAdmin: false, + isSimpleMode: false, + backendModeEnabled: true, + hasPendingAuthSession: false, + } + const redirect = simulateGuard('/payment/result', { requiresAuth: false }, authState) + expect(redirect).toBeNull() + }) + it('unauthenticated: /register is allowed when a pending auth session exists', () => { const authState: MockAuthState = { isAuthenticated: false, diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index b7fcf475..b97ccb5d 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -542,7 +542,7 @@ let authInitialized = false const navigationLoading = useNavigationLoadingState() // 延迟初始化预加载,传入 router 实例 let routePrefetch: ReturnType | null = null -const BACKEND_MODE_ALLOWED_PATHS = ['/login', '/key-usage', '/setup'] +const BACKEND_MODE_ALLOWED_PATHS = ['/login', '/key-usage', '/setup', '/payment/result'] const BACKEND_MODE_CALLBACK_PATHS = [ '/auth/callback', '/auth/linuxdo/callback', diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index 4009454c..2cf4e694 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -603,6 +603,14 @@ async function finalizePendingAccountResponse(completion: LinuxDoPendingActionRe return } + if (completion.auth_result === 'pending_session') { + needsInvitation.value = false + needsAdoptionConfirmation.value = false + isProcessing.value = false + persistPendingAuthSession(redirect) + return + } + await finalizeCompletion(completion, redirect) } @@ -612,9 +620,9 @@ async function handleSubmitInvitation() { isSubmitting.value = true try { - const tokenData = legacyPendingOAuthToken.value + const completion: LinuxDoPendingActionResponse = legacyPendingOAuthToken.value ? ( - await apiClient.post('/auth/oauth/linuxdo/complete-registration', { + await apiClient.post('/auth/oauth/linuxdo/complete-registration', { pending_oauth_token: legacyPendingOAuthToken.value, invitation_code: invitationCode.value.trim(), ...serializeAdoptionDecision(currentAdoptionDecision()) @@ -624,10 +632,7 @@ async function handleSubmitInvitation() { invitationCode.value.trim(), currentAdoptionDecision() ) - persistOAuthTokenContext(tokenData) - await authStore.setToken(tokenData.access_token) - appStore.showSuccess(t('auth.loginSuccess')) - await router.replace(redirectTo.value) + await finalizePendingAccountResponse(completion) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { message?: string } } } invitationError.value = diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue index d03c70fb..873022e1 100644 --- a/frontend/src/views/auth/OidcCallbackView.vue +++ b/frontend/src/views/auth/OidcCallbackView.vue @@ -632,6 +632,14 @@ async function finalizePendingAccountResponse(completion: PendingOidcCompletion) return } + if (completion.auth_result === 'pending_session') { + needsInvitation.value = false + needsAdoptionConfirmation.value = false + isProcessing.value = false + persistPendingAuthSession(redirect) + return + } + await finalizeCompletion(completion, redirect) } @@ -641,9 +649,9 @@ async function handleSubmitInvitation() { isSubmitting.value = true try { - const tokenData = legacyPendingOAuthToken.value + const completion: PendingOidcCompletion = legacyPendingOAuthToken.value ? ( - await apiClient.post('/auth/oauth/oidc/complete-registration', { + await apiClient.post('/auth/oauth/oidc/complete-registration', { pending_oauth_token: legacyPendingOAuthToken.value, invitation_code: invitationCode.value.trim(), ...serializeAdoptionDecision(currentAdoptionDecision()) @@ -653,10 +661,7 @@ async function handleSubmitInvitation() { invitationCode.value.trim(), currentAdoptionDecision() ) - persistOAuthTokenContext(tokenData) - await authStore.setToken(tokenData.access_token) - appStore.showSuccess(t('auth.loginSuccess')) - await router.replace(redirectTo.value) + await finalizePendingAccountResponse(completion) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { message?: string } } } invitationError.value = diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue index 9a71f62b..bae20df8 100644 --- a/frontend/src/views/auth/WechatCallbackView.vue +++ b/frontend/src/views/auth/WechatCallbackView.vue @@ -840,6 +840,14 @@ async function finalizePendingAccountResponse(completion: PendingWeChatCompletio return } + if (completion.auth_result === 'pending_session') { + needsInvitation.value = false + needsAdoptionConfirmation.value = false + isProcessing.value = false + persistPendingAuthSession(redirect) + return + } + await finalizeCompletion(completion, redirect) } @@ -849,9 +857,9 @@ async function handleSubmitInvitation() { isSubmitting.value = true try { - const tokenData = legacyPendingOAuthToken.value + const completion: PendingWeChatCompletion = legacyPendingOAuthToken.value ? ( - await apiClient.post('/auth/oauth/wechat/complete-registration', { + await apiClient.post('/auth/oauth/wechat/complete-registration', { pending_oauth_token: legacyPendingOAuthToken.value, invitation_code: invitationCode.value.trim(), ...serializeAdoptionDecision(currentAdoptionDecision()) @@ -861,10 +869,7 @@ async function handleSubmitInvitation() { invitationCode.value.trim(), currentAdoptionDecision() ) - persistOAuthTokenContext(tokenData) - await authStore.setToken(tokenData.access_token) - appStore.showSuccess(t('auth.loginSuccess')) - await router.replace(redirectTo.value) + await finalizePendingAccountResponse(completion) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { message?: string } } } invitationError.value = diff --git a/frontend/src/views/auth/WechatPaymentCallbackView.vue b/frontend/src/views/auth/WechatPaymentCallbackView.vue index 53599ec3..225c84e1 100644 --- a/frontend/src/views/auth/WechatPaymentCallbackView.vue +++ b/frontend/src/views/auth/WechatPaymentCallbackView.vue @@ -85,6 +85,12 @@ function normalizeRedirectPath(path: string | null | undefined): string { return value } +function appendQueryParam(query: Record, key: string, value: string) { + if (value) { + query[key] = value + } +} + function goBackToPayment() { void router.replace('/purchase') } @@ -102,12 +108,19 @@ onMounted(async () => { } const resumeToken = readParam('wechat_resume_token') + const openid = readParam('openid') + const state = readParam('state') + const scope = readParam('scope') + const paymentType = readParam('payment_type') + const amount = readParam('amount') + const orderType = readParam('order_type') + const planId = readParam('plan_id') const redirectURL = new URL( normalizeRedirectPath(readParam('redirect')), window.location.origin, ) - if (!resumeToken) { + if (!resumeToken && !openid) { errorMessage.value = t('auth.wechatPayment.callbackMissingResumeToken') return } @@ -115,7 +128,18 @@ onMounted(async () => { const query: Record = { ...Object.fromEntries(redirectURL.searchParams.entries()), wechat_resume: '1', - wechat_resume_token: resumeToken, + } + + if (resumeToken) { + query.wechat_resume_token = resumeToken + } else { + query.openid = openid + appendQueryParam(query, 'state', state) + appendQueryParam(query, 'scope', scope) + appendQueryParam(query, 'payment_type', paymentType) + appendQueryParam(query, 'amount', amount) + appendQueryParam(query, 'order_type', orderType) + appendQueryParam(query, 'plan_id', planId) } await router.replace({ diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts index 29aef613..3fee2c27 100644 --- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts @@ -409,6 +409,50 @@ describe('LinuxDoCallbackView', () => { }) }) + it('keeps the oauth flow active when complete-registration returns another pending step', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + completeLinuxDoOAuthRegistration.mockResolvedValue({ + auth_result: 'pending_session', + step: 'choose_account_action_required', + redirect: '/dashboard', + email: 'fresh@example.com', + resolved_email: 'fresh@example.com', + force_email_on_signup: true, + adoption_required: true + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(completeLinuxDoOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + expect(wrapper.text()).toContain('auth.oauthFlow.bindExistingAccount') + expect(wrapper.text()).toContain('auth.oauthFlow.createNewAccount') + }) + it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { getPublicSettings.mockResolvedValue({ invitation_code_enabled: true, diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts index 0167604c..ec89512b 100644 --- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts @@ -385,6 +385,50 @@ describe('OidcCallbackView', () => { }) }) + it('keeps the oauth flow active when complete-registration returns another pending step', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + completeOIDCOAuthRegistration.mockResolvedValue({ + auth_result: 'pending_session', + step: 'choose_account_action_required', + redirect: '/dashboard', + email: 'fresh@example.com', + resolved_email: 'fresh@example.com', + force_email_on_signup: true, + adoption_required: true + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(completeOIDCOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + expect(wrapper.text()).toContain('auth.oauthFlow.bindExistingAccount') + expect(wrapper.text()).toContain('auth.oauthFlow.createNewAccount') + }) + it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { getPublicSettings.mockResolvedValue({ oidc_oauth_provider_name: 'ExampleID', diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts index cc72107d..da41c987 100644 --- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts @@ -517,6 +517,50 @@ describe('WechatCallbackView', () => { expect(replaceMock).toHaveBeenCalledWith('/subscriptions') }) + it('keeps the oauth flow active when complete-registration returns another pending step', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + completeWeChatOAuthRegistrationMock.mockResolvedValue({ + auth_result: 'pending_session', + step: 'choose_account_action_required', + redirect: '/dashboard', + email: 'fresh@example.com', + resolved_email: 'fresh@example.com', + force_email_on_signup: true, + adoption_required: true, + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(completeWeChatOAuthRegistrationMock).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: true, + adoptAvatar: true, + }) + expect(setTokenMock).not.toHaveBeenCalled() + expect(replaceMock).not.toHaveBeenCalled() + expect(wrapper.get('[data-testid="wechat-choice-bind-existing"]').exists()).toBe(true) + expect(wrapper.get('[data-testid="wechat-choice-create-account"]').exists()).toBe(true) + }) + it('offers existing-account email collection during invitation flow', async () => { exchangePendingOAuthCompletionMock.mockResolvedValue({ error: 'invitation_required', diff --git a/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts index 93cd0e94..822a083b 100644 --- a/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts @@ -79,6 +79,29 @@ describe('WechatPaymentCallbackView', () => { }) }) + it('redirects legacy openid callback payloads back to purchase while preserving resume context', async () => { + locationState.current.hash = + '#openid=openid-123&state=oauth-state&scope=snsapi_base&payment_type=wxpay_direct&amount=128&order_type=subscription&plan_id=7&redirect=%2Fpayment%3Ffrom%3Dwechat' + + mount(WechatPaymentCallbackView) + await flushPromises() + + expect(replaceMock).toHaveBeenCalledWith({ + path: '/purchase', + query: { + from: 'wechat', + wechat_resume: '1', + openid: 'openid-123', + state: 'oauth-state', + scope: 'snsapi_base', + payment_type: 'wxpay_direct', + amount: '128', + order_type: 'subscription', + plan_id: '7', + }, + }) + }) + it('shows an error when the callback payload is missing the resume token', async () => { locationState.current.hash = '#payment_type=wxpay' From 01a991f56ff50a628d66d2321d011e46c1217d61 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 13:22:33 +0800 Subject: [PATCH 18/31] fix(test): restore identity repo integration imports --- .../repository/user_profile_identity_repo_contract_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go index 69a25fbe..d4f9e8b3 100644 --- a/backend/internal/repository/user_profile_identity_repo_contract_test.go +++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go @@ -10,6 +10,8 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" ) From 3d29f7c2fac9798bb016c0f3a7a9f13091421ecc Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 13:30:34 +0800 Subject: [PATCH 19/31] fix(auth): invalidate access tokens on session revoke --- backend/internal/handler/auth_handler.go | 2 +- .../handler/auth_session_revocation_test.go | 61 +++++++++++++++++++ .../handler/auth_wechat_oauth_test.go | 12 ---- backend/internal/handler/user_handler.go | 2 +- backend/internal/handler/user_handler_test.go | 12 ++-- backend/internal/service/auth_service.go | 20 ++++++ 6 files changed, 90 insertions(+), 19 deletions(-) create mode 100644 backend/internal/handler/auth_session_revocation_test.go diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index ca3a5a77..dc68a466 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -719,7 +719,7 @@ func (h *AuthHandler) RevokeAllSessions(c *gin.Context) { return } - if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil { + if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil { slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err) response.InternalError(c, "Failed to revoke sessions") return diff --git a/backend/internal/handler/auth_session_revocation_test.go b/backend/internal/handler/auth_session_revocation_test.go new file mode 100644 index 00000000..1924cb81 --- /dev/null +++ b/backend/internal/handler/auth_session_revocation_test.go @@ -0,0 +1,61 @@ +//go:build unit + +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 29, + Email: "session@example.com", + Username: "session-user", + Role: service.RoleUser, + Status: service.StatusActive, + TokenVersion: 7, + }, + } + refreshTokenCache := &userHandlerRefreshTokenCacheStub{} + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) + handler := &AuthHandler{authService: authService} + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/auth/revoke-all-sessions", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 29}) + + handler.RevokeAllSessions(c) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, []int64{29}, refreshTokenCache.revokedUserIDs) + require.Equal(t, int64(8), repo.user.TokenVersion) + + var resp struct { + Code int `json:"code"` + Data struct { + Message string `json:"message"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "All sessions have been revoked. Please log in again.", resp.Data.Message) +} diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index b8bd21ce..d303bd42 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -1346,18 +1346,6 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, }, client } -func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) { - t.Helper() - - parsed, err := url.Parse(location) - require.NoError(t, err) - - fragment, err := url.ParseQuery(parsed.Fragment) - require.NoError(t, err) - require.Equal(t, errorCode, fragment.Get("error")) - require.Equal(t, errorMessage, fragment.Get("error_message")) -} - type wechatOAuthSettingRepoStub struct { values map[string]string } diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 80dcc5ce..867d8c9e 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -259,7 +259,7 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) { return } if h.authService != nil { - if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil { + if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil { response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index 87e168c0..c212603b 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -593,11 +593,12 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure repo := &userHandlerRepoStub{ user: &service.User{ - ID: 23, - Email: "identity@example.com", - Username: "identity-user", - Role: service.RoleUser, - Status: service.StatusActive, + ID: 23, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + TokenVersion: 4, }, identities: []service.UserAuthIdentityRecord{ { @@ -632,6 +633,7 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure require.Equal(t, http.StatusOK, recorder.Code) require.Equal(t, []int64{23}, refreshTokenCache.revokedUserIDs) + require.Equal(t, int64(5), repo.user.TokenVersion) } func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) { diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 6d61894b..efe08644 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -1467,6 +1467,26 @@ func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) e return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID) } +// RevokeAllUserTokens invalidates both stateless access tokens and refresh sessions. +// Access/refresh token verification both depend on TokenVersion, so bumping it provides +// immediate revocation even if refresh-token cache cleanup later fails. +func (s *AuthService) RevokeAllUserTokens(ctx context.Context, userID int64) error { + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return fmt.Errorf("get user: %w", err) + } + + user.TokenVersion++ + if err := s.userRepo.Update(ctx, user); err != nil { + return fmt.Errorf("update user: %w", err) + } + + if err := s.RevokeAllUserSessions(ctx, userID); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after token invalidation for user %d: %v", userID, err) + } + return nil +} + // hashToken 计算Token的SHA256哈希 func hashToken(token string) string { hash := sha256.Sum256([]byte(token)) From 36aed35957d74081b26d2897d4bebf89613e6e11 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 14:56:56 +0800 Subject: [PATCH 20/31] fix(auth): harden oauth identity upgrade paths --- .../ent/schema/auth_identity_schema_test.go | 44 +++ backend/ent/schema/user.go | 11 +- backend/internal/config/config.go | 50 ++-- backend/internal/config/config_test.go | 15 + ...tting_handler_auth_source_defaults_test.go | 69 +++++ .../internal/handler/auth_linuxdo_oauth.go | 48 ++-- .../handler/auth_linuxdo_oauth_test.go | 73 ++++- .../handler/auth_oauth_pending_flow.go | 82 +++++- .../handler/auth_oauth_pending_flow_test.go | 15 +- backend/internal/handler/auth_oidc_oauth.go | 30 +- .../internal/handler/auth_oidc_oauth_test.go | 66 ++++- ...ntity_legacy_migration_integration_test.go | 252 +++++++++++++++++ .../internal/repository/migrations_runner.go | 92 +++++- .../migrations_runner_checksum_test.go | 33 +++ .../migrations_runner_extra_test.go | 2 + .../repository/migrations_runner_notx_test.go | 78 ++++++ .../migrations_schema_integration_test.go | 52 ++++ .../repository/user_profile_identity_repo.go | 263 +++++++++++++++--- .../user_profile_identity_repo_unit_test.go | 212 ++++++++++++++ backend/internal/repository/user_repo.go | 89 ++++-- .../user_repo_email_lookup_unit_test.go | 83 +++++- .../internal/service/auth_oauth_email_flow.go | 14 +- .../service/auth_oauth_email_flow_test.go | 74 +++++ .../service/auth_pending_identity_service.go | 194 +++++++++++-- .../auth_pending_identity_service_test.go | 102 +++++++ backend/internal/service/auth_service.go | 22 +- backend/internal/service/setting_service.go | 37 ++- .../setting_service_oidc_config_test.go | 62 ++++- ...nding_auth_and_provider_default_grants.sql | 9 +- ...auth_identity_legacy_external_backfill.sql | 137 ++++++--- ...dentity_legacy_external_safety_reports.sql | 246 +++++++++++++--- ...y_auth_source_grant_on_signup_defaults.sql | 71 ++++- 32 files changed, 2365 insertions(+), 262 deletions(-) create mode 100644 backend/internal/repository/user_profile_identity_repo_unit_test.go diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go index de55dd69..fbb93236 100644 --- a/backend/ent/schema/auth_identity_schema_test.go +++ b/backend/ent/schema/auth_identity_schema_test.go @@ -3,7 +3,9 @@ package schema import ( "testing" + "entgo.io/ent" "entgo.io/ent/entc/load" + "entgo.io/ent/schema/field" "github.com/stretchr/testify/require" ) @@ -74,6 +76,17 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) { userSchema := requireSchema(t, schemas, "User") requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at") + signupSource := requireSchemaField(t, userSchema, "signup_source") + require.Equal(t, field.TypeString, signupSource.Info.Type) + require.True(t, signupSource.Default) + require.Equal(t, "email", signupSource.DefaultValue) + require.Equal(t, 1, signupSource.Validators) + + validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source") + for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} { + require.NoError(t, validator(value)) + } + require.Error(t, validator("github")) } func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema { @@ -98,6 +111,37 @@ func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) { } } +func requireSchemaField(t *testing.T, schema *load.Schema, name string) *load.Field { + t.Helper() + + for _, schemaField := range schema.Fields { + if schemaField.Name == name { + return schemaField + } + } + + require.Failf(t, "missing schema field", "schema %s should include field %s", schema.Name, name) + return nil +} + +func requireStringFieldValidator(t *testing.T, fields []ent.Field, name string) func(string) error { + t.Helper() + + for _, entField := range fields { + descriptor := entField.Descriptor() + if descriptor.Name != name { + continue + } + require.NotEmpty(t, descriptor.Validators, "field %s should include a validator", name) + validator, ok := descriptor.Validators[0].(func(string) error) + require.True(t, ok, "field %s validator should be func(string) error", name) + return validator + } + + require.Failf(t, "missing field validator", "schema should include field %s", name) + return nil +} + func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) { t.Helper() diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index f307bda8..c0f0bdc1 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -1,6 +1,8 @@ package schema import ( + "fmt" + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" "github.com/Wei-Shaw/sub2api/internal/domain" @@ -73,7 +75,14 @@ func (User) Fields() []ent.Field { Optional(). Nillable(), field.String("signup_source"). - MaxLen(20). + Validate(func(value string) error { + switch value { + case "email", "linuxdo", "wechat", "oidc": + return nil + default: + return fmt.Errorf("must be one of email, linuxdo, wechat, oidc") + } + }). Default("email"), field.Time("last_login_at"). Optional(). diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index d47eadd4..87263db0 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -211,25 +211,27 @@ type WeChatConnectConfig struct { } type OIDCConnectConfig struct { - Enabled bool `mapstructure:"enabled"` - ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等 - ClientID string `mapstructure:"client_id"` - ClientSecret string `mapstructure:"client_secret"` - IssuerURL string `mapstructure:"issuer_url"` - DiscoveryURL string `mapstructure:"discovery_url"` - AuthorizeURL string `mapstructure:"authorize_url"` - TokenURL string `mapstructure:"token_url"` - UserInfoURL string `mapstructure:"userinfo_url"` - JWKSURL string `mapstructure:"jwks_url"` - Scopes string `mapstructure:"scopes"` // 默认 "openid email profile" - RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) - FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback) - TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none - UsePKCE bool `mapstructure:"use_pkce"` - ValidateIDToken bool `mapstructure:"validate_id_token"` - AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256" - ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120 - RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false + Enabled bool `mapstructure:"enabled"` + ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等 + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + IssuerURL string `mapstructure:"issuer_url"` + DiscoveryURL string `mapstructure:"discovery_url"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + JWKSURL string `mapstructure:"jwks_url"` + Scopes string `mapstructure:"scopes"` // 默认 "openid email profile" + RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback) + TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none + UsePKCE bool `mapstructure:"use_pkce"` + ValidateIDToken bool `mapstructure:"validate_id_token"` + UsePKCEExplicit bool `mapstructure:"-" yaml:"-"` + ValidateIDTokenExplicit bool `mapstructure:"-" yaml:"-"` + AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256" + ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120 + RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false // 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。 // 为空时,服务端会尝试一组常见字段名。 @@ -329,6 +331,14 @@ func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool { return !hasNewEnv } +func hasExplicitConfigOrEnv(configKey, envKey string) bool { + if viper.InConfig(configKey) { + return true + } + _, ok := os.LookupEnv(envKey) + return ok +} + func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) { if cfg == nil { return @@ -1262,6 +1272,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath) cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath) cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath) + cfg.OIDC.UsePKCEExplicit = hasExplicitConfigOrEnv("oidc_connect.use_pkce", "OIDC_CONNECT_USE_PKCE") + cfg.OIDC.ValidateIDTokenExplicit = hasExplicitConfigOrEnv("oidc_connect.validate_id_token", "OIDC_CONNECT_VALIDATE_ID_TOKEN") cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 8b59ef5f..6ba86aa1 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -254,6 +254,21 @@ func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) { require.NoError(t, err) require.True(t, cfg.OIDC.UsePKCE) require.True(t, cfg.OIDC.ValidateIDToken) + require.False(t, cfg.OIDC.UsePKCEExplicit) + require.False(t, cfg.OIDC.ValidateIDTokenExplicit) +} + +func TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("OIDC_CONNECT_USE_PKCE", "false") + t.Setenv("OIDC_CONNECT_VALIDATE_ID_TOKEN", "false") + + cfg, err := Load() + require.NoError(t, err) + require.False(t, cfg.OIDC.UsePKCE) + require.False(t, cfg.OIDC.ValidateIDToken) + require.True(t, cfg.OIDC.UsePKCEExplicit) + require.True(t, cfg.OIDC.ValidateIDTokenExplicit) } func TestLoadForcedCodexInstructionsTemplate(t *testing.T) { diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go index 8045d0c9..9a33a93a 100644 --- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -335,6 +335,75 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla require.Equal(t, false, data["oidc_connect_validate_id_token"]) } +func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyOIDCConnectEnabled: "true", + service.SettingKeyOIDCConnectProviderName: "OIDC", + service.SettingKeyOIDCConnectClientID: "oidc-client", + service.SettingKeyOIDCConnectClientSecret: "oidc-secret", + service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com", + service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth", + service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token", + service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo", + service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks", + service.SettingKeyOIDCConnectScopes: "openid email profile", + service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", + service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", + service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256", + service.SettingKeyOIDCConnectClockSkewSeconds: "120", + service.SettingKeyOIDCConnectRequireEmailVerified: "false", + service.SettingKeyOIDCConnectUserInfoEmailPath: "", + service.SettingKeyOIDCConnectUserInfoIDPath: "", + service.SettingKeyOIDCConnectUserInfoUsernamePath: "", + }, + } + svc := service.NewSettingService(repo, &config.Config{ + Default: config.DefaultConfig{UserConcurrency: 5}, + OIDC: config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/auth", + TokenURL: "https://issuer.example.com/token", + UserInfoURL: "https://issuer.example.com/userinfo", + JWKSURL: "https://issuer.example.com/jwks", + Scopes: "openid email profile", + RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + }, + }) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "oidc_connect_enabled": true, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE]) + require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken]) +} + func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { gin.SetMode(gin.TestMode) repo := &settingHandlerRepoStub{ diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index a7e77c09..2ef05963 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -355,15 +355,20 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri } userEntity, err := client.User.Query(). - Where(dbuser.EmailEqualFold(email)). - Only(ctx) + Where(userNormalizedEmailPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) if err != nil { - if dbent.IsNotFound(err) { - return nil, nil - } return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err) } - return userEntity, nil + switch len(userEntity) { + case 0: + return nil, nil + case 1: + return userEntity[0], nil + default: + return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users") + } } func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession( @@ -411,9 +416,15 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession( completionResponse["choice_reason"] = "force_email_on_signup" } + var targetUserID *int64 + if compatEmailUser != nil && compatEmailUser.ID > 0 { + targetUserID = &compatEmailUser.ID + } + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ Intent: oauthIntentLogin, Identity: identity, + TargetUserID: targetUserID, ResolvedEmail: resolvedChoiceEmail, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, @@ -490,9 +501,13 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) - if err != nil { - response.ErrorFrom(c, err) + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil { + respondPendingOAuthBindingApplyError(c, err) return } decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ @@ -503,17 +518,16 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) - return - } - h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) - if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { - clearOAuthPendingSessionCookie(c, secureCookie) - clearOAuthPendingBrowserCookie(c, secureCookie) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { response.ErrorFrom(c, err) return } + if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index d535c178..8b01ab41 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -508,7 +508,7 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test ctx := context.Background() existingUser, err := client.User.Create(). - SetEmail("legacy@example.com"). + SetEmail(" Legacy@Example.com "). SetUsername("legacy-user"). SetPasswordHash("hash"). SetRole(service.RoleUser). @@ -539,16 +539,17 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test Only(ctx) require.NoError(t, err) require.Equal(t, oauthIntentLogin, session.Intent) - require.Nil(t, session.TargetUserID) - require.Equal(t, existingUser.Email, session.ResolvedEmail) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, strings.TrimSpace(existingUser.Email), session.ResolvedEmail) require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) require.True(t, ok) require.Equal(t, "/dashboard", completion["redirect"]) require.Equal(t, oauthPendingChoiceStep, completion["step"]) - require.Equal(t, existingUser.Email, completion["email"]) - require.Equal(t, existingUser.Email, completion["existing_account_email"]) + require.Equal(t, strings.TrimSpace(existingUser.Email), completion["email"]) + require.Equal(t, strings.TrimSpace(existingUser.Email), completion["existing_account_email"]) require.Equal(t, true, completion["existing_account_bindable"]) require.Equal(t, "compat_email_match", completion["choice_reason"]) _, hasAccessToken := completion["access_token"] @@ -943,6 +944,68 @@ func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *te require.False(t, decision.AdoptAvatar) } +func TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingOwner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(existingOwner.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-conflict-subject"). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-conflict-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-conflict-subject"). + SetResolvedEmail("linuxdo-conflict-subject@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-conflict-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-conflict-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusConflict, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"]) + + userCount, err := client.User.Query(). + Where(dbuser.EmailEQ("linuxdo-conflict-subject@linuxdo-connect.invalid")). + Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { t.Helper() handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 7be01e74..ab854d24 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -519,7 +519,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) { email := strings.TrimSpace(strings.ToLower(req.Email)) if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil { - session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email) if err != nil { response.ErrorFrom(c, err) return @@ -704,6 +704,38 @@ func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email return matches[0], nil } +func ensurePendingOAuthRegistrationIdentityAvailable(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) error { + if client == nil || session == nil { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil + } + return err + } + if identity == nil || identity.UserID <= 0 { + return nil + } + + activeOwner, err := findActiveUserByID(ctx, client, identity.UserID) + if err != nil { + return err + } + if activeOwner != nil { + return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + return nil +} + func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { if session == nil { return nil @@ -1206,6 +1238,38 @@ func consumePendingOAuthBrowserSessionTx( return nil } +func applyPendingOAuthAdoptionAndConsumeSession( + ctx context.Context, + client *dbent.Client, + authService *service.AuthService, + userService *service.UserService, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + userID int64, +) error { + if client == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + if session == nil || userID <= 0 { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + + tx, err := client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := applyPendingOAuthAdoption(txCtx, client, authService, userService, session, decision, &userID); err != nil { + return err + } + if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil { + return err + } + return tx.Commit() +} + func applyPendingOAuthAdoption( ctx context.Context, client *dbent.Client, @@ -1448,16 +1512,21 @@ func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState( c *gin.Context, client *dbent.Client, session *dbent.PendingAuthSession, + targetUser *dbent.User, email string, ) (*dbent.PendingAuthSession, error) { completionResponse := pendingOAuthChoiceCompletionResponse(session, email) + var targetUserID *int64 + if targetUser != nil && targetUser.ID > 0 { + targetUserID = &targetUser.ID + } session, err := updatePendingOAuthSessionProgress( c.Request.Context(), client, session, strings.TrimSpace(session.Intent), email, - nil, + targetUserID, completionResponse, ) if err != nil { @@ -1601,7 +1670,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) } } if existingUser != nil { - session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email) if err != nil { response.ErrorFrom(c, err) return @@ -1624,7 +1693,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) ) if err != nil { if errors.Is(err, service.ErrEmailExists) { - session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) + existingUser, lookupErr := findUserByNormalizedEmail(c.Request.Context(), client, email) + if lookupErr != nil { + response.ErrorFrom(c, lookupErr) + return + } + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index bc8fe7eb..9f9e497b 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -1045,7 +1045,7 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t * handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") ctx := context.Background() - _, err := client.User.Create(). + existingUser, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash("hash"). @@ -1099,7 +1099,8 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t * storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.Equal(t, oauthIntentLogin, storedSession.Intent) - require.Nil(t, storedSession.TargetUserID) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) require.Nil(t, storedSession.ConsumedAt) @@ -1118,7 +1119,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") ctx := context.Background() - _, err := client.User.Create(). + existingUser, err := client.User.Create(). SetEmail(" Owner@Example.com "). SetUsername("owner-user"). SetPasswordHash("hash"). @@ -1164,7 +1165,8 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) - require.Nil(t, storedSession.TargetUserID) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) } @@ -1172,7 +1174,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") ctx := context.Background() - _, err := client.User.Create(). + existingUser, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash("hash"). @@ -1220,7 +1222,8 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.Equal(t, oauthIntentLogin, storedSession.Intent) - require.Nil(t, storedSession.TargetUserID) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) } diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 3c67e421..0ac8871b 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -563,10 +563,15 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession( if compatEmailUser != nil { resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email) } + var targetUserID *int64 + if compatEmailUser != nil && compatEmailUser.ID > 0 { + targetUserID = &compatEmailUser.ID + } return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ Intent: oauthIntentLogin, Identity: identity, + TargetUserID: targetUserID, ResolvedEmail: resolvedChoiceEmail, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, @@ -643,9 +648,13 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) - if err != nil { - response.ErrorFrom(c, err) + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil { + respondPendingOAuthBindingApplyError(c, err) return } decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ @@ -656,17 +665,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) - return - } - h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) - if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { - clearOAuthPendingSessionCookie(c, secureCookie) - clearOAuthPendingBrowserCookie(c, secureCookie) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { response.ErrorFrom(c, err) return } + if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index c2855dc9..3216d51e 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -438,7 +438,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing Only(ctx) require.NoError(t, err) require.Equal(t, oauthIntentLogin, session.Intent) - require.Nil(t, session.TargetUserID) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) require.Equal(t, existingUser.Email, session.ResolvedEmail) require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) @@ -862,6 +863,69 @@ func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testi require.False(t, decision.AdoptAvatar) } +func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingOwner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(existingOwner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-conflict-subject"). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-conflict-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-conflict-subject"). + SetResolvedEmail("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-conflict-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-conflict-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusConflict, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"]) + + userCount, err := client.User.Query(). + Where(dbuser.EmailEQ("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid")). + Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + type oidcProviderFixture struct { Subject string PreferredUsername string diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go index 41b64de7..e64934c5 100644 --- a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go +++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go @@ -576,6 +576,258 @@ FROM auth_identity_migration_reports require.Equal(t, beforeCount, afterCount) } +func TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migrationSQL, err := os.ReadFile(migrationPath) + require.NoError(t, err) + + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) + + var linuxDoFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoFirstUserID)) + + var linuxDoSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoSecondUserID)) + + var wechatFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatFirstUserID)) + + var wechatSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatSecondUserID)) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}') +RETURNING id +`, linuxDoFirstUserID).Scan(new(int64))) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}') +RETURNING id +`, linuxDoSecondUserID).Scan(new(int64))) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}') +RETURNING id +`, wechatFirstUserID).Scan(new(int64))) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}') +RETURNING id +`, wechatSecondUserID).Scan(new(int64))) + + _, err = tx.ExecContext(ctx, string(migrationSQL)) + require.NoError(t, err) + + var linuxDoIdentityCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-ambiguous-subject' +`).Scan(&linuxDoIdentityCount)) + require.Zero(t, linuxDoIdentityCount) + + var wechatIdentityCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE provider_type = 'wechat' + AND provider_key = 'wechat-main' + AND provider_subject = 'union-ambiguous-subject' +`).Scan(&wechatIdentityCount)) + require.Zero(t, wechatIdentityCount) + + var wechatChannelCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_channels +WHERE provider_type = 'wechat' + AND provider_key = 'wechat-main' + AND channel = 'oa' + AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b') +`).Scan(&wechatChannelCount)) + require.Zero(t, wechatChannelCount) +} + +func TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migration115SQL, err := os.ReadFile(migration115Path) + require.NoError(t, err) + + migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migration116SQL, err := os.ReadFile(migration116Path) + require.NoError(t, err) + + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) + + var linuxDoFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoFirstUserID)) + + var linuxDoSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoSecondUserID)) + + var wechatFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatFirstUserID)) + + var wechatSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatSecondUserID)) + + var linuxDoFirstLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}') +RETURNING id +`, linuxDoFirstUserID).Scan(&linuxDoFirstLegacyID)) + + var linuxDoSecondLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}') +RETURNING id +`, linuxDoSecondUserID).Scan(&linuxDoSecondLegacyID)) + + var wechatFirstLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}') +RETURNING id +`, wechatFirstUserID).Scan(&wechatFirstLegacyID)) + + var wechatSecondLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}') +RETURNING id +`, wechatSecondUserID).Scan(&wechatSecondLegacyID)) + + _, err = tx.ExecContext(ctx, string(migration115SQL)) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration116SQL)) + require.NoError(t, err) + + var identityCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject') + OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject') +`).Scan(&identityCount)) + require.Zero(t, identityCount) + + var conflictReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_conflict' + AND report_key IN ($1, $2, $3, $4) +`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&conflictReportCount)) + require.Equal(t, 4, conflictReportCount) + + var winnerAttributedReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_conflict' + AND report_key IN ($1, $2, $3, $4) + AND details ->> 'existing_identity_id' IS NOT NULL +`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&winnerAttributedReportCount)) + require.Zero(t, winnerAttributedReportCount) +} + func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) { tx := testTx(t) ctx := context.Background() diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index 662a3972..f5798486 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -51,6 +51,8 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions ( const migrationsAdvisoryLockID int64 = 694208311321144027 const migrationsLockRetryInterval = 500 * time.Millisecond const nonTransactionalMigrationSuffix = "_notx.sql" +const paymentOrdersOutTradeNoUniqueMigration = "120_enforce_payment_orders_out_trade_no_unique_notx.sql" +const paymentOrdersOutTradeNoUniqueIndex = "paymentorder_out_trade_no_unique" type migrationChecksumCompatibilityRule struct { fileChecksum string @@ -65,9 +67,11 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil "054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"), "061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"), "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"), + "115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"), + "116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"), "118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"), "119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"), - "120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"), + "120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"), "123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"), } @@ -195,6 +199,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { } if nonTx { + if err := prepareNonTransactionalMigration(ctx, db, name); err != nil { + return fmt.Errorf("prepare migration %s: %w", name, err) + } + // *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。 // 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。 statements := splitSQLStatements(content) @@ -244,6 +252,88 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { return nil } +func prepareNonTransactionalMigration(ctx context.Context, db *sql.DB, name string) error { + switch name { + case paymentOrdersOutTradeNoUniqueMigration: + return preparePaymentOrdersOutTradeNoUniqueMigration(ctx, db) + default: + return nil + } +} + +func preparePaymentOrdersOutTradeNoUniqueMigration(ctx context.Context, db *sql.DB) error { + duplicates, err := findDuplicatePaymentOrderOutTradeNos(ctx, db) + if err != nil { + return fmt.Errorf("precheck duplicate out_trade_no: %w", err) + } + if len(duplicates) > 0 { + return fmt.Errorf( + "duplicate out_trade_no values block %s; remediate duplicates before retrying: %s", + paymentOrdersOutTradeNoUniqueMigration, + strings.Join(duplicates, ", "), + ) + } + + invalid, err := indexIsInvalid(ctx, db, paymentOrdersOutTradeNoUniqueIndex) + if err != nil { + return fmt.Errorf("check invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err) + } + if !invalid { + return nil + } + + if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", paymentOrdersOutTradeNoUniqueIndex)); err != nil { + return fmt.Errorf("drop invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err) + } + return nil +} + +func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]string, error) { + rows, err := db.QueryContext(ctx, ` + SELECT out_trade_no, COUNT(*) AS duplicate_count + FROM payment_orders + WHERE out_trade_no <> '' + GROUP BY out_trade_no + HAVING COUNT(*) > 1 + ORDER BY duplicate_count DESC, out_trade_no + LIMIT 5 + `) + if err != nil { + return nil, err + } + defer rows.Close() + + duplicates := make([]string, 0, 5) + for rows.Next() { + var outTradeNo string + var duplicateCount int + if err := rows.Scan(&outTradeNo, &duplicateCount); err != nil { + return nil, err + } + duplicates = append(duplicates, fmt.Sprintf("%s (count=%d)", outTradeNo, duplicateCount)) + } + if err := rows.Err(); err != nil { + return nil, err + } + return duplicates, nil +} + +func indexIsInvalid(ctx context.Context, db *sql.DB, indexName string) (bool, error) { + var invalid bool + err := db.QueryRowContext(ctx, ` + SELECT EXISTS ( + SELECT 1 + FROM pg_class idx + JOIN pg_namespace ns ON ns.oid = idx.relnamespace + JOIN pg_index i ON i.indexrelid = idx.oid + WHERE ns.nspname = 'public' + AND idx.relname = $1 + AND NOT i.indisvalid + ) + `, indexName).Scan(&invalid) + return invalid, err +} + func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error { hasLegacy, err := tableExists(ctx, db, "schema_migrations") if err != nil { diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go index dc241a75..57647093 100644 --- a/backend/internal/repository/migrations_runner_checksum_test.go +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -70,6 +70,24 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { require.True(t, ok) }) + t.Run("115历史checksum可兼容修复后的legacy external backfill", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "115_auth_identity_legacy_external_backfill.sql", + "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f", + "022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", + ) + require.True(t, ok) + }) + + t.Run("116历史checksum可兼容修复后的legacy external safety reports", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "116_auth_identity_legacy_external_safety_reports.sql", + "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877", + "07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", + ) + require.True(t, ok) + }) + t.Run("119历史checksum可兼容占位文件", func(t *testing.T) { ok := isMigrationChecksumCompatible( "119_enforce_payment_orders_out_trade_no_unique.sql", @@ -79,6 +97,21 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { require.True(t, ok) }) + t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) { + for _, dbChecksum := range []string{ + "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", + "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", + "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a", + } { + ok := isMigrationChecksumCompatible( + "120_enforce_payment_orders_out_trade_no_unique_notx.sql", + dbChecksum, + "34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", + ) + require.True(t, ok) + } + }) + t.Run("119未知checksum不兼容", func(t *testing.T) { ok := isMigrationChecksumCompatible( "119_enforce_payment_orders_out_trade_no_unique.sql", diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go index af1adc50..a8bc15bc 100644 --- a/backend/internal/repository/migrations_runner_extra_test.go +++ b/backend/internal/repository/migrations_runner_extra_test.go @@ -96,6 +96,8 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) { func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) { for _, name := range []string{ + "115_auth_identity_legacy_external_backfill.sql", + "116_auth_identity_legacy_external_safety_reports.sql", "118_wechat_dual_mode_and_auth_source_defaults.sql", "120_enforce_payment_orders_out_trade_no_unique_notx.sql", "123_fix_legacy_auth_source_grant_on_signup_defaults.sql", diff --git a/backend/internal/repository/migrations_runner_notx_test.go b/backend/internal/repository/migrations_runner_notx_test.go index db1183cd..b7cb396c 100644 --- a/backend/internal/repository/migrations_runner_notx_test.go +++ b/backend/internal/repository/migrations_runner_notx_test.go @@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b); require.NoError(t, mock.ExpectationsWereMet()) } +func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders"). + WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}).AddRow("dup-out-trade-no", 2)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{ + Data: []byte(` +CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique + ON payment_orders (out_trade_no) + WHERE out_trade_no <> ''; + +DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no; +`), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "duplicate out_trade_no") + require.Contains(t, err.Error(), "dup-out-trade-no") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders"). + WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"})) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("paymentorder_out_trade_no_unique"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{ + Data: []byte(` +CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique + ON payment_orders (out_trade_no) + WHERE out_trade_no <> ''; + +DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no; +`), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index ac4dea18..eeee5c23 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -93,6 +93,19 @@ func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) tx := testTx(t) requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false) + requireColumn(t, tx, "users", "signup_source", "character varying", 20, false) + requireColumnDefaultContains(t, tx, "users", "signup_source", "email") + requireConstraintDefinitionContains( + t, + tx, + "users", + "users_signup_source_check", + "signup_source", + "'email'", + "'linuxdo'", + "'wechat'", + "'oidc'", + ) requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE") requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE") @@ -195,6 +208,45 @@ LIMIT 1 require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable) } +func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) { + t.Helper() + + var def string + err := tx.QueryRowContext(context.Background(), ` +SELECT pg_get_constraintdef(c.oid) +FROM pg_constraint c +JOIN pg_class tbl ON tbl.oid = c.conrelid +JOIN pg_namespace ns ON ns.oid = tbl.relnamespace +WHERE ns.nspname = 'public' + AND tbl.relname = $1 + AND c.conname = $2 +`, table, constraint).Scan(&def) + require.NoError(t, err, "query constraint definition for %s.%s", table, constraint) + + for _, fragment := range fragments { + require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment) + } +} + +func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) { + t.Helper() + + var columnDefault sql.NullString + err := tx.QueryRowContext(context.Background(), ` +SELECT column_default +FROM information_schema.columns +WHERE table_schema = 'public' + AND table_name = $1 + AND column_name = $2 +`, table, column).Scan(&columnDefault) + require.NoError(t, err, "query column_default for %s.%s", table, column) + require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column) + + for _, fragment := range fragments { + require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment) + } +} + func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { t.Helper() diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go index 87094ad7..b2b03746 100644 --- a/backend/internal/repository/user_profile_identity_repo.go +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -4,11 +4,15 @@ import ( "context" "database/sql" "fmt" + "hash/fnv" "reflect" + "sort" "strings" + "sync" "time" "unsafe" + "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" @@ -120,6 +124,113 @@ type sqlQueryExecutor interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } +var repositoryScopedKeyLocks = newScopedKeyLockRegistry() + +type scopedKeyLockRegistry struct { + mu sync.Mutex + locks map[string]*scopedKeyLockEntry +} + +type scopedKeyLockEntry struct { + mu sync.Mutex + refs int +} + +func newScopedKeyLockRegistry() *scopedKeyLockRegistry { + return &scopedKeyLockRegistry{ + locks: make(map[string]*scopedKeyLockEntry), + } +} + +func (r *scopedKeyLockRegistry) lock(keys ...string) func() { + normalized := normalizeLockKeys(keys...) + if len(normalized) == 0 { + return func() {} + } + + entries := make([]*scopedKeyLockEntry, 0, len(normalized)) + r.mu.Lock() + for _, key := range normalized { + entry := r.locks[key] + if entry == nil { + entry = &scopedKeyLockEntry{} + r.locks[key] = entry + } + entry.refs++ + entries = append(entries, entry) + } + r.mu.Unlock() + + for _, entry := range entries { + entry.mu.Lock() + } + + return func() { + for i := len(entries) - 1; i >= 0; i-- { + entries[i].mu.Unlock() + } + + r.mu.Lock() + defer r.mu.Unlock() + for idx, key := range normalized { + entry := entries[idx] + entry.refs-- + if entry.refs == 0 { + delete(r.locks, key) + } + } + } +} + +func normalizeLockKeys(keys ...string) []string { + if len(keys) == 0 { + return nil + } + + deduped := make(map[string]struct{}, len(keys)) + for _, key := range keys { + trimmed := strings.TrimSpace(key) + if trimmed == "" { + continue + } + deduped[trimmed] = struct{}{} + } + if len(deduped) == 0 { + return nil + } + + normalized := make([]string, 0, len(deduped)) + for key := range deduped { + normalized = append(normalized, key) + } + sort.Strings(normalized) + return normalized +} + +func advisoryLockHash(key string) int64 { + hasher := fnv.New64a() + _, _ = hasher.Write([]byte(key)) + return int64(hasher.Sum64()) +} + +func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) { + release := repositoryScopedKeyLocks.lock(keys...) + normalized := normalizeLockKeys(keys...) + if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres { + return release, nil + } + + for _, key := range normalized { + rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key)) + if err != nil { + release() + return nil, err + } + _ = rows.Close() + } + return release, nil +} + func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error { if dbent.TxFromContext(ctx) != nil { return fn(ctx) @@ -329,7 +440,11 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA return err } } else { + targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey) update := client.AuthIdentity.UpdateOneID(identity.ID) + if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) { + update = update.SetProviderKey(targetProviderKey) + } if input.Metadata != nil { update = update.SetMetadata(copyMetadata(input.Metadata)) } @@ -378,8 +493,12 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA return err } } else { + targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey) update := client.AuthIdentityChannel.UpdateOneID(channel.ID). SetIdentityID(identity.ID) + if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) { + update = update.SetProviderKey(targetProviderKey) + } if input.ChannelMetadata != nil { update = update.SetMetadata(copyMetadata(input.ChannelMetadata)) } @@ -418,13 +537,52 @@ func compatibleIdentityProviderKeys(providerType, providerKey string) []string { return keys } +func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + existingKey = strings.TrimSpace(existingKey) + requestedKey = strings.TrimSpace(requestedKey) + if providerType != "wechat" { + if requestedKey != "" { + return requestedKey + } + return existingKey + } + if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") { + return "wechat-main" + } + if requestedKey != "" { + return requestedKey + } + return existingKey +} + +func compatibleIdentityProviderKeyRank(providerType, providerKey string) int { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + providerKey = strings.TrimSpace(providerKey) + if providerType != "wechat" { + return 0 + } + switch { + case strings.EqualFold(providerKey, "wechat-main"): + return 0 + case strings.EqualFold(providerKey, "wechat"): + return 2 + default: + return 1 + } +} + func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity { + var selected *dbent.AuthIdentity for _, record := range records { - if record.UserID == userID { - return record + if record.UserID != userID { + continue + } + if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record } } - return nil + return selected } func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool { @@ -437,12 +595,16 @@ func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) } func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel { + var selected *dbent.AuthIdentityChannel for _, record := range records { - if record.Edges.Identity != nil && record.Edges.Identity.UserID == userID { - return record + if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID { + continue + } + if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record } } - return nil + return selected } func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool { @@ -479,51 +641,70 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, } func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { - client := clientFromContext(ctx, r.client) - if input.IdentityID != nil && *input.IdentityID > 0 { - if _, err := client.IdentityAdoptionDecision.Update(). - Where( - identityadoptiondecision.IdentityIDEQ(*input.IdentityID), - dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { - col := s.C(identityadoptiondecision.FieldPendingAuthSessionID) - s.Where(entsql.Or( - entsql.IsNull(col), - entsql.NEQ(col, input.PendingAuthSessionID), - )) - }), - ). - ClearIdentityID(). - Save(ctx); err != nil { - return nil, err + var result *dbent.IdentityAdoptionDecision + err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + client := clientFromContext(txCtx, r.client) + releaseLocks, err := lockRepositoryScopedKeys( + txCtx, + client, + txAwareSQLExecutor(txCtx, r.sql, r.client), + identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)..., + ) + if err != nil { + return err + } + defer releaseLocks() + + if input.IdentityID != nil && *input.IdentityID > 0 { + if _, err := client.IdentityAdoptionDecision.Update(). + Where( + identityadoptiondecision.IdentityIDEQ(*input.IdentityID), + dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { + col := s.C(identityadoptiondecision.FieldPendingAuthSessionID) + s.Where(entsql.Or( + entsql.IsNull(col), + entsql.NEQ(col, input.PendingAuthSessionID), + )) + }), + ). + ClearIdentityID(). + Save(txCtx); err != nil { + return err + } } - } - current, err := client.IdentityAdoptionDecision.Query(). - Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { - return nil, err - } - now := time.Now().UTC() - if current == nil { create := client.IdentityAdoptionDecision.Create(). SetPendingAuthSessionID(input.PendingAuthSessionID). SetAdoptDisplayName(input.AdoptDisplayName). SetAdoptAvatar(input.AdoptAvatar). - SetDecidedAt(now) - if input.IdentityID != nil { + SetDecidedAt(time.Now().UTC()) + if input.IdentityID != nil && *input.IdentityID > 0 { create = create.SetIdentityID(*input.IdentityID) } - return create.Save(ctx) - } - update := client.IdentityAdoptionDecision.UpdateOneID(current.ID). - SetAdoptDisplayName(input.AdoptDisplayName). - SetAdoptAvatar(input.AdoptAvatar) - if input.IdentityID != nil { - update = update.SetIdentityID(*input.IdentityID) + decisionID, err := create. + OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID). + UpdateNewValues(). + ID(txCtx) + if err != nil { + return err + } + + result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID) + return err + }) + if err != nil { + return nil, err } - return update.Save(ctx) + return result, nil +} + +func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string { + keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)} + if identityID != nil && *identityID > 0 { + keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID)) + } + return keys } func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) { diff --git a/backend/internal/repository/user_profile_identity_repo_unit_test.go b/backend/internal/repository/user_profile_identity_repo_unit_test.go new file mode 100644 index 00000000..689f32f9 --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo_unit_test.go @@ -0,0 +1,212 @@ +package repository + +import ( + "context" + "sync" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias(t *testing.T) { + repo, client := newUserEntRepo(t) + ctx := context.Background() + + user := &service.User{ + Email: "wechat-legacy@example.com", + Username: "wechat-legacy", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, user)) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetProviderSubject("union-legacy-123"). + SetMetadata(map[string]any{"source": "legacy-alias"}). + Save(ctx) + require.NoError(t, err) + + legacyChannel, err := client.AuthIdentityChannel.Create(). + SetIdentityID(legacyIdentity.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetChannel("oa"). + SetChannelAppID("wx-app-legacy"). + SetChannelSubject("openid-legacy-123"). + SetMetadata(map[string]any{"scene": "legacy-alias"}). + Save(ctx) + require.NoError(t, err) + + bound, err := repo.BindAuthIdentityToUser(ctx, BindAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-legacy-123", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + Channel: "oa", + ChannelAppID: "wx-app-legacy", + ChannelSubject: "openid-legacy-123", + }, + Metadata: map[string]any{"source": "canonical-bind"}, + ChannelMetadata: map[string]any{"scene": "canonical-bind"}, + }) + require.NoError(t, err) + require.NotNil(t, bound) + require.NotNil(t, bound.Identity) + require.NotNil(t, bound.Channel) + require.Equal(t, legacyIdentity.ID, bound.Identity.ID) + require.Equal(t, legacyChannel.ID, bound.Channel.ID) + require.Equal(t, "wechat-main", bound.Identity.ProviderKey) + require.Equal(t, "wechat-main", bound.Channel.ProviderKey) + + reloadedIdentity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID) + require.NoError(t, err) + require.Equal(t, "wechat-main", reloadedIdentity.ProviderKey) + require.Equal(t, "canonical-bind", reloadedIdentity.Metadata["source"]) + + reloadedChannel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID) + require.NoError(t, err) + require.Equal(t, "wechat-main", reloadedChannel.ProviderKey) + require.Equal(t, "canonical-bind", reloadedChannel.Metadata["scene"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderSubjectEQ("union-legacy-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + + channelCount, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ChannelEQ("oa"), + authidentitychannel.ChannelAppIDEQ("wx-app-legacy"), + authidentitychannel.ChannelSubjectEQ("openid-legacy-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, channelCount) +} + +func TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency(t *testing.T) { + repo, client := newUserEntRepo(t) + ctx := context.Background() + + user := &service.User{ + Email: "repo-adoption@example.com", + Username: "repo-adoption", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, user)) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-repo-adoption"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("pending-repo-adoption"). + SetIntent("bind_current_user"). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-repo-adoption"). + SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)). + SetUpstreamIdentityClaims(map[string]any{"provider_subject": "union-repo-adoption"}). + SetLocalFlowState(map[string]any{"step": "pending"}). + Save(ctx) + require.NoError(t, err) + + firstCreateStarted := make(chan struct{}) + releaseFirstCreate := make(chan struct{}) + var firstCreate sync.Once + client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator { + return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) { + blocked := false + if m.Op().Is(dbent.OpCreate) { + firstCreate.Do(func() { + blocked = true + close(firstCreateStarted) + }) + } + if blocked { + <-releaseFirstCreate + } + return next.Mutate(ctx, m) + }) + }) + + type adoptionResult struct { + decision *dbent.IdentityAdoptionDecision + err error + } + + input := IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + } + + results := make(chan adoptionResult, 2) + go func() { + decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + <-firstCreateStarted + + go func() { + decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + time.Sleep(100 * time.Millisecond) + close(releaseFirstCreate) + + first := <-results + second := <-results + + require.NoError(t, first.err) + require.NoError(t, second.err) + require.NotNil(t, first.decision) + require.NotNil(t, second.decision) + require.Equal(t, first.decision.ID, second.decision.ID) + + count, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, count) + + loaded, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, loaded.IdentityID) + require.Equal(t, identity.ID, *loaded.IdentityID) + require.True(t, loaded.AdoptDisplayName) + require.True(t, loaded.AdoptAvatar) +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 68e51eeb..3d526e7b 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -43,9 +43,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error if userIn == nil { return nil } - if err := r.ensureNormalizedEmailAvailable(ctx, 0, userIn.Email); err != nil { - return err - } // 统一使用 ent 的事务:保证用户与允许分组的更新原子化, // 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。 @@ -55,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error } var txClient *dbent.Client + txCtx := ctx if err == nil { defer func() { _ = tx.Rollback() }() txClient = tx.Client() + txCtx = dbent.NewTxContext(ctx, tx) } else { // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。 if existingTx := dbent.TxFromContext(ctx); existingTx != nil { @@ -67,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error } } + releaseEmailLock, err := lockRepositoryScopedKeys( + txCtx, + txClient, + txAwareSQLExecutor(txCtx, r.sql, r.client), + normalizedEmailUniquenessLockKey(userIn.Email), + ) + if err != nil { + return err + } + defer releaseEmailLock() + + if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil { + return err + } + created, err := txClient.User.Create(). SetEmail(userIn.Email). SetUsername(userIn.Username). @@ -79,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). SetNillableLastLoginAt(userIn.LastLoginAt). SetNillableLastActiveAt(userIn.LastActiveAt). - Save(ctx) + Save(txCtx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) } - if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil { + if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil { return err } - if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil { + if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil { return err } @@ -149,9 +163,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error if userIn == nil { return nil } - if err := r.ensureNormalizedEmailAvailable(ctx, userIn.ID, userIn.Email); err != nil { - return err - } // 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。 tx, err := r.client.Tx(ctx) @@ -160,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error } var txClient *dbent.Client + txCtx := ctx if err == nil { defer func() { _ = tx.Rollback() }() txClient = tx.Client() + txCtx = dbent.NewTxContext(ctx, tx) } else { // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。 if existingTx := dbent.TxFromContext(ctx); existingTx != nil { @@ -171,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error txClient = r.client } } - existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID) + + releaseEmailLock, err := lockRepositoryScopedKeys( + txCtx, + txClient, + txAwareSQLExecutor(txCtx, r.sql, r.client), + normalizedEmailUniquenessLockKey(userIn.Email), + ) + if err != nil { + return err + } + defer releaseEmailLock() + + if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil { + return err + } + + existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, nil) } @@ -203,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error if userIn.BalanceNotifyThreshold == nil { updateOp = updateOp.ClearBalanceNotifyThreshold() } - updated, err := updateOp.Save(ctx) + updated, err := updateOp.Save(txCtx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) } - if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil { + if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil { return err } - if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil { + if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil { return err } @@ -711,7 +740,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, } func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error { - matches, err := r.client.User.Query(). + return ensureNormalizedEmailAvailableWithClient(ctx, clientFromContext(ctx, r.client), userID, email) +} + +func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error { + client = clientFromContext(ctx, client) + if client == nil { + return nil + } + + matches, err := client.User.Query(). Where(userEmailLookupPredicate(email)). All(ctx) if err != nil { @@ -726,7 +764,7 @@ func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, use } func userEmailLookupPredicate(email string) predicate.User { - normalized := strings.ToLower(strings.TrimSpace(email)) + normalized := normalizeEmailLookupValue(email) if normalized == "" { return dbuser.EmailEQ(email) } @@ -740,6 +778,18 @@ func userEmailLookupPredicate(email string) predicate.User { }) } +func normalizeEmailLookupValue(email string) string { + return strings.ToLower(strings.TrimSpace(email)) +} + +func normalizedEmailUniquenessLockKey(email string) string { + normalized := normalizeEmailLookupValue(email) + if normalized == "" { + return "" + } + return "users:normalized-email:" + normalized +} + func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { client := clientFromContext(ctx, r.client) err := client.UserAllowedGroup.Create(). @@ -874,11 +924,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { } func userSignupSourceOrDefault(signupSource string) string { - signupSource = strings.TrimSpace(signupSource) - if signupSource == "" { + switch strings.TrimSpace(strings.ToLower(signupSource)) { + case "", "email": + return "email" + case "linuxdo", "wechat", "oidc": + return strings.TrimSpace(strings.ToLower(signupSource)) + default: return "email" } - return signupSource } // marshalExtraEmails serializes notify email entries to JSON for storage. diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go index b2b02ef5..2ef9d761 100644 --- a/backend/internal/repository/user_repo_email_lookup_unit_test.go +++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go @@ -3,7 +3,10 @@ package repository import ( "context" "database/sql" + "fmt" + "sync" "testing" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/enttest" @@ -18,9 +21,10 @@ import ( func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) { t.Helper() - db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared") + db, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name())) require.NoError(t, err) t.Cleanup(func() { _ = db.Close() }) + db.SetMaxOpenConns(10) _, err = db.Exec("PRAGMA foreign_keys = ON") require.NoError(t, err) @@ -144,3 +148,80 @@ func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) { require.Error(t, err) require.ErrorContains(t, err, "normalized email lookup matched multiple users") } + +func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency(t *testing.T) { + repo, client := newUserEntRepo(t) + ctx := context.Background() + + firstCreateStarted := make(chan struct{}) + releaseFirstCreate := make(chan struct{}) + var firstCreate sync.Once + client.User.Use(func(next dbent.Mutator) dbent.Mutator { + return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) { + blocked := false + if m.Op().Is(dbent.OpCreate) { + firstCreate.Do(func() { + blocked = true + close(firstCreateStarted) + }) + } + if blocked { + <-releaseFirstCreate + } + return next.Mutate(ctx, m) + }) + }) + + type createResult struct { + err error + } + + results := make(chan createResult, 2) + go func() { + results <- createResult{err: repo.Create(ctx, &service.User{ + Email: " Race@Example.com ", + Username: "race-user-1", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + })} + }() + + <-firstCreateStarted + + go func() { + results <- createResult{err: repo.Create(ctx, &service.User{ + Email: "race@example.com", + Username: "race-user-2", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + })} + }() + + time.Sleep(100 * time.Millisecond) + close(releaseFirstCreate) + + first := <-results + second := <-results + + errors := []error{first.err, second.err} + successes := 0 + conflicts := 0 + for _, err := range errors { + switch { + case err == nil: + successes++ + case err == service.ErrEmailExists: + conflicts++ + default: + t.Fatalf("unexpected create error: %v", err) + } + } + require.Equal(t, 1, successes) + require.Equal(t, 1, conflicts) + + count, err := client.User.Query().Where(userEmailLookupPredicate("race@example.com")).Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, count) +} diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index ea558ae2..a18cf39c 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -14,10 +14,14 @@ import ( func normalizeOAuthSignupSource(signupSource string) string { signupSource = strings.TrimSpace(strings.ToLower(signupSource)) - if signupSource == "" { + switch signupSource { + case "", "email": + return "email" + case "linuxdo", "wechat", "oidc": + return signupSource + default: return "email" } - return signupSource } // SendPendingOAuthVerifyCode sends a local verification code for pending OAuth @@ -136,10 +140,7 @@ func (s *AuthService) RegisterOAuthEmailAccount( return nil, nil, fmt.Errorf("hash password: %w", err) } - signupSource = strings.TrimSpace(strings.ToLower(signupSource)) - if signupSource == "" { - signupSource = "email" - } + signupSource = normalizeOAuthSignupSource(signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) user := &User{ @@ -149,6 +150,7 @@ func (s *AuthService) RegisterOAuthEmailAccount( Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, Status: StatusActive, + SignupSource: signupSource, } if err := s.userRepo.Create(ctx, user); err != nil { diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go index a77dda72..e3fb2f85 100644 --- a/backend/internal/service/auth_oauth_email_flow_test.go +++ b/backend/internal/service/auth_oauth_email_flow_test.go @@ -191,6 +191,80 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai require.Empty(t, redeemRepo.updateCalls) } +func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) { + userRepo := &userRepoStub{nextID: 42} + emailCache := &emailCacheStub{ + data: &VerificationCodeData{ + Code: "246810", + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + &redeemCodeRepoStub{}, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, + emailCache, + ) + + tokenPair, user, err := authService.RegisterOAuthEmailAccount( + context.Background(), + "fresh@example.com", + "secret-123", + "246810", + "", + " OIDC ", + ) + + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Len(t, userRepo.created, 1) + require.Equal(t, "oidc", userRepo.created[0].SignupSource) +} + +func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) { + userRepo := &userRepoStub{nextID: 43} + emailCache := &emailCacheStub{ + data: &VerificationCodeData{ + Code: "246810", + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + &redeemCodeRepoStub{}, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, + emailCache, + ) + + tokenPair, user, err := authService.RegisterOAuthEmailAccount( + context.Background(), + "fallback@example.com", + "secret-123", + "246810", + "", + "github", + ) + + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Len(t, userRepo.created, 1) + require.Equal(t, "email", userRepo.created[0].SignupSource) +} + func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) { userRepo := &userRepoStub{} redeemRepo := &redeemCodeRepoStub{ diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go index cc0522ab..6e69c121 100644 --- a/backend/internal/service/auth_pending_identity_service.go +++ b/backend/internal/service/auth_pending_identity_service.go @@ -5,10 +5,15 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" + "errors" "fmt" + "hash/fnv" + "sort" "strings" + "sync" "time" + "entgo.io/ent/dialect" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" @@ -75,6 +80,122 @@ type AuthPendingIdentityService struct { entClient *dbent.Client } +var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry() + +type authPendingIdentityScopedKeyLockRegistry struct { + mu sync.Mutex + locks map[string]*authPendingIdentityScopedKeyLockEntry +} + +type authPendingIdentityScopedKeyLockEntry struct { + mu sync.Mutex + refs int +} + +func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry { + return &authPendingIdentityScopedKeyLockRegistry{ + locks: make(map[string]*authPendingIdentityScopedKeyLockEntry), + } +} + +func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() { + normalized := normalizeAuthPendingIdentityLockKeys(keys...) + if len(normalized) == 0 { + return func() {} + } + + entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized)) + r.mu.Lock() + for _, key := range normalized { + entry := r.locks[key] + if entry == nil { + entry = &authPendingIdentityScopedKeyLockEntry{} + r.locks[key] = entry + } + entry.refs++ + entries = append(entries, entry) + } + r.mu.Unlock() + + for _, entry := range entries { + entry.mu.Lock() + } + + return func() { + for i := len(entries) - 1; i >= 0; i-- { + entries[i].mu.Unlock() + } + + r.mu.Lock() + defer r.mu.Unlock() + for idx, key := range normalized { + entry := entries[idx] + entry.refs-- + if entry.refs == 0 { + delete(r.locks, key) + } + } + } +} + +func normalizeAuthPendingIdentityLockKeys(keys ...string) []string { + if len(keys) == 0 { + return nil + } + + deduped := make(map[string]struct{}, len(keys)) + for _, key := range keys { + trimmed := strings.TrimSpace(key) + if trimmed == "" { + continue + } + deduped[trimmed] = struct{}{} + } + if len(deduped) == 0 { + return nil + } + + normalized := make([]string, 0, len(deduped)) + for key := range deduped { + normalized = append(normalized, key) + } + sort.Strings(normalized) + return normalized +} + +func authPendingIdentityAdvisoryLockHash(key string) int64 { + hasher := fnv.New64a() + _, _ = hasher.Write([]byte(key)) + return int64(hasher.Sum64()) +} + +func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) { + release := authPendingIdentityScopedKeyLocks.lock(keys...) + normalized := normalizeAuthPendingIdentityLockKeys(keys...) + if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres { + return release, nil + } + + for _, key := range normalized { + var rows entsql.Rows + if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil { + release() + return nil, err + } + _ = rows.Close() + } + + return release, nil +} + +func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string { + keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)} + if identityID != nil && *identityID > 0 { + keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID)) + } + return keys +} + func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService { return &AuthPendingIdentityService{entClient: entClient} } @@ -324,8 +445,29 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, return nil, fmt.Errorf("pending auth ent client is not configured") } + tx, err := s.entClient.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { + return nil, err + } + + client := s.entClient + txCtx := ctx + if err == nil { + defer func() { _ = tx.Rollback() }() + client = tx.Client() + txCtx = dbent.NewTxContext(ctx, tx) + } else if existingTx := dbent.TxFromContext(ctx); existingTx != nil { + client = existingTx.Client() + } + + releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...) + if err != nil { + return nil, err + } + defer releaseLocks() + if input.IdentityID != nil && *input.IdentityID > 0 { - if _, err := s.entClient.IdentityAdoptionDecision.Update(). + if _, err := client.IdentityAdoptionDecision.Update(). Where( identityadoptiondecision.IdentityIDEQ(*input.IdentityID), dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { @@ -337,36 +479,40 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, }), ). ClearIdentityID(). - Save(ctx); err != nil { + Save(txCtx); err != nil { return nil, err } } - existing, err := s.entClient.IdentityAdoptionDecision.Query(). - Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { - return nil, err - } - if existing == nil { - create := s.entClient.IdentityAdoptionDecision.Create(). - SetPendingAuthSessionID(input.PendingAuthSessionID). - SetAdoptDisplayName(input.AdoptDisplayName). - SetAdoptAvatar(input.AdoptAvatar). - SetDecidedAt(time.Now().UTC()) - if input.IdentityID != nil { - create = create.SetIdentityID(*input.IdentityID) - } - return create.Save(ctx) + create := client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(time.Now().UTC()) + if input.IdentityID != nil && *input.IdentityID > 0 { + create = create.SetIdentityID(*input.IdentityID) } - update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID). - SetAdoptDisplayName(input.AdoptDisplayName). - SetAdoptAvatar(input.AdoptAvatar) - if input.IdentityID != nil { - update = update.SetIdentityID(*input.IdentityID) + decisionID, err := create. + OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID). + UpdateNewValues(). + ID(txCtx) + if err != nil { + return nil, err } - return update.Save(ctx) + + decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID) + if err != nil { + return nil, err + } + + if tx != nil { + if err := tx.Commit(); err != nil { + return nil, err + } + } + + return decision, nil } func copyPendingMap(in map[string]any) map[string]any { diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go index deeeeb06..555bb0e7 100644 --- a/backend/internal/service/auth_pending_identity_service_test.go +++ b/backend/internal/service/auth_pending_identity_service_test.go @@ -5,6 +5,7 @@ package service import ( "context" "database/sql" + "sync" "testing" "time" @@ -259,6 +260,107 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIden require.Nil(t, reloadedFirst.IdentityID) } +func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("adoption-concurrent@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-concurrent"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-concurrent", + }, + }) + require.NoError(t, err) + + firstCreateStarted := make(chan struct{}) + releaseFirstCreate := make(chan struct{}) + var firstCreate sync.Once + client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator { + return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) { + blocked := false + if m.Op().Is(dbent.OpCreate) { + firstCreate.Do(func() { + blocked = true + close(firstCreateStarted) + }) + } + if blocked { + <-releaseFirstCreate + } + return next.Mutate(ctx, m) + }) + }) + + type adoptionResult struct { + decision *dbent.IdentityAdoptionDecision + err error + } + + input := PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + } + + results := make(chan adoptionResult, 2) + go func() { + decision, err := svc.UpsertAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + <-firstCreateStarted + + go func() { + decision, err := svc.UpsertAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + time.Sleep(100 * time.Millisecond) + close(releaseFirstCreate) + + first := <-results + second := <-results + + require.NoError(t, first.err) + require.NoError(t, second.err) + require.NotNil(t, first.decision) + require.NotNil(t, second.decision) + require.Equal(t, first.decision.ID, second.decision.ID) + + count, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, count) + + loaded, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, loaded.IdentityID) + require.Equal(t, identity.ID, *loaded.IdentityID) +} + func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) { t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL") diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index efe08644..59442d1f 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "crypto/sha256" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -489,6 +490,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, Status: StatusActive, + SignupSource: signupSource, } if err := s.userRepo.Create(ctx, newUser); err != nil { @@ -599,6 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, Status: StatusActive, + SignupSource: signupSource, } if s.entClient != nil && invitationRedeemCode != nil { @@ -1048,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) { UserID: user.ID, Email: user.Email, Role: user.Role, - TokenVersion: user.TokenVersion, + TokenVersion: resolvedTokenVersion(user), RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expiresAt), IssuedAt: jwt.NewNumericDate(now), @@ -1114,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) ( // Security: Check TokenVersion to prevent refreshing revoked tokens // This ensures tokens issued before a password change cannot be refreshed - if claims.TokenVersion != user.TokenVersion { + if claims.TokenVersion != resolvedTokenVersion(user) { return "", ErrTokenRevoked } @@ -1342,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami data := &RefreshTokenData{ UserID: user.ID, - TokenVersion: user.TokenVersion, + TokenVersion: resolvedTokenVersion(user), FamilyID: familyID, CreatedAt: now, ExpiresAt: now.Add(ttl), @@ -1422,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) } // 检查TokenVersion(密码更改后所有Token失效) - if data.TokenVersion != user.TokenVersion { + if data.TokenVersion != resolvedTokenVersion(user) { // TokenVersion不匹配,撤销整个Token家族 _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) return nil, ErrTokenRevoked @@ -1492,3 +1495,14 @@ func hashToken(token string) string { hash := sha256.Sum256([]byte(token)) return hex.EncodeToString(hash[:]) } + +func resolvedTokenVersion(user *User) int64 { + if user == nil { + return 0 + } + + material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash + sum := sha256.Sum256([]byte(material)) + fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff) + return user.TokenVersion ^ fingerprint +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f08274c7..aac60b08 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -814,6 +814,20 @@ func parseCustomMenuItemURLs(raw string) []string { return urls } +func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool { + if base.UsePKCEExplicit { + return base.UsePKCE + } + return false +} + +func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool { + if base.ValidateIDTokenExplicit { + return base.ValidateIDToken + } + return false +} + // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { updates, err := s.buildSystemSettingsUpdates(ctx, settings) @@ -1479,6 +1493,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { return fmt.Errorf("check existing settings: %w", err) } + oidcUsePKCEDefault := true + oidcValidateIDTokenDefault := true + if s != nil && s.cfg != nil { + if s.cfg.OIDC.UsePKCEExplicit { + oidcUsePKCEDefault = s.cfg.OIDC.UsePKCE + } + if s.cfg.OIDC.ValidateIDTokenExplicit { + oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken + } + } + // 初始化默认设置 defaults := map[string]string{ SettingKeyRegistrationEnabled: "true", @@ -1523,8 +1548,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyOIDCConnectRedirectURL: "", SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", - SettingKeyOIDCConnectUsePKCE: "true", - SettingKeyOIDCConnectValidateIDToken: "true", + SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault), + SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault), SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", SettingKeyOIDCConnectClockSkewSeconds: "120", SettingKeyOIDCConnectRequireEmailVerified: "false", @@ -1767,12 +1792,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { result.OIDCConnectUsePKCE = raw == "true" } else { - result.OIDCConnectUsePKCE = oidcBase.UsePKCE + result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase) } if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { result.OIDCConnectValidateIDToken = raw == "true" } else { - result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken + result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase) } if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) @@ -2482,9 +2507,13 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. } if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { effective.UsePKCE = raw == "true" + } else { + effective.UsePKCE = oidcUsePKCECompatibilityDefault(effective) } if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { effective.ValidateIDToken = raw == "true" + } else { + effective.ValidateIDToken = oidcValidateIDTokenCompatibilityDefault(effective) } if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { effective.AllowedSigningAlgs = strings.TrimSpace(v) diff --git a/backend/internal/service/setting_service_oidc_config_test.go b/backend/internal/service/setting_service_oidc_config_test.go index eb312d2c..1ece6405 100644 --- a/backend/internal/service/setting_service_oidc_config_test.go +++ b/backend/internal/service/setting_service_oidc_config_test.go @@ -118,8 +118,10 @@ func TestSettingService_ParseSettings_PreservesOptionalOIDCCompatibilityFlags(t func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValues(t *testing.T) { svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{ OIDC: config.OIDCConnectConfig{ - UsePKCE: true, - ValidateIDToken: true, + UsePKCE: true, + UsePKCEExplicit: true, + ValidateIDToken: true, + ValidateIDTokenExplicit: true, }, }) @@ -131,6 +133,22 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue require.True(t, got.OIDCConnectValidateIDToken) } +func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) { + svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{ + OIDC: config.OIDCConnectConfig{ + UsePKCE: true, + ValidateIDToken: true, + }, + }) + + got := svc.parseSettings(map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + }) + + require.False(t, got.OIDCConnectUsePKCE) + require.False(t, got.OIDCConnectValidateIDToken) +} + func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) { cfg := &config.Config{ OIDC: config.OIDCConnectConfig{ @@ -163,6 +181,42 @@ func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTok } func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *testing.T) { + cfg := &config.Config{ + OIDC: config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/auth", + TokenURL: "https://issuer.example.com/token", + UserInfoURL: "https://issuer.example.com/userinfo", + JWKSURL: "https://issuer.example.com/jwks", + RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + Scopes: "openid email profile", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + UsePKCEExplicit: true, + ValidateIDToken: true, + ValidateIDTokenExplicit: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + }, + } + + repo := &settingOIDCRepoStub{values: map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + }} + svc := NewSettingService(repo, cfg) + + got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.UsePKCE) + require.True(t, got.ValidateIDToken) +} + +func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) { cfg := &config.Config{ OIDC: config.OIDCConnectConfig{ Enabled: true, @@ -192,6 +246,6 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) require.NoError(t, err) - require.True(t, got.UsePKCE) - require.True(t, got.ValidateIDToken) + require.False(t, got.UsePKCE) + require.False(t, got.ValidateIDToken) } diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql index fbaed62e..f59b2188 100644 --- a/backend/migrations/110_pending_auth_and_provider_default_grants.sql +++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql @@ -38,23 +38,22 @@ VALUES ('auth_source_default_email_balance', '0'), ('auth_source_default_email_concurrency', '5'), ('auth_source_default_email_subscriptions', '[]'), - ('auth_source_default_email_grant_on_signup', 'true'), + ('auth_source_default_email_grant_on_signup', 'false'), ('auth_source_default_email_grant_on_first_bind', 'false'), ('auth_source_default_linuxdo_balance', '0'), ('auth_source_default_linuxdo_concurrency', '5'), ('auth_source_default_linuxdo_subscriptions', '[]'), - ('auth_source_default_linuxdo_grant_on_signup', 'true'), + ('auth_source_default_linuxdo_grant_on_signup', 'false'), ('auth_source_default_linuxdo_grant_on_first_bind', 'false'), ('auth_source_default_oidc_balance', '0'), ('auth_source_default_oidc_concurrency', '5'), ('auth_source_default_oidc_subscriptions', '[]'), - ('auth_source_default_oidc_grant_on_signup', 'true'), + ('auth_source_default_oidc_grant_on_signup', 'false'), ('auth_source_default_oidc_grant_on_first_bind', 'false'), ('auth_source_default_wechat_balance', '0'), ('auth_source_default_wechat_concurrency', '5'), ('auth_source_default_wechat_subscriptions', '[]'), - ('auth_source_default_wechat_grant_on_signup', 'true'), + ('auth_source_default_wechat_grant_on_signup', 'false'), ('auth_source_default_wechat_grant_on_first_bind', 'false'), ('force_email_on_third_party_signup', 'false') ON CONFLICT (key) DO NOTHING; - diff --git a/backend/migrations/115_auth_identity_legacy_external_backfill.sql b/backend/migrations/115_auth_identity_legacy_external_backfill.sql index 7a20f8eb..264da3c9 100644 --- a/backend/migrations/115_auth_identity_legacy_external_backfill.sql +++ b/backend/migrations/115_auth_identity_legacy_external_backfill.sql @@ -31,6 +31,41 @@ BEGIN END IF; EXECUTE $sql$ +WITH legacy AS ( + SELECT + uei.id, + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_username) AS provider_username, + BTRIM(uei.display_name) AS display_name, + public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + uei.created_at, + uei.updated_at + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' +), +legacy_subjects AS ( + SELECT + provider_user_id AS provider_subject, + COUNT(DISTINCT user_id) AS distinct_user_count + FROM legacy + GROUP BY provider_user_id +), +canonical_legacy AS ( + SELECT + legacy.*, + ROW_NUMBER() OVER ( + PARTITION BY legacy.provider_user_id + ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC + ) AS canonical_row_num + FROM legacy + JOIN legacy_subjects AS subjects + ON subjects.provider_subject = legacy.provider_user_id + AND subjects.distinct_user_count = 1 +) INSERT INTO auth_identities ( user_id, provider_type, @@ -52,11 +87,18 @@ SELECT 'display_name', legacy.display_name, 'migration', '115_auth_identity_legacy_external_backfill' ) -FROM ( +FROM canonical_legacy AS legacy +WHERE legacy.canonical_row_num = 1 +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; +$sql$; + + EXECUTE $sql$ +WITH legacy AS ( SELECT uei.id, uei.user_id, BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_union_id) AS provider_union_id, BTRIM(uei.provider_username) AS provider_username, BTRIM(uei.display_name) AS display_name, public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, @@ -65,13 +107,28 @@ FROM ( FROM user_external_identities AS uei JOIN users AS u ON u.id = uei.user_id WHERE u.deleted_at IS NULL - AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' - AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' -) AS legacy -ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; -$sql$; - - EXECUTE $sql$ + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' +), +legacy_subjects AS ( + SELECT + provider_union_id AS provider_subject, + COUNT(DISTINCT user_id) AS distinct_user_count + FROM legacy + GROUP BY provider_union_id +), +canonical_legacy AS ( + SELECT + legacy.*, + ROW_NUMBER() OVER ( + PARTITION BY legacy.provider_union_id + ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC + ) AS canonical_row_num + FROM legacy + JOIN legacy_subjects AS subjects + ON subjects.provider_subject = legacy.provider_union_id + AND subjects.distinct_user_count = 1 +) INSERT INTO auth_identities ( user_id, provider_type, @@ -96,27 +153,36 @@ SELECT 'display_name', legacy.display_name, 'migration', '115_auth_identity_legacy_external_backfill' ) -FROM ( - SELECT - uei.id, - uei.user_id, - BTRIM(uei.provider_user_id) AS provider_user_id, - BTRIM(uei.provider_union_id) AS provider_union_id, - BTRIM(uei.provider_username) AS provider_username, - BTRIM(uei.display_name) AS display_name, - public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, - uei.created_at, - uei.updated_at - FROM user_external_identities AS uei - JOIN users AS u ON u.id = uei.user_id - WHERE u.deleted_at IS NULL - AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' - AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' -) AS legacy +FROM canonical_legacy AS legacy +WHERE legacy.canonical_row_num = 1 ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; $sql$; EXECUTE $sql$ +WITH legacy AS ( + SELECT + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_union_id) AS provider_union_id, + BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel, + BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id, + meta.metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + CROSS JOIN LATERAL ( + SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + ) AS meta + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' +), +legacy_subjects AS ( + SELECT + provider_union_id AS provider_subject, + COUNT(DISTINCT user_id) AS distinct_user_count + FROM legacy + GROUP BY provider_union_id +) INSERT INTO auth_identity_channels ( identity_id, provider_type, @@ -138,23 +204,10 @@ SELECT 'unionid', legacy.provider_union_id, 'migration', '115_auth_identity_legacy_external_backfill' ) -FROM ( - SELECT - uei.user_id, - BTRIM(uei.provider_user_id) AS provider_user_id, - BTRIM(uei.provider_union_id) AS provider_union_id, - BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel, - BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id, - meta.metadata_json - FROM user_external_identities AS uei - JOIN users AS u ON u.id = uei.user_id - CROSS JOIN LATERAL ( - SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json - ) AS meta - WHERE u.deleted_at IS NULL - AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' - AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' -) AS legacy +FROM legacy +JOIN legacy_subjects AS subjects + ON subjects.provider_subject = legacy.provider_union_id + AND subjects.distinct_user_count = 1 JOIN auth_identities AS ai ON ai.user_id = legacy.user_id AND ai.provider_type = 'wechat' diff --git a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql index 3983bb1a..81eb133c 100644 --- a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql +++ b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql @@ -74,6 +74,82 @@ $sql$; EXECUTE $sql$ INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_external_identity_conflict', + 'legacy_external_identity:' || legacy.id::text, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'legacy_user_id', legacy.user_id, + 'provider_type', legacy.provider_type, + 'provider_key', legacy.provider_key, + 'provider_subject', legacy.provider_subject, + 'conflicting_legacy_user_ids', ambiguous.conflicting_legacy_user_ids, + 'reason', 'legacy canonical identity subject belongs to multiple legacy users and cannot be auto-resolved', + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) +) AS legacy +JOIN ( + SELECT + provider_type, + provider_key, + provider_subject, + to_jsonb(array_agg(DISTINCT user_id ORDER BY user_id)) AS conflicting_legacy_user_ids + FROM ( + SELECT + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) + ) AS legacy_subjects + GROUP BY provider_type, provider_key, provider_subject + HAVING COUNT(DISTINCT user_id) > 1 +) AS ambiguous + ON ambiguous.provider_type = legacy.provider_type + AND ambiguous.provider_key = legacy.provider_key + AND ambiguous.provider_subject = legacy.provider_subject +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) SELECT 'legacy_external_identity_conflict', 'legacy_external_identity:' || legacy.id::text, @@ -116,6 +192,39 @@ FROM ( (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') ) ) AS legacy +JOIN ( + SELECT + provider_type, + provider_key, + provider_subject + FROM ( + SELECT + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) + ) AS legacy_subjects + GROUP BY provider_type, provider_key, provider_subject + HAVING COUNT(DISTINCT user_id) = 1 +) AS clear_subjects + ON clear_subjects.provider_type = legacy.provider_type + AND clear_subjects.provider_key = legacy.provider_key + AND clear_subjects.provider_subject = legacy.provider_subject JOIN auth_identities AS ai ON ai.provider_type = legacy.provider_type AND ai.provider_key = legacy.provider_key @@ -125,29 +234,7 @@ ON CONFLICT (report_type, report_key) DO NOTHING; $sql$; EXECUTE $sql$ -INSERT INTO auth_identities ( - user_id, - provider_type, - provider_key, - provider_subject, - verified_at, - metadata -) -SELECT - legacy.user_id, - legacy.provider_type, - legacy.provider_key, - legacy.provider_subject, - legacy.verified_at, - legacy.metadata_json || jsonb_build_object( - 'legacy_identity_id', legacy.id, - 'provider_user_id', legacy.provider_user_id, - 'provider_union_id', NULLIF(legacy.provider_union_id, ''), - 'provider_username', legacy.provider_username, - 'display_name', legacy.display_name, - 'migration', '116_auth_identity_legacy_external_safety_reports' - ) -FROM ( +WITH legacy AS ( SELECT uei.id, uei.user_id, @@ -175,12 +262,58 @@ FROM ( OR (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') ) -) AS legacy +), +clear_subjects AS ( + SELECT + provider_type, + provider_key, + provider_subject + FROM legacy + GROUP BY provider_type, provider_key, provider_subject + HAVING COUNT(DISTINCT user_id) = 1 +), +canonical_legacy AS ( + SELECT + legacy.*, + ROW_NUMBER() OVER ( + PARTITION BY legacy.provider_type, legacy.provider_key, legacy.provider_subject + ORDER BY legacy.verified_at DESC, legacy.id DESC + ) AS canonical_row_num + FROM legacy + JOIN clear_subjects + ON clear_subjects.provider_type = legacy.provider_type + AND clear_subjects.provider_key = legacy.provider_key + AND clear_subjects.provider_subject = legacy.provider_subject +) +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + legacy.user_id, + legacy.provider_type, + legacy.provider_key, + legacy.provider_subject, + legacy.verified_at, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'provider_user_id', legacy.provider_user_id, + 'provider_union_id', NULLIF(legacy.provider_union_id, ''), + 'provider_username', legacy.provider_username, + 'display_name', legacy.display_name, + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM canonical_legacy AS legacy LEFT JOIN auth_identities AS ai ON ai.provider_type = legacy.provider_type AND ai.provider_key = legacy.provider_key AND ai.provider_subject = legacy.provider_subject -WHERE ai.id IS NULL +WHERE legacy.canonical_row_num = 1 + AND ai.id IS NULL ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; $sql$; @@ -225,6 +358,19 @@ FROM ( AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' ) AS legacy +JOIN ( + SELECT + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_subject + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' + GROUP BY BTRIM(COALESCE(uei.provider_union_id, '')) + HAVING COUNT(DISTINCT uei.user_id) = 1 +) AS clear_subjects + ON clear_subjects.provider_subject = legacy.provider_union_id JOIN auth_identities AS legacy_ai ON legacy_ai.user_id = legacy.user_id AND legacy_ai.provider_type = 'wechat' @@ -245,6 +391,33 @@ ON CONFLICT (report_type, report_key) DO NOTHING; $sql$; EXECUTE $sql$ +WITH legacy AS ( + SELECT + uei.user_id, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel, + BTRIM(COALESCE( + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id', + '' + )) AS channel_app_id + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' +), +clear_subjects AS ( + SELECT + provider_union_id AS provider_subject + FROM legacy + GROUP BY provider_union_id + HAVING COUNT(DISTINCT user_id) = 1 +) INSERT INTO auth_identity_channels ( identity_id, provider_type, @@ -266,26 +439,9 @@ SELECT 'unionid', legacy.provider_union_id, 'migration', '116_auth_identity_legacy_external_safety_reports' ) -FROM ( - SELECT - uei.user_id, - BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, - BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, - public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, - BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel, - BTRIM(COALESCE( - public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id', - public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid', - public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id', - '' - )) AS channel_app_id - FROM user_external_identities AS uei - JOIN users AS u ON u.id = uei.user_id - WHERE u.deleted_at IS NULL - AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' - AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' - AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' -) AS legacy +FROM legacy +JOIN clear_subjects + ON clear_subjects.provider_subject = legacy.provider_union_id JOIN auth_identities AS legacy_ai ON legacy_ai.user_id = legacy.user_id AND legacy_ai.provider_type = 'wechat' diff --git a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql index 094b223c..4388285a 100644 --- a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql +++ b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql @@ -1,3 +1,68 @@ --- Intentionally left as a no-op. --- Legacy installs may have intentionally kept the original signup grant defaults, --- and we cannot distinguish those cases safely from untouched migration 110 rows. +-- Auto-backfill untouched migration 110 signup-grant defaults to the corrected false value. +-- Rows still matching the migration-110 default payload and timestamp window are treated as +-- untouched legacy defaults; any remaining legacy true values are reported for manual review. + +WITH migration_110 AS ( + SELECT applied_at + FROM schema_migrations + WHERE filename = '110_pending_auth_and_provider_default_grants.sql' +), +providers AS ( + SELECT provider_type + FROM ( + VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat') + ) AS providers(provider_type) +), +legacy_provider_defaults AS ( + SELECT providers.provider_type + FROM providers + CROSS JOIN migration_110 + JOIN settings balance + ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance' + JOIN settings concurrency + ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency' + JOIN settings subscriptions + ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions' + JOIN settings grant_on_signup + ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup' + JOIN settings grant_on_first_bind + ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind' + WHERE balance.value = '0' + AND concurrency.value = '5' + AND subscriptions.value = '[]' + AND grant_on_signup.value = 'true' + AND grant_on_first_bind.value = 'false' + AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' +), +updated_signup_grants AS ( + UPDATE settings + SET + value = 'false', + updated_at = NOW() + FROM legacy_provider_defaults + WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup' + AND settings.value = 'true' + RETURNING legacy_provider_defaults.provider_type +) +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_auth_source_signup_grant_review', + providers.provider_type, + jsonb_build_object( + 'provider_type', providers.provider_type, + 'current_value', grant_on_signup.value, + 'auto_backfilled', FALSE, + 'reason', 'legacy_true_default_not_auto_backfilled' + ) +FROM providers +JOIN settings grant_on_signup + ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup' +LEFT JOIN updated_signup_grants + ON updated_signup_grants.provider_type = providers.provider_type +WHERE grant_on_signup.value = 'true' + AND updated_signup_grants.provider_type IS NULL +ON CONFLICT (report_type, report_key) DO NOTHING; From 1aab084ecb9ad1e92dce3f273d8dc73e11cb9f5c Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 14:57:16 +0800 Subject: [PATCH 21/31] fix(payment): restore upgrade-safe payment flows --- backend/internal/handler/auth_wechat_oauth.go | 7 +- .../handler/auth_wechat_oauth_test.go | 62 ++++ backend/internal/payment/provider/wxpay.go | 25 +- .../internal/payment/provider/wxpay_test.go | 66 +++++ .../service/payment_config_providers.go | 80 ++++-- .../service/payment_config_providers_test.go | 266 ++++++++++++++++++ .../service/payment_resume_service_test.go | 39 +++ backend/internal/service/payment_service.go | 29 +- ...ayment_orders_out_trade_no_unique_notx.sql | 2 + ...tity_payment_migrations_regression_test.go | 23 +- frontend/src/views/user/PaymentResultView.vue | 45 +-- frontend/src/views/user/PaymentView.vue | 27 +- .../user/__tests__/PaymentResultView.spec.ts | 15 +- .../views/user/__tests__/PaymentView.spec.ts | 27 ++ 14 files changed, 645 insertions(+), 68 deletions(-) diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index dc93fcae..efee4cc0 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -471,11 +471,12 @@ func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) { } func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService { + var legacyKey []byte key, err := payment.ProvideEncryptionKey(h.cfg) - if err != nil { - return service.NewPaymentResumeService(nil) + if err == nil { + legacyKey = []byte(key) } - return service.NewPaymentResumeService([]byte(key)) + return service.NewLegacyAwarePaymentResumeService(legacyKey) } type completeWeChatOAuthRequest struct { diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index d303bd42..7cf114c1 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -378,6 +378,7 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback")) defer client.Close() handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + handler.cfg.Totp.EncryptionKeyConfigured = true recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -415,6 +416,67 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) require.Equal(t, "/purchase?from=wechat", claims.RedirectTo) } +func TestWeChatPaymentOAuthCallbackUsesExplicitPaymentResumeSigningKeyWhenMixedKeysConfigured(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-mixed-key","scope":"snsapi_base"}`)) + return + } + http.NotFound(w, r) + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback")) + defer client.Close() + + legacyKeyHex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + explicitSigningKey := "explicit-payment-resume-signing-key" + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", explicitSigningKey) + handler.cfg.Totp.EncryptionKey = legacyKeyHex + handler.cfg.Totp.EncryptionKeyConfigured = true + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-mixed", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-mixed")) + req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat")) + req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"18.8","order_type":"subscription","plan_id":9}`)) + req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base")) + c.Request = req + + handler.WeChatPaymentOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + parsed, err := url.Parse(location) + require.NoError(t, err) + fragment, err := url.ParseQuery(parsed.Fragment) + require.NoError(t, err) + + token := fragment.Get("wechat_resume_token") + require.NotEmpty(t, token) + + claims, err := service.NewPaymentResumeService([]byte(explicitSigningKey)).ParseWeChatPaymentResumeToken(token) + require.NoError(t, err) + require.Equal(t, "openid-mixed-key", claims.OpenID) + require.Equal(t, payment.TypeWxpay, claims.PaymentType) + require.Equal(t, "18.8", claims.Amount) + require.Equal(t, payment.OrderTypeSubscription, claims.OrderType) + require.EqualValues(t, 9, claims.PlanID) + require.Equal(t, "/purchase?from=wechat", claims.RedirectTo) + + _, err = service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")).ParseWeChatPaymentResumeToken(token) + require.Error(t, err) +} + func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) { testCases := []struct { name string diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go index 4b334513..9927a265 100644 --- a/backend/internal/payment/provider/wxpay.go +++ b/backend/internal/payment/provider/wxpay.go @@ -204,8 +204,8 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ if err == nil { return resp, nil } - if strings.Contains(err.Error(), wxpayErrNoAuth) { - return nil, fmt.Errorf("wxpay h5 payments are not authorized for this merchant: %w", err) + if wxpayShouldFallbackToNative(err) { + return w.prepayNativeFallback(ctx, client, req, notifyURL, totalFen) } return nil, err case wxpayModeNative: @@ -292,6 +292,23 @@ func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.Create return &payment.CreatePaymentResponse{TradeNo: req.OrderID, PayURL: h5URL}, nil } +func (w *Wxpay) prepayNativeFallback(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) { + resp, err := w.prepayNative(ctx, c, req, notifyURL, totalFen) + if err != nil { + return nil, fmt.Errorf("wxpay native fallback after NO_AUTH: %w", err) + } + nativeURL := strings.TrimSpace(resp.PayURL) + if nativeURL == "" { + nativeURL = strings.TrimSpace(resp.QRCode) + } + if nativeURL == "" { + return resp, nil + } + resp.PayURL = nativeURL + resp.QRCode = nativeURL + return resp, nil +} + func buildWxpayH5Info(config map[string]string) *h5.H5Info { tp := wxpayH5Type info := &h5.H5Info{Type: &tp} @@ -304,6 +321,10 @@ func buildWxpayH5Info(config map[string]string) *h5.H5Info { return info } +func wxpayShouldFallbackToNative(err error) bool { + return err != nil && strings.Contains(err.Error(), wxpayErrNoAuth) +} + func resolveWxpayCreateMode(req payment.CreatePaymentRequest) (string, error) { if strings.TrimSpace(req.OpenID) != "" { return wxpayModeJSAPI, nil diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go index ebbd9d34..a5a406f9 100644 --- a/backend/internal/payment/provider/wxpay_test.go +++ b/backend/internal/payment/provider/wxpay_test.go @@ -8,6 +8,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" "net/url" "strings" "testing" @@ -641,3 +642,68 @@ func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) { t.Fatalf("pay_url = %q, want redirect_url query appended", resp.PayURL) } } + +func TestCreatePaymentMobileH5FallsBackToNativeOnNoAuth(t *testing.T) { + origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment + origNativePrepay := wxpayNativePrepay + origH5Prepay := wxpayH5Prepay + t.Cleanup(func() { + wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay + wxpayNativePrepay = origNativePrepay + wxpayH5Prepay = origH5Prepay + }) + + jsapiCalls := 0 + nativeCalls := 0 + h5Calls := 0 + wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) { + jsapiCalls++ + return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil + } + wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) { + h5Calls++ + return nil, nil, errors.New("NO_AUTH") + } + wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) { + nativeCalls++ + return &native.PrepayResponse{ + CodeUrl: core.String("weixin://wxpay/bizpayurl?pr=fallback-native"), + }, nil, nil + } + + provider := &Wxpay{ + config: map[string]string{ + "appId": "wx123", + "mchId": "mch123", + }, + coreClient: &core.Client{}, + } + + resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{ + OrderID: "sub2_100", + Amount: "66.88", + PaymentType: payment.TypeWxpay, + Subject: "Balance Recharge", + NotifyURL: "https://merchant.example/payment/notify", + ClientIP: "203.0.113.10", + IsMobile: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if jsapiCalls != 0 { + t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls) + } + if h5Calls != 1 { + t.Fatalf("h5 prepay calls = %d, want 1", h5Calls) + } + if nativeCalls != 1 { + t.Fatalf("native prepay calls = %d, want 1", nativeCalls) + } + if resp.PayURL != "weixin://wxpay/bizpayurl?pr=fallback-native" { + t.Fatalf("pay_url = %q, want native fallback url", resp.PayURL) + } + if resp.QRCode != "weixin://wxpay/bizpayurl?pr=fallback-native" { + t.Fatalf("qr_code = %q, want native fallback url", resp.QRCode) + } +} diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index d2f89b06..ff05e559 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -116,6 +116,17 @@ var providerSensitiveConfigFields = map[string]map[string]struct{}{ payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}}, } +// providerPendingOrderProtectedConfigFields lists config keys that cannot be +// changed while the instance has in-progress orders. This includes secrets plus +// all provider identity fields that are snapshotted into orders or used by +// webhook/refund verification. +var providerPendingOrderProtectedConfigFields = map[string]map[string]struct{}{ + payment.TypeEasyPay: {"pkey": {}, "pid": {}}, + payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}}, + payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}}, + payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}}, +} + func isSensitiveProviderConfigField(providerKey, fieldName string) bool { fields, ok := providerSensitiveConfigFields[providerKey] if !ok { @@ -125,6 +136,28 @@ func isSensitiveProviderConfigField(providerKey, fieldName string) bool { return found } +func hasPendingOrderProtectedConfigChange(providerKey string, currentConfig, nextConfig map[string]string) bool { + fields, ok := providerPendingOrderProtectedConfigFields[providerKey] + if !ok { + return false + } + for fieldName := range fields { + if providerConfigFieldValue(currentConfig, fieldName) != providerConfigFieldValue(nextConfig, fieldName) { + return true + } + } + return false +} + +func providerConfigFieldValue(config map[string]string, fieldName string) string { + for key, value := range config { + if strings.EqualFold(key, fieldName) { + return value + } + } + return "" +} + func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) { return s.entClient.PaymentOrder.Query(). Where( @@ -190,6 +223,18 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in if err != nil { return nil, fmt.Errorf("load provider instance: %w", err) } + var pendingOrderCount *int + getPendingOrderCount := func() (int, error) { + if pendingOrderCount != nil { + return *pendingOrderCount, nil + } + count, err := s.countPendingOrders(ctx, id) + if err != nil { + return 0, fmt.Errorf("check pending orders: %w", err) + } + pendingOrderCount = &count + return count, nil + } nextEnabled := current.Enabled if req.Enabled != nil { nextEnabled = *req.Enabled @@ -201,18 +246,20 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in if err := s.validateVisibleMethodEnablementConflicts(ctx, id, current.ProviderKey, nextSupportedTypes, nextEnabled); err != nil { return nil, err } + var mergedConfig map[string]string if req.Config != nil { - hasSensitive := false - for k, v := range req.Config { - if v != "" && isSensitiveProviderConfigField(current.ProviderKey, k) { - hasSensitive = true - break - } + currentConfig, err := s.decryptConfig(current.Config) + if err != nil { + return nil, fmt.Errorf("decrypt existing config: %w", err) } - if hasSensitive { - count, err := s.countPendingOrders(ctx, id) + mergedConfig, err = s.mergeConfig(ctx, id, req.Config) + if err != nil { + return nil, err + } + if hasPendingOrderProtectedConfigChange(current.ProviderKey, currentConfig, mergedConfig) { + count, err := getPendingOrderCount() if err != nil { - return nil, fmt.Errorf("check pending orders: %w", err) + return nil, err } if count > 0 { return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders"). @@ -221,9 +268,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } } if req.Enabled != nil && !*req.Enabled { - count, err := s.countPendingOrders(ctx, id) + count, err := getPendingOrderCount() if err != nil { - return nil, fmt.Errorf("check pending orders: %w", err) + return nil, err } if count > 0 { return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders"). @@ -237,13 +284,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in if req.Enabled != nil { finalEnabled = *req.Enabled } - var mergedConfig map[string]string - if req.Config != nil { - mergedConfig, err = s.mergeConfig(ctx, id, req.Config) - if err != nil { - return nil, err - } - } if finalEnabled { configToValidate := mergedConfig if configToValidate == nil { @@ -269,9 +309,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } if req.SupportedTypes != nil { // Check pending orders before removing payment types - count, err := s.countPendingOrders(ctx, id) + count, err := getPendingOrderCount() if err != nil { - return nil, fmt.Errorf("check pending orders: %w", err) + return nil, err } if count > 0 { // Load current instance to compare types diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go index 51d5c7b6..e0d2908a 100644 --- a/backend/internal/service/payment_config_providers_test.go +++ b/backend/internal/service/payment_config_providers_test.go @@ -8,8 +8,13 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "strconv" "testing" + "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -315,10 +320,263 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) { require.Equal(t, "alipay,wxpay", saved.SupportedTypes) } +func TestUpdateProviderInstanceRejectsProtectedConfigChangesWhilePendingOrders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + providerKey string + createConfig func(*testing.T) map[string]string + supportedType []string + updateConfig map[string]string + fieldName string + wantValue string + }{ + { + name: "wxpay appId", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfig, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"appId": "wx-app-updated"}, + fieldName: "appId", + wantValue: "wx-app-test", + }, + { + name: "wxpay mpAppId", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfigWithJSAPIAppID, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"mpAppId": "wx-mp-app-updated"}, + fieldName: "mpAppId", + wantValue: "wx-mp-app-test", + }, + { + name: "wxpay mchId", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfig, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"mchId": "mch-updated"}, + fieldName: "mchId", + wantValue: "mch-test", + }, + { + name: "wxpay publicKeyId", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfig, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"publicKeyId": "public-key-id-updated"}, + fieldName: "publicKeyId", + wantValue: "public-key-id-test", + }, + { + name: "wxpay certSerial", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfig, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"certSerial": "cert-serial-updated"}, + fieldName: "certSerial", + wantValue: "cert-serial-test", + }, + { + name: "alipay appId", + providerKey: payment.TypeAlipay, + createConfig: validAlipayProviderConfig, + supportedType: []string{payment.TypeAlipay}, + updateConfig: map[string]string{"appId": "alipay-app-updated"}, + fieldName: "appId", + wantValue: "alipay-app-test", + }, + { + name: "easypay pid", + providerKey: payment.TypeEasyPay, + createConfig: validEasyPayProviderConfig, + supportedType: []string{payment.TypeAlipay}, + updateConfig: map[string]string{"pid": "pid-updated"}, + fieldName: "pid", + wantValue: "pid-test", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + svc := &PaymentConfigService{ + entClient: client, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + } + + instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: tc.providerKey, + Name: "protected-config-instance", + Config: tc.createConfig(t), + SupportedTypes: tc.supportedType, + Enabled: true, + }) + require.NoError(t, err) + + createPendingProviderConfigOrder(t, ctx, client, instance) + + updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{ + Config: tc.updateConfig, + }) + require.Nil(t, updated) + require.Error(t, err) + require.Equal(t, "PENDING_ORDERS", infraerrors.Reason(err)) + + saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID) + require.NoError(t, err) + cfg, err := svc.decryptConfig(saved.Config) + require.NoError(t, err) + require.Equal(t, tc.wantValue, cfg[tc.fieldName]) + }) + } +} + +func TestUpdateProviderInstanceAllowsSafeConfigChangesWhilePendingOrders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + providerKey string + createConfig func(*testing.T) map[string]string + supportedType []string + updateConfig map[string]string + fieldName string + wantValue string + }{ + { + name: "wxpay notifyUrl", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfig, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"notifyUrl": "https://merchant.example.com/wxpay/notify-v2"}, + fieldName: "notifyUrl", + wantValue: "https://merchant.example.com/wxpay/notify-v2", + }, + { + name: "alipay same appId", + providerKey: payment.TypeAlipay, + createConfig: validAlipayProviderConfig, + supportedType: []string{payment.TypeAlipay}, + updateConfig: map[string]string{"appId": "alipay-app-test"}, + fieldName: "appId", + wantValue: "alipay-app-test", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + svc := &PaymentConfigService{ + entClient: client, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + } + + instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: tc.providerKey, + Name: "safe-config-instance", + Config: tc.createConfig(t), + SupportedTypes: tc.supportedType, + Enabled: true, + }) + require.NoError(t, err) + + createPendingProviderConfigOrder(t, ctx, client, instance) + + updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{ + Config: tc.updateConfig, + }) + require.NoError(t, err) + require.NotNil(t, updated) + + saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID) + require.NoError(t, err) + cfg, err := svc.decryptConfig(saved.Config) + require.NoError(t, err) + require.Equal(t, tc.wantValue, cfg[tc.fieldName]) + }) + } +} + +func createPendingProviderConfigOrder(t *testing.T, ctx context.Context, client *dbent.Client, instance *dbent.PaymentProviderInstance) { + t.Helper() + + user, err := client.User.Create(). + SetEmail("provider-config-pending@example.com"). + SetPasswordHash("hash"). + SetUsername("provider-config-pending-user"). + Save(ctx) + require.NoError(t, err) + + instanceID := strconv.FormatInt(instance.ID, 10) + _, err = client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("PENDING-PROVIDER-CONFIG-" + instanceID). + SetOutTradeNo("sub2_pending_provider_config_" + instanceID). + SetPaymentType(providerPendingOrderPaymentType(instance.ProviderKey)). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(instanceID). + SetProviderKey(instance.ProviderKey). + Save(ctx) + require.NoError(t, err) +} + +func providerPendingOrderPaymentType(providerKey string) string { + switch providerKey { + case payment.TypeWxpay: + return payment.TypeWxpay + case payment.TypeAlipay: + return payment.TypeAlipay + default: + return payment.TypeAlipay + } +} + func boolPtrValue(v bool) *bool { return &v } +func validAlipayProviderConfig(t *testing.T) map[string]string { + t.Helper() + + return map[string]string{ + "appId": "alipay-app-test", + "privateKey": "alipay-private-key-test", + "notifyUrl": "https://merchant.example.com/alipay/notify", + "returnUrl": "https://merchant.example.com/alipay/return", + } +} + +func validEasyPayProviderConfig(t *testing.T) map[string]string { + t.Helper() + + return map[string]string{ + "pid": "pid-test", + "pkey": "pkey-test", + "apiBase": "https://pay.example.com", + "notifyUrl": "https://merchant.example.com/easypay/notify", + "returnUrl": "https://merchant.example.com/easypay/return", + } +} + func validWxpayProviderConfig(t *testing.T) map[string]string { t.Helper() @@ -340,3 +598,11 @@ func validWxpayProviderConfig(t *testing.T) map[string]string { "certSerial": "cert-serial-test", } } + +func validWxpayProviderConfigWithJSAPIAppID(t *testing.T) map[string]string { + t.Helper() + + cfg := validWxpayProviderConfig(t) + cfg["mpAppId"] = "wx-mp-app-test" + return cfg +} diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go index 59a2221e..7e0adc2d 100644 --- a/backend/internal/service/payment_resume_service_test.go +++ b/backend/internal/service/payment_resume_service_test.go @@ -387,6 +387,45 @@ func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDu } } +func TestNewConfiguredPaymentResumeServicePrefersExplicitSigningKeyAndKeepsLegacyVerificationFallback(t *testing.T) { + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key") + + legacyKey := []byte("0123456789abcdef0123456789abcdef") + svc := newLegacyAwarePaymentResumeService(legacyKey) + + explicitToken, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-explicit-key", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + explicitClaims, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).ParseWeChatPaymentResumeToken(explicitToken) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if explicitClaims.OpenID != "openid-explicit-key" { + t.Fatalf("openid = %q, want %q", explicitClaims.OpenID, "openid-explicit-key") + } + + legacyToken, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-legacy-key", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + legacyClaims, err := svc.ParseWeChatPaymentResumeToken(legacyToken) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if legacyClaims.OpenID != "openid-legacy-key" { + t.Fatalf("openid = %q, want %q", legacyClaims.OpenID, "openid-legacy-key") + } +} + func TestNormalizeVisibleMethodSource(t *testing.T) { t.Parallel() diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 159f97d3..d39d2b19 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -268,8 +268,16 @@ func (s *PaymentService) paymentResume() *PaymentResumeService { return psNewPaymentResumeService(s.configService) } +func NewLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService { + return newLegacyAwarePaymentResumeService(legacyKey) +} + func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService { - signingKey, verifyFallbacks := psResumeSigningKeys(configService) + return newLegacyAwarePaymentResumeService(psResumeLegacyVerificationKey(configService)) +} + +func newLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService { + signingKey, verifyFallbacks := resolvePaymentResumeSigningKeys(legacyKey) return NewPaymentResumeService(signingKey, verifyFallbacks...) } @@ -279,8 +287,18 @@ func psResumeSigningKey(configService *PaymentConfigService) []byte { } func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte) { + return resolvePaymentResumeSigningKeys(psResumeLegacyVerificationKey(configService)) +} + +func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte { + if configService == nil { + return nil + } + return configService.encryptionKey +} + +func resolvePaymentResumeSigningKeys(legacyKey []byte) ([]byte, [][]byte) { signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv)) - legacyKey := psResumeLegacyVerificationKey(configService) if len(signingKey) == 0 { if len(legacyKey) == 0 { return nil, nil @@ -293,13 +311,6 @@ func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte) return signingKey, [][]byte{legacyKey} } -func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte { - if configService == nil { - return nil - } - return configService.encryptionKey -} - func parsePaymentResumeSigningKey(raw string) []byte { raw = strings.TrimSpace(raw) if raw == "" { diff --git a/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql index 00836698..638d8622 100644 --- a/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql +++ b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql @@ -1,4 +1,6 @@ -- Build the payment order uniqueness guarantee online. +-- The migration runner performs an explicit duplicate out_trade_no precheck and +-- drops any stale invalid paymentorder_out_trade_no_unique index before retrying. -- Create the new partial unique index concurrently first so writes keep flowing, -- then remove the legacy index name once the replacement is ready. CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go index dcb0bb9c..6a95d335 100644 --- a/backend/migrations/auth_identity_payment_migrations_regression_test.go +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -63,6 +63,8 @@ func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) { require.NoError(t, err) followupSQL := string(followupContent) + require.Contains(t, followupSQL, "explicit duplicate out_trade_no precheck") + require.Contains(t, followupSQL, "stale invalid paymentorder_out_trade_no_unique index") require.Contains(t, followupSQL, "CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique") require.NotContains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique") require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no") @@ -76,6 +78,18 @@ func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) { require.Contains(t, alignmentSQL, "RENAME TO paymentorder_out_trade_no") } +func TestMigration110SeedsAuthSourceSignupGrantsDisabledByDefault(t *testing.T) { + content, err := FS.ReadFile("110_pending_auth_and_provider_default_grants.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "('auth_source_default_email_grant_on_signup', 'false')") + require.Contains(t, sql, "('auth_source_default_linuxdo_grant_on_signup', 'false')") + require.Contains(t, sql, "('auth_source_default_oidc_grant_on_signup', 'false')") + require.Contains(t, sql, "('auth_source_default_wechat_grant_on_signup', 'false')") + require.NotContains(t, sql, "('auth_source_default_email_grant_on_signup', 'true')") +} + func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) { content, err := FS.ReadFile("122_pending_auth_completion_token_cleanup.sql") require.NoError(t, err) @@ -94,7 +108,10 @@ func TestMigration123BackfillsLegacyAuthSourceGrantDefaultsSafely(t *testing.T) require.NoError(t, err) sql := string(content) - require.Contains(t, sql, "Intentionally left as a no-op") - require.NotContains(t, sql, "UPDATE settings") - require.NotContains(t, sql, "value = 'false'") + require.Contains(t, sql, "110_pending_auth_and_provider_default_grants.sql") + require.Contains(t, sql, "schema_migrations") + require.Contains(t, sql, "updated_at") + require.Contains(t, sql, "'_grant_on_signup'") + require.Contains(t, sql, "value = 'false'") + require.Contains(t, sql, "auth_identity_migration_reports") } diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index b75d75df..f8ed37d9 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -291,6 +291,7 @@ onMounted(async () => { const routeOrderId = Number(readRouteQueryString('order_id')) || 0 let outTradeNo = readRouteQueryString('out_trade_no') let orderId = 0 + let resumeTokenLookupFailed = false const restored = restoreRecoverySnapshot({ resumeToken, @@ -312,24 +313,17 @@ onMounted(async () => { orderId = resolvedOrder.id } } else if (routeOrderId > 0) { + resumeTokenLookupFailed = true orderId = routeOrderId + } else { + resumeTokenLookupFailed = true } } else if (routeOrderId > 0) { orderId = routeOrderId } const hasLegacyFallbackContext = readRouteQueryString('trade_status').trim() !== '' - const shouldUsePublicOutTradeNo = !resumeToken && outTradeNo !== '' && (hasLegacyFallbackContext || routeOrderId > 0 || orderId > 0) - - if (!order.value && shouldUsePublicOutTradeNo) { - const legacyOrder = await resolveOrderFromOutTradeNo(outTradeNo) - if (legacyOrder) { - order.value = legacyOrder - if (!orderId) { - orderId = legacyOrder.id - } - } - } + const shouldUsePublicOutTradeNo = outTradeNo !== '' && (hasLegacyFallbackContext || routeOrderId > 0 || orderId > 0) if (!order.value && orderId && (!resumeToken || routeOrderId > 0)) { try { @@ -339,7 +333,17 @@ onMounted(async () => { } } - if (!order.value && !resumeToken && !orderId && outTradeNo && hasLegacyFallbackContext) { + if (!order.value && shouldUsePublicOutTradeNo && (!resumeToken || resumeTokenLookupFailed)) { + const legacyOrder = await resolveOrderFromOutTradeNo(outTradeNo) + if (legacyOrder) { + order.value = legacyOrder + if (!orderId) { + orderId = legacyOrder.id + } + } + } + + if (!order.value && !orderId && outTradeNo && hasLegacyFallbackContext) { returnInfo.value = { outTradeNo, money: String(route.query.money || ''), @@ -350,17 +354,24 @@ onMounted(async () => { const refreshOrder = async (): Promise => { if (resumeToken) { - return await resolveOrderFromResumeToken(resumeToken) + const resolvedOrder = await resolveOrderFromResumeToken(resumeToken) + if (resolvedOrder) { + return resolvedOrder + } + } + + if (orderId) { + try { + return await paymentStore.pollOrderStatus(orderId) + } catch (_err: unknown) { + // Fall through to legacy public verification when order polling is unavailable. + } } if (shouldUsePublicOutTradeNo) { return await resolveOrderFromOutTradeNo(outTradeNo) } - if (orderId) { - return await paymentStore.pollOrderStatus(orderId) - } - return null } diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index 05d70512..1040d3f6 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -740,18 +740,23 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n return } if (decision.kind === 'wechat_jsapi' && decision.jsapi) { - const jsapiResult = await invokeWechatJsapiPayment(decision.jsapi as Record) - const errMsg = String(jsapiResult.err_msg || '').toLowerCase() - if (errMsg.includes('cancel')) { - appStore.showInfo(t('payment.qr.cancelled')) + try { + const jsapiResult = await invokeWechatJsapiPayment(decision.jsapi as Record) + const errMsg = String(jsapiResult.err_msg || '').toLowerCase() + if (errMsg.includes('cancel')) { + appStore.showInfo(t('payment.qr.cancelled')) + resetPayment() + } else if (errMsg && !errMsg.includes('ok')) { + applyScenarioError({ reason: 'WECHAT_JSAPI_FAILED', message: errMsg }, visibleMethod) + resetPayment() + } else { + const resultState = { ...decision.paymentState } + resetPayment() + await redirectToPaymentResult(resultState) + } + } catch (err: unknown) { resetPayment() - } else if (errMsg && !errMsg.includes('ok')) { - applyScenarioError({ reason: 'WECHAT_JSAPI_FAILED', message: errMsg }, visibleMethod) - resetPayment() - } else { - const resultState = { ...decision.paymentState } - resetPayment() - await redirectToPaymentResult(resultState) + throw err } return } diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts index 81a7ccf0..556cf390 100644 --- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts @@ -255,14 +255,21 @@ describe('PaymentResultView', () => { expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() }) - it('does not fall back to public out_trade_no verification when resume_token recovery fails', async () => { + it('falls back to public out_trade_no verification when resume_token recovery fails in legacy return flows', async () => { routeState.query = { resume_token: 'resume-fail', out_trade_no: 'legacy-should-not-run', trade_status: 'TRADE_SUCCESS', } resolveOrderPublicByResumeToken.mockRejectedValueOnce(new Error('resume failed')) - mount(PaymentResultView, { + verifyOrderPublic.mockResolvedValueOnce({ + data: { + ...orderFactory('PAID'), + out_trade_no: 'legacy-should-not-run', + }, + }) + + const wrapper = mount(PaymentResultView, { global: { stubs: { OrderStatusBadge: true, @@ -273,7 +280,9 @@ describe('PaymentResultView', () => { await flushPromises() expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-fail') - expect(verifyOrderPublic).not.toHaveBeenCalled() + expect(verifyOrderPublic).toHaveBeenCalledWith('legacy-should-not-run') + expect(pollOrderStatus).not.toHaveBeenCalled() + expect(wrapper.text()).toContain('payment.result.success') }) it('ignores a stale global recovery snapshot when legacy return markers do not identify the order', async () => { diff --git a/frontend/src/views/user/__tests__/PaymentView.spec.ts b/frontend/src/views/user/__tests__/PaymentView.spec.ts index 2b81a085..d2683161 100644 --- a/frontend/src/views/user/__tests__/PaymentView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentView.spec.ts @@ -252,6 +252,33 @@ describe('PaymentView WeChat JSAPI flow', () => { expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() }) + it('clears stale recovery state when JSAPI never becomes available', async () => { + vi.useFakeTimers() + createOrder.mockResolvedValue(jsapiOrderFixture('resume-token-missing-bridge')) + ;(window as Window & { WeixinJSBridge?: { invoke: typeof bridgeInvoke } }).WeixinJSBridge = undefined + + const wrapper = shallowMount(PaymentView, { + global: { + stubs: { + Teleport: true, + Transition: false, + }, + }, + }) + + await flushPromises() + await vi.advanceTimersByTimeAsync(4000) + await flushPromises() + await flushPromises() + + expect(showError).toHaveBeenCalledWith( + 'payment.errors.wechatJsapiUnavailable payment.errors.wechatOpenInWeChatHint', + ) + expect(routerPush).not.toHaveBeenCalled() + expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toBeNull() + expect(wrapper.html()).not.toContain('payment-status-panel-stub') + }) + it('clears a stale recovery snapshot before handling wechat resume callback params', async () => { createOrder.mockRejectedValueOnce(new Error('resume failed')) window.localStorage.setItem(PAYMENT_RECOVERY_STORAGE_KEY, JSON.stringify({ From ca4e38aa0167d2b7a013b6d363c950a3a15dc834 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 14:57:47 +0800 Subject: [PATCH 22/31] fix(profile): stabilize binding compatibility and frontend checks --- .github/workflows/backend-ci.yml | 22 +- Makefile | 14 +- README.md | 12 +- README_CN.md | 12 +- README_JA.md | 12 +- backend/internal/handler/user_handler.go | 6 +- backend/internal/handler/user_handler_test.go | 46 ++- backend/internal/server/api_contract_test.go | 18 +- backend/internal/server/routes/auth.go | 36 +- .../internal/service/auth_email_binding.go | 9 + .../service/auth_service_email_bind_test.go | 326 +++++++++++++++++- backend/internal/service/user_service.go | 78 ++++- backend/internal/service/user_service_test.go | 100 ++++++ .../src/api/__tests__/admin.users.spec.ts | 117 +++++++ frontend/src/api/__tests__/client.spec.ts | 16 + frontend/src/api/__tests__/user.spec.ts | 32 ++ frontend/src/api/admin/users.ts | 26 +- frontend/src/api/auth.ts | 5 +- frontend/src/api/client.ts | 1 + frontend/src/api/user.ts | 2 +- .../ProfileIdentityBindingsSection.vue | 64 +++- .../user/profile/ProfileInfoCard.vue | 27 +- .../ProfileIdentityBindingsSection.spec.ts | 47 ++- .../profile/__tests__/ProfileInfoCard.spec.ts | 20 ++ frontend/src/views/admin/SettingsView.vue | 18 +- .../admin/__tests__/SettingsView.spec.ts | 31 +- .../src/views/auth/LinuxDoCallbackView.vue | 9 +- .../src/views/auth/WechatCallbackView.vue | 8 +- .../__tests__/LinuxDoCallbackView.spec.ts | 27 ++ .../auth/__tests__/WechatCallbackView.spec.ts | 28 ++ 30 files changed, 1072 insertions(+), 97 deletions(-) create mode 100644 frontend/src/api/__tests__/admin.users.spec.ts create mode 100644 frontend/src/api/__tests__/user.spec.ts diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index d7e15377..f8b22ee7 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -28,6 +28,26 @@ jobs: working-directory: backend run: make test-integration + frontend: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 9 + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: '20' + cache: 'pnpm' + cache-dependency-path: frontend/pnpm-lock.yaml + - name: Install frontend dependencies + working-directory: frontend + run: pnpm install --frozen-lockfile + - name: Frontend typecheck and critical vitest + run: make test-frontend + golangci-lint: runs-on: ubuntu-latest steps: @@ -46,4 +66,4 @@ jobs: with: version: v2.9 args: --timeout=30m - working-directory: backend \ No newline at end of file + working-directory: backend diff --git a/Makefile b/Makefile index fd6a5a9a..d00d0c4f 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,12 @@ -.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan +.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-frontend-critical test-datamanagementd secret-scan + +FRONTEND_CRITICAL_VITEST := \ + src/views/auth/__tests__/LinuxDoCallbackView.spec.ts \ + src/views/auth/__tests__/WechatCallbackView.spec.ts \ + src/views/user/__tests__/PaymentView.spec.ts \ + src/views/user/__tests__/PaymentResultView.spec.ts \ + src/components/user/profile/__tests__/ProfileInfoCard.spec.ts \ + src/views/admin/__tests__/SettingsView.spec.ts # 一键编译前后端 build: build-backend build-frontend @@ -24,6 +32,10 @@ test-backend: test-frontend: @pnpm --dir frontend run lint:check @pnpm --dir frontend run typecheck + @$(MAKE) test-frontend-critical + +test-frontend-critical: + @pnpm --dir frontend exec vitest run $(FRONTEND_CRITICAL_VITEST) test-datamanagementd: @cd datamanagement && go test ./... diff --git a/README.md b/README.md index 3e609d65..aa27d907 100644 --- a/README.md +++ b/README.md @@ -42,10 +42,18 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot - **Smart Scheduling** - Intelligent account selection with sticky sessions - **Concurrency Control** - Per-user and per-account concurrency limits - **Rate Limiting** - Configurable request and token rate limits -- **Built-in Payment System** - Supports EasyPay, Alipay, WeChat Pay, and Stripe for user self-service top-up, no separate payment service needed ([Configuration Guide](docs/PAYMENT.md)) +- **Built-in Payment System** - Supports EasyPay, Alipay, WeChat Pay, and Stripe for user self-service top-up, no separate payment service needed ([Payment Setup](#payment)) - **Admin Dashboard** - Web interface for monitoring and management - **External System Integration** - Embed external systems (e.g. ticketing) via iframe to extend the admin dashboard +## Payment + +Sub2API includes the payment system in the main service. No standalone payment service or separate payment guide is required. + +- Supported providers: EasyPay, Alipay, WeChat Pay, Stripe +- The frontend keeps user-facing methods unified; admins choose the backing source in `Admin -> Settings -> Payment` +- Callback URLs are generated from the site domain when configuring providers + ## ❤️ Sponsors > [Want to appear here?](mailto:support@pincc.ai) @@ -109,7 +117,7 @@ Community projects that extend or integrate with Sub2API: | Project | Description | Features | |---------|-------------|----------| -| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~Self-service payment system~~ | **Now Built-in** — Payment is now integrated into Sub2API, no separate deployment needed. See [Payment Configuration Guide](docs/PAYMENT.md) | +| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~Self-service payment system~~ | **Now Built-in** — Payment is now integrated into Sub2API, no separate deployment needed. See [Payment Setup](#payment) | | [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native | ## Tech Stack diff --git a/README_CN.md b/README_CN.md index add32a17..530a0c80 100644 --- a/README_CN.md +++ b/README_CN.md @@ -41,10 +41,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 - **智能调度** - 智能账号选择,支持粘性会话 - **并发控制** - 用户级和账号级并发限制 - **速率限制** - 可配置的请求和 Token 速率限制 -- **内置支付系统** - 支持 EasyPay 易支付、支付宝官方、微信官方、Stripe,用户自助充值,无需独立部署支付服务([配置指南](docs/PAYMENT_CN.md)) +- **内置支付系统** - 支持 EasyPay 易支付、支付宝官方、微信官方、Stripe,用户自助充值,无需独立部署支付服务([支付说明](#支付)) - **管理后台** - Web 界面进行监控和管理 - **外部系统集成** - 支持通过 iframe 嵌入外部系统(如工单等),扩展管理后台功能 +## 支付 + +Sub2API 已将支付系统集成到主服务中,无需独立支付服务,也不再依赖单独的支付配置文档。 + +- 支持服务商:EasyPay 易支付、支付宝官方、微信官方、Stripe +- 前台统一展示用户可见支付方式,管理员在 `管理后台 -> 设置 -> 支付` 里选择对应来源 +- 添加服务商时会基于站点域名生成回调地址 + ## ❤️ 赞助商 > [想出现在这里?](mailto:support@pincc.ai) @@ -108,7 +116,7 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 | 项目 | 说明 | 功能 | |------|------|------| -| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~自助支付系统~~ | **已内置** — 支付功能已集成到 Sub2API 中,无需独立部署。详见 [支付配置指南](docs/PAYMENT_CN.md) | +| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~自助支付系统~~ | **已内置** — 支付功能已集成到 Sub2API 中,无需独立部署。详见 [支付说明](#支付) | | [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 | ## 技术栈 diff --git a/README_JA.md b/README_JA.md index ccd595b9..b852b358 100644 --- a/README_JA.md +++ b/README_JA.md @@ -42,10 +42,18 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを - **スマートスケジューリング** - スティッキーセッション付きのインテリジェントなアカウント選択 - **同時実行制御** - ユーザーごと・アカウントごとの同時実行数制限 - **レート制限** - 設定可能なリクエスト数およびトークンレート制限 -- **内蔵決済システム** - EasyPay、Alipay、WeChat Pay、Stripe に対応。ユーザーのセルフサービスチャージが可能で、別途決済サービスのデプロイは不要([設定ガイド](docs/PAYMENT.md)) +- **内蔵決済システム** - EasyPay、Alipay、WeChat Pay、Stripe に対応。ユーザーのセルフサービスチャージが可能で、別途決済サービスのデプロイは不要([決済案内](#決済)) - **管理ダッシュボード** - 監視・管理のための Web インターフェース - **外部システム連携** - 外部システム(チケット管理など)を iframe 経由で管理ダッシュボードに埋め込み可能 +## 決済 + +Sub2API の決済機能は本体に統合されています。独立した決済サービスや別個の決済ガイドは不要です。 + +- 対応プロバイダー: EasyPay、Alipay、WeChat Pay、Stripe +- フロントエンドではユーザー向け決済方法を統一表示し、管理者は `管理画面 -> 設定 -> 決済` で実際の接続先を選択します +- プロバイダー設定時のコールバック URL はサイトドメインから自動生成されます + ## ❤️ スポンサー > [こちらに掲載しませんか?](mailto:support@pincc.ai) @@ -108,7 +116,7 @@ Sub2API を拡張・統合するコミュニティプロジェクト: | プロジェクト | 説明 | 機能 | |---------|-------------|----------| -| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~セルフサービス決済システム~~ | **内蔵済み** — 決済機能は Sub2API に統合されました。別途デプロイは不要です。[決済設定ガイド](docs/PAYMENT.md)をご参照ください | +| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~セルフサービス決済システム~~ | **内蔵済み** — 決済機能は Sub2API に統合されました。別途デプロイは不要です。[決済案内](#決済)をご参照ください | | [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | モバイル管理コンソール | ユーザー管理、アカウント管理、監視ダッシュボード、マルチバックエンド切り替えが可能なクロスプラットフォームアプリ(iOS/Android/Web)。Expo + React Native で構築 | ## 技術スタック diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 867d8c9e..f74c2b72 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -249,7 +249,7 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) { return } - updatedUser, err := h.userService.UnbindUserAuthProvider( + updatedUser, unbound, err := h.userService.UnbindUserAuthProviderWithResult( c.Request.Context(), subject.UserID, c.Param("provider"), @@ -258,7 +258,7 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) { response.ErrorFrom(c, err) return } - if h.authService != nil { + if unbound && h.authService != nil { if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil { response.ErrorFrom(c, err) return @@ -512,7 +512,7 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity var avatarSource *userProfileSourceContext avatarValue := strings.TrimSpace(user.AvatarURL) for _, summary := range thirdParty { - if avatarValue != "" && avatarValue == strings.TrimSpace(summary.DisplayName) { + if avatarValue != "" && avatarValue == strings.TrimSpace(summary.AvatarURL) { avatarSource = buildUserProfileSourceContext(summary.Provider) break } diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index c212603b..e4985a22 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -636,6 +636,50 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure require.Equal(t, int64(5), repo.user.TokenVersion) } +func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 24, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + TokenVersion: 4, + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "identity@example.com", + }, + }, + } + refreshTokenCache := &userHandlerRefreshTokenCacheStub{} + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 24}) + c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}} + + handler.UnbindIdentity(c) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Empty(t, repo.unbound) + require.Empty(t, refreshTokenCache.revokedUserIDs) + require.Equal(t, int64(4), repo.user.TokenVersion) +} + func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) { gin.SetMode(gin.TestMode) @@ -728,7 +772,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { require.Equal(t, "wechat", resp.Data.Provider) require.Equal(t, "GET", resp.Data.Method) require.True(t, resp.Data.UseBrowserRedirect) - require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/start") + require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/bind/start") require.Contains(t, resp.Data.AuthorizeURL, "intent=bind_current_user") require.Contains(t, resp.Data.AuthorizeURL, "redirect=%2Fsettings%2Fprofile") } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 3d933dbc..30ddf0a2 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -85,7 +85,7 @@ func TestAPIContracts(t *testing.T) { "bound_count": 0, "can_bind": true, "can_unbind": false, - "bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" }, "oidc": { "provider": "oidc", @@ -93,7 +93,7 @@ func TestAPIContracts(t *testing.T) { "bound_count": 0, "can_bind": true, "can_unbind": false, - "bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" }, "wechat": { "provider": "wechat", @@ -101,7 +101,7 @@ func TestAPIContracts(t *testing.T) { "bound_count": 0, "can_bind": true, "can_unbind": false, - "bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" } }, "identity_bindings": { @@ -122,7 +122,7 @@ func TestAPIContracts(t *testing.T) { "bound_count": 0, "can_bind": true, "can_unbind": false, - "bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" }, "oidc": { "provider": "oidc", @@ -130,7 +130,7 @@ func TestAPIContracts(t *testing.T) { "bound_count": 0, "can_bind": true, "can_unbind": false, - "bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" }, "wechat": { "provider": "wechat", @@ -138,7 +138,7 @@ func TestAPIContracts(t *testing.T) { "bound_count": 0, "can_bind": true, "can_unbind": false, - "bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" } }, "auth_bindings": { @@ -159,7 +159,7 @@ func TestAPIContracts(t *testing.T) { "bound_count": 0, "can_bind": true, "can_unbind": false, - "bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" }, "oidc": { "provider": "oidc", @@ -167,7 +167,7 @@ func TestAPIContracts(t *testing.T) { "bound_count": 0, "can_bind": true, "can_unbind": false, - "bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" }, "wechat": { "provider": "wechat", @@ -175,7 +175,7 @@ func TestAPIContracts(t *testing.T) { "bound_count": 0, "can_bind": true, "can_unbind": false, - "bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" } }, "run_mode": "standard" diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index b4b75795..642a2103 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -63,8 +63,20 @@ func RegisterAuthRoutes( FailureMode: middleware.RateLimitFailClose, }), h.Auth.ResetPassword) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) + auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.LinuxDoOAuthStart(c) + }) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart) + auth.GET("/oauth/wechat/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.WeChatOAuthStart(c) + }) auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback) auth.GET("/oauth/wechat/payment/start", h.Auth.WeChatPaymentOAuthStart) auth.GET("/oauth/wechat/payment/callback", h.Auth.WeChatPaymentOAuthCallback) @@ -129,6 +141,12 @@ func RegisterAuthRoutes( h.Auth.CreateWeChatOAuthAccount, ) auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart) + auth.GET("/oauth/oidc/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.OIDCOAuthStart(c) + }) auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback) auth.POST("/oauth/oidc/complete-registration", rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{ @@ -165,23 +183,5 @@ func RegisterAuthRoutes( // 撤销所有会话(需要认证) authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions) authenticated.POST("/auth/oauth/bind-token", h.Auth.PrepareOAuthBindAccessTokenCookie) - authenticated.GET("/auth/oauth/linuxdo/bind/start", func(c *gin.Context) { - query := c.Request.URL.Query() - query.Set("intent", "bind_current_user") - c.Request.URL.RawQuery = query.Encode() - h.Auth.LinuxDoOAuthStart(c) - }) - authenticated.GET("/auth/oauth/oidc/bind/start", func(c *gin.Context) { - query := c.Request.URL.Query() - query.Set("intent", "bind_current_user") - c.Request.URL.RawQuery = query.Encode() - h.Auth.OIDCOAuthStart(c) - }) - authenticated.GET("/auth/oauth/wechat/bind/start", func(c *gin.Context) { - query := c.Request.URL.Query() - query.Set("intent", "bind_current_user") - c.Request.URL.RawQuery = query.Encode() - h.Auth.WeChatOAuthStart(c) - }) } } diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go index f0483800..78f1185d 100644 --- a/backend/internal/service/auth_email_binding.go +++ b/backend/internal/service/auth_email_binding.go @@ -11,6 +11,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // BindEmailIdentity verifies and binds a local email/password identity to the @@ -69,6 +70,7 @@ func (s *AuthService) BindEmailIdentity( if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil { return nil, err } + s.revokeEmailIdentitySessions(ctx, userID) return currentUser, nil } @@ -87,6 +89,7 @@ func (s *AuthService) BindEmailIdentity( } } + s.revokeEmailIdentitySessions(ctx, userID) return currentUser, nil } @@ -219,6 +222,12 @@ func (s *AuthService) updateBoundEmailIdentityWithClient( return nil } +func (s *AuthService) revokeEmailIdentitySessions(ctx context.Context, userID int64) { + if err := s.RevokeAllUserSessions(ctx, userID); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after email identity bind for user %d: %v", userID, err) + } +} + func replaceBoundEmailAuthIdentityWithClient( ctx context.Context, client *dbent.Client, diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go index d32a4a40..cced842a 100644 --- a/backend/internal/service/auth_service_email_bind_test.go +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -6,6 +6,7 @@ import ( "context" "database/sql" "errors" + "sync" "testing" "time" @@ -13,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" @@ -54,6 +56,16 @@ func newAuthServiceForEmailBind( settings map[string]string, emailCache service.EmailCache, defaultSubAssigner service.DefaultSubscriptionAssigner, +) (*service.AuthService, service.UserRepository, *dbent.Client) { + return newAuthServiceForEmailBindWithRefreshCache(t, settings, emailCache, defaultSubAssigner, nil) +} + +func newAuthServiceForEmailBindWithRefreshCache( + t *testing.T, + settings map[string]string, + emailCache service.EmailCache, + defaultSubAssigner service.DefaultSubscriptionAssigner, + refreshTokenCache service.RefreshTokenCache, ) (*service.AuthService, service.UserRepository, *dbent.Client) { t.Helper() @@ -98,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants ( emailSvc = service.NewEmailService(settingRepo, emailCache) } - svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner) + svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner) return svc, repo, client } @@ -427,6 +439,61 @@ func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t require.Equal(t, 0, newIdentityCount) } +func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *testing.T) { + ctx := context.Background() + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + refreshTokenCache := newEmailBindRefreshTokenCacheStub() + userRepo := newEmailBindUserRepoStub(&service.User{ + ID: 41, + Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain, + Username: "legacy-user", + PasswordHash: "old-hash", + Role: service.RoleUser, + Status: service.StatusActive, + TokenVersion: 4, + }) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-bind-email-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + } + emailService := service.NewEmailService(nil, cache) + svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil) + + oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{ + ID: 41, + Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain, + Role: service.RoleUser, + Status: service.StatusActive, + TokenVersion: 4, + }, "") + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, 41, "new@example.com", "123456", "new-password") + require.NoError(t, err) + require.NotNil(t, updatedUser) + + storedUser, err := userRepo.GetByID(ctx, 41) + require.NoError(t, err) + require.Equal(t, "new@example.com", storedUser.Email) + require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash)) + + _, err = svc.RefreshToken(ctx, oldTokenPair.AccessToken) + require.ErrorIs(t, err, service.ErrTokenRevoked) + + _, err = svc.RefreshTokenPair(ctx, oldTokenPair.RefreshToken) + require.True(t, errors.Is(err, service.ErrTokenRevoked) || errors.Is(err, service.ErrRefreshTokenInvalid)) +} + type emailBindSettingRepoStub struct { values map[string]string } @@ -527,3 +594,260 @@ func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int6 func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { return 0, nil } + +type emailBindRefreshTokenCacheStub struct { + mu sync.Mutex + tokens map[string]*service.RefreshTokenData + userSets map[int64]map[string]struct{} + families map[string]map[string]struct{} +} + +func newEmailBindRefreshTokenCacheStub() *emailBindRefreshTokenCacheStub { + return &emailBindRefreshTokenCacheStub{ + tokens: make(map[string]*service.RefreshTokenData), + userSets: make(map[int64]map[string]struct{}), + families: make(map[string]map[string]struct{}), + } +} + +func (s *emailBindRefreshTokenCacheStub) StoreRefreshToken(_ context.Context, tokenHash string, data *service.RefreshTokenData, _ time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + cloned := *data + s.tokens[tokenHash] = &cloned + return nil +} + +func (s *emailBindRefreshTokenCacheStub) GetRefreshToken(_ context.Context, tokenHash string) (*service.RefreshTokenData, error) { + s.mu.Lock() + defer s.mu.Unlock() + data, ok := s.tokens[tokenHash] + if !ok { + return nil, service.ErrRefreshTokenNotFound + } + cloned := *data + return &cloned, nil +} + +func (s *emailBindRefreshTokenCacheStub) DeleteRefreshToken(_ context.Context, tokenHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.tokens, tokenHash) + for _, tokenSet := range s.userSets { + delete(tokenSet, tokenHash) + } + for _, tokenSet := range s.families { + delete(tokenSet, tokenHash) + } + return nil +} + +func (s *emailBindRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error { + s.mu.Lock() + defer s.mu.Unlock() + for tokenHash := range s.userSets[userID] { + delete(s.tokens, tokenHash) + for _, tokenSet := range s.families { + delete(tokenSet, tokenHash) + } + } + delete(s.userSets, userID) + return nil +} + +func (s *emailBindRefreshTokenCacheStub) DeleteTokenFamily(_ context.Context, familyID string) error { + s.mu.Lock() + defer s.mu.Unlock() + for tokenHash := range s.families[familyID] { + delete(s.tokens, tokenHash) + for _, tokenSet := range s.userSets { + delete(tokenSet, tokenHash) + } + } + delete(s.families, familyID) + return nil +} + +func (s *emailBindRefreshTokenCacheStub) AddToUserTokenSet(_ context.Context, userID int64, tokenHash string, _ time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.userSets[userID] == nil { + s.userSets[userID] = make(map[string]struct{}) + } + s.userSets[userID][tokenHash] = struct{}{} + return nil +} + +func (s *emailBindRefreshTokenCacheStub) AddToFamilyTokenSet(_ context.Context, familyID string, tokenHash string, _ time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.families[familyID] == nil { + s.families[familyID] = make(map[string]struct{}) + } + s.families[familyID][tokenHash] = struct{}{} + return nil +} + +func (s *emailBindRefreshTokenCacheStub) GetUserTokenHashes(_ context.Context, userID int64) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + tokenSet := s.userSets[userID] + out := make([]string, 0, len(tokenSet)) + for tokenHash := range tokenSet { + out = append(out, tokenHash) + } + return out, nil +} + +func (s *emailBindRefreshTokenCacheStub) GetFamilyTokenHashes(_ context.Context, familyID string) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + tokenSet := s.families[familyID] + out := make([]string, 0, len(tokenSet)) + for tokenHash := range tokenSet { + out = append(out, tokenHash) + } + return out, nil +} + +func (s *emailBindRefreshTokenCacheStub) IsTokenInFamily(_ context.Context, familyID string, tokenHash string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + _, ok := s.families[familyID][tokenHash] + return ok, nil +} + +type emailBindUserRepoStub struct { + mu sync.Mutex + usersByID map[int64]*service.User + usersByEmail map[string]*service.User +} + +func newEmailBindUserRepoStub(user *service.User) *emailBindUserRepoStub { + cloned := cloneEmailBindUser(user) + return &emailBindUserRepoStub{ + usersByID: map[int64]*service.User{ + cloned.ID: cloned, + }, + usersByEmail: map[string]*service.User{ + cloned.Email: cloned, + }, + } +} + +func (s *emailBindUserRepoStub) Create(context.Context, *service.User) error { return nil } + +func (s *emailBindUserRepoStub) GetByID(_ context.Context, id int64) (*service.User, error) { + s.mu.Lock() + defer s.mu.Unlock() + user, ok := s.usersByID[id] + if !ok { + return nil, service.ErrUserNotFound + } + return cloneEmailBindUser(user), nil +} + +func (s *emailBindUserRepoStub) GetByEmail(_ context.Context, email string) (*service.User, error) { + s.mu.Lock() + defer s.mu.Unlock() + user, ok := s.usersByEmail[email] + if !ok { + return nil, service.ErrUserNotFound + } + return cloneEmailBindUser(user), nil +} + +func (s *emailBindUserRepoStub) GetFirstAdmin(context.Context) (*service.User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (s *emailBindUserRepoStub) Update(_ context.Context, user *service.User) error { + s.mu.Lock() + defer s.mu.Unlock() + existing, ok := s.usersByID[user.ID] + if !ok { + return service.ErrUserNotFound + } + delete(s.usersByEmail, existing.Email) + cloned := cloneEmailBindUser(user) + s.usersByID[user.ID] = cloned + s.usersByEmail[cloned.Email] = cloned + return nil +} + +func (s *emailBindUserRepoStub) Delete(context.Context, int64) error { return nil } + +func (s *emailBindUserRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { + return nil, nil +} + +func (s *emailBindUserRepoStub) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (s *emailBindUserRepoStub) DeleteUserAvatar(context.Context, int64) error { + panic("unexpected DeleteUserAvatar call") +} + +func (s *emailBindUserRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *emailBindUserRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *emailBindUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (s *emailBindUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} + +func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error { + return nil +} + +func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } +func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } +func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } + +func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + _, ok := s.usersByEmail[email] + return ok, nil +} + +func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *emailBindUserRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { + return nil +} + +func (s *emailBindUserRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} + +func (s *emailBindUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) { + return nil, nil +} + +func (s *emailBindUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { + return nil +} + +func (s *emailBindUserRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (s *emailBindUserRepoStub) EnableTotp(context.Context, int64) error { return nil } +func (s *emailBindUserRepoStub) DisableTotp(context.Context, int64) error { return nil } + +func cloneEmailBindUser(user *service.User) *service.User { + if user == nil { + return nil + } + cloned := *user + return &cloned +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index c16d810b..a211103f 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -127,6 +127,7 @@ type UserIdentitySummary struct { Bound bool `json:"bound"` BoundCount int `json:"bound_count"` DisplayName string `json:"display_name,omitempty"` + AvatarURL string `json:"-"` SubjectHint string `json:"subject_hint,omitempty"` ProviderKey string `json:"provider_key,omitempty"` VerifiedAt *time.Time `json:"verified_at,omitempty"` @@ -228,6 +229,7 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro if err != nil { return nil, fmt.Errorf("get user: %w", err) } + normalizeLoadedUserTokenVersion(user) if err := s.hydrateUserAvatar(ctx, user); err != nil { return nil, fmt.Errorf("get user avatar: %w", err) } @@ -323,29 +325,34 @@ func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUs } func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) (*User, error) { + user, _, err := s.UnbindUserAuthProviderWithResult(ctx, userID, provider) + return user, err +} + +func (s *UserService) UnbindUserAuthProviderWithResult(ctx context.Context, userID int64, provider string) (*User, bool, error) { provider = normalizeUserIdentityProvider(provider) if provider == "" || provider == "email" { - return nil, ErrIdentityProviderInvalid + return nil, false, ErrIdentityProviderInvalid } user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - return nil, fmt.Errorf("get user: %w", err) + return nil, false, fmt.Errorf("get user: %w", err) } records, err := s.listUserAuthIdentities(ctx, userID) if err != nil { - return nil, err + return nil, false, err } if len(filterUserAuthIdentities(records, provider)) == 0 { - return user, nil + return user, false, nil } if !s.canUnbindProvider(provider, user, records) { - return nil, ErrIdentityUnbindLastMethod + return nil, false, ErrIdentityUnbindLastMethod } if err := s.userRepo.UnbindUserAuthProvider(ctx, userID, provider); err != nil { - return nil, err + return nil, false, err } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) @@ -353,9 +360,9 @@ func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64, updatedUser, err := s.GetProfile(ctx, userID) if err != nil { - return nil, err + return nil, false, err } - return updatedUser, nil + return updatedUser, true, nil } // UpdateProfile 更新用户资料 @@ -655,6 +662,7 @@ func (s *UserService) buildProviderIdentitySummary(provider string, user *User, summary.Bound = true summary.BoundCount = len(filtered) summary.DisplayName = userAuthIdentityDisplayName(primary) + summary.AvatarURL = strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "avatar_url", "suggested_avatar_url", "headimgurl")) summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject) summary.ProviderKey = strings.TrimSpace(primary.ProviderKey) summary.VerifiedAt = primary.VerifiedAt @@ -672,7 +680,7 @@ func (s *UserService) canUnbindProvider(provider string, user *User, records []U return false } - if s.buildEmailIdentitySummary(user, records).Bound { + if s.canUseEmailAsSignInMethod(user, records) { return true } @@ -688,6 +696,44 @@ func (s *UserService) canUnbindProvider(provider string, user *User, records []U return false } +func (s *UserService) canUseEmailAsSignInMethod(user *User, records []UserAuthIdentityRecord) bool { + if user == nil { + return false + } + + email := strings.ToLower(strings.TrimSpace(user.Email)) + if email == "" || isReservedEmail(email) { + return false + } + + if emailSignupSourceAllowsLogin(user.SignupSource) { + return true + } + + for _, record := range filterUserAuthIdentities(records, "email") { + if emailIdentitySupportsSignIn(record) { + return true + } + } + + return false +} + +func emailSignupSourceAllowsLogin(signupSource string) bool { + signupSource = strings.ToLower(strings.TrimSpace(signupSource)) + return signupSource == "" || signupSource == "email" +} + +func emailIdentitySupportsSignIn(record UserAuthIdentityRecord) bool { + source := strings.TrimSpace(firstStringIdentityValue(record.Metadata, "source")) + switch source { + case "auth_service_email_bind", "auth_service_login_backfill", "auth_service_dual_write": + return true + default: + return false + } +} + func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) { if userID <= 0 || s == nil || s.userRepo == nil { return nil, nil @@ -709,11 +755,11 @@ func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, err path := "" switch provider { case "linuxdo": - path = "/api/v1/auth/oauth/linuxdo/start" + path = "/api/v1/auth/oauth/linuxdo/bind/start" case "oidc": - path = "/api/v1/auth/oauth/oidc/start" + path = "/api/v1/auth/oauth/oidc/bind/start" case "wechat": - path = "/api/v1/auth/oauth/wechat/start" + path = "/api/v1/auth/oauth/wechat/bind/start" default: return "", ErrIdentityProviderInvalid } @@ -889,12 +935,20 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) { if err != nil { return nil, fmt.Errorf("get user: %w", err) } + normalizeLoadedUserTokenVersion(user) if err := s.hydrateUserAvatar(ctx, user); err != nil { return nil, fmt.Errorf("get user avatar: %w", err) } return user, nil } +func normalizeLoadedUserTokenVersion(user *User) { + if user == nil { + return + } + user.TokenVersion = resolvedTokenVersion(user) +} + // TouchLastActive 通过防抖更新 users.last_active_at,减少鉴权热路径写放大。 // 该操作为尽力而为,不应中断正常请求。 func (s *UserService) TouchLastActive(ctx context.Context, userID int64) { diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 0ad95356..ff55c2a5 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -387,6 +387,70 @@ func TestUnbindUserAuthProviderRejectsLastRemainingLoginMethod(t *testing.T) { require.Empty(t, repo.unboundProviders) } +func TestGetProfileIdentitySummaries_DoesNotTreatOAuthOnlyCompatEmailAsAlternativeLoginMethod(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 10, + Email: "oauth-only@example.com", + SignupSource: "oidc", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "oidc", + ProviderKey: "https://issuer.example.com", + ProviderSubject: "oidc-only-subject", + }, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 10, repo.getByIDUser) + + require.NoError(t, err) + require.False(t, summaries.OIDC.CanUnbind) + + _, err = svc.UnbindUserAuthProvider(context.Background(), 10, "oidc") + require.ErrorIs(t, err, ErrIdentityUnbindLastMethod) + require.Empty(t, repo.unboundProviders) +} + +func TestGetProfileIdentitySummaries_DoesNotTreatCompatBackfilledEmailIdentityAsAlternativeLoginMethod(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 11, + Email: "oauth-only@example.com", + SignupSource: "wechat", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "oauth-only@example.com", + Metadata: map[string]any{ + "backfill_source": "users.email", + "migration": "109_auth_identity_compat_backfill", + }, + }, + { + ProviderType: "wechat", + ProviderKey: "wechat", + ProviderSubject: "wechat-only-subject", + }, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 11, repo.getByIDUser) + + require.NoError(t, err) + require.True(t, summaries.Email.Bound) + require.False(t, summaries.WeChat.CanUnbind) + + _, err = svc.UnbindUserAuthProvider(context.Background(), 11, "wechat") + require.ErrorIs(t, err, ErrIdentityUnbindLastMethod) + require.Empty(t, repo.unboundProviders) +} + func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testing.T) { repo := &mockUserRepo{ getByIDUser: &User{ @@ -451,6 +515,42 @@ func TestGetProfileIdentitySummaries_HidesBindActionWhenProviderExplicitlyDisabl require.Empty(t, summaries.LinuxDo.BindStartPath) } +func TestGetProfileIdentitySummaries_UsesBindStartRoute(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 16, + Email: "alice@example.com", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "alice@example.com", + }, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 16, repo.getByIDUser) + + require.NoError(t, err) + require.Equal( + t, + "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile", + summaries.LinuxDo.BindStartPath, + ) + require.Equal( + t, + "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile", + summaries.OIDC.BindStartPath, + ) + require.Equal( + t, + "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile", + summaries.WeChat.BindStartPath, + ) +} + func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) { repo := &mockUserRepo{} svc := NewUserService(repo, nil, nil, nil) // billingCache = nil diff --git a/frontend/src/api/__tests__/admin.users.spec.ts b/frontend/src/api/__tests__/admin.users.spec.ts new file mode 100644 index 00000000..37656b78 --- /dev/null +++ b/frontend/src/api/__tests__/admin.users.spec.ts @@ -0,0 +1,117 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { post } = vi.hoisted(() => ({ + post: vi.fn(), +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + post, + }, +})) + +import { + bindUserAuthIdentity, + type AdminBindAuthIdentityRequest, + type AdminBoundAuthIdentity, +} from '@/api/admin/users' + +type Assert = T +type IsExact = ( + (() => G extends T ? 1 : 2) extends (() => G extends U ? 1 : 2) + ? ((() => G extends U ? 1 : 2) extends (() => G extends T ? 1 : 2) ? true : false) + : false +) + +type ExpectedAdminBindAuthIdentityRequest = { + provider_type: string + provider_key: string + provider_subject: string + issuer?: string + metadata?: Record + channel?: { + channel: string + channel_app_id: string + channel_subject: string + metadata?: Record + } +} + +type ExpectedAdminBoundAuthIdentity = { + user_id: number + provider_type: string + provider_key: string + provider_subject: string + verified_at?: string | null + issuer?: string | null + metadata: Record | null + created_at: string + updated_at: string + channel?: { + channel: string + channel_app_id: string + channel_subject: string + metadata: Record | null + created_at: string + updated_at: string + } | null +} + +const requestContractExact: Assert< + IsExact +> = true +const responseContractExact: Assert< + IsExact +> = true + +describe('admin users api auth identity binding', () => { + beforeEach(() => { + post.mockReset() + }) + + it('posts the backend-compatible auth identity bind payload and returns the backend response shape', async () => { + const payload: AdminBindAuthIdentityRequest = { + provider_type: 'wechat', + provider_key: 'wechat-main', + provider_subject: 'union-123', + metadata: { source: 'admin-repair' }, + channel: { + channel: 'open', + channel_app_id: 'wx-open', + channel_subject: 'openid-123', + metadata: { scene: 'migration' }, + }, + } + + const response: AdminBoundAuthIdentity = { + user_id: 9, + provider_type: 'wechat', + provider_key: 'wechat-main', + provider_subject: 'union-123', + verified_at: '2026-04-22T00:00:00Z', + issuer: null, + metadata: { source: 'admin-repair' }, + created_at: '2026-04-22T00:00:00Z', + updated_at: '2026-04-22T00:00:00Z', + channel: { + channel: 'open', + channel_app_id: 'wx-open', + channel_subject: 'openid-123', + metadata: { scene: 'migration' }, + created_at: '2026-04-22T00:00:00Z', + updated_at: '2026-04-22T00:00:00Z', + }, + } + post.mockResolvedValue({ data: response }) + + const result = await bindUserAuthIdentity(9, payload) + + expect(post).toHaveBeenCalledWith('/admin/users/9/auth-identities', payload) + expect(result).toEqual(response) + }) + + it('keeps bind auth identity request and response types aligned with the backend contract', () => { + expect(requestContractExact).toBe(true) + expect(responseContractExact).toBe(true) + }) +}) diff --git a/frontend/src/api/__tests__/client.spec.ts b/frontend/src/api/__tests__/client.spec.ts index 0f663e76..a46c39eb 100644 --- a/frontend/src/api/__tests__/client.spec.ts +++ b/frontend/src/api/__tests__/client.spec.ts @@ -91,6 +91,22 @@ describe('API Client', () => { const config = adapter.mock.calls[0][0] expect(config.params?.timezone).toBeUndefined() }) + + it('请求默认带 withCredentials 以支持跨域 cookie', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: {} }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await apiClient.post('/auth/oauth/bind-token') + + const config = adapter.mock.calls[0][0] + expect(config.withCredentials).toBe(true) + }) }) // --- 响应拦截器 --- diff --git a/frontend/src/api/__tests__/user.spec.ts b/frontend/src/api/__tests__/user.spec.ts new file mode 100644 index 00000000..887046da --- /dev/null +++ b/frontend/src/api/__tests__/user.spec.ts @@ -0,0 +1,32 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +describe('user api oauth binding urls', () => { + beforeEach(() => { + vi.resetModules() + vi.stubEnv('VITE_API_BASE_URL', 'https://api.example.com/api/v1') + }) + + afterEach(() => { + vi.unstubAllEnvs() + }) + + it('builds third-party bind urls against the bind start endpoint', async () => { + const { buildOAuthBindingStartURL } = await import('@/api/user') + + expect(buildOAuthBindingStartURL('linuxdo', { redirectTo: '/settings/profile' })).toBe( + 'https://api.example.com/api/v1/auth/oauth/linuxdo/bind/start?redirect=%2Fsettings%2Fprofile&intent=bind_current_user' + ) + expect( + buildOAuthBindingStartURL('wechat', { + redirectTo: '/settings/profile', + wechatOAuthSettings: { + wechat_oauth_open_enabled: true, + wechat_oauth_mp_enabled: false, + wechat_oauth_mobile_enabled: false + } + }) + ).toBe( + 'https://api.example.com/api/v1/auth/oauth/wechat/bind/start?redirect=%2Fsettings%2Fprofile&intent=bind_current_user&mode=open' + ) + }) +}) diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index 1bb3d54c..3c75a6c4 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -8,26 +8,40 @@ import type { AdminUser, UpdateUserRequest, PaginatedResponse, ApiKey } from '@/ export interface AdminBindAuthIdentityChannelRequest { channel: string - channel_app_id?: string + channel_app_id: string channel_subject: string - metadata?: Record + metadata?: Record | null } export interface AdminBindAuthIdentityRequest { provider_type: string provider_key: string provider_subject: string - issuer?: string - metadata?: Record + issuer?: string | null + metadata?: Record | null channel?: AdminBindAuthIdentityChannelRequest } +export interface AdminBoundAuthIdentityChannel { + channel: string + channel_app_id: string + channel_subject: string + metadata: Record | null + created_at: string + updated_at: string +} + export interface AdminBoundAuthIdentity { - identity_id: number + user_id: number provider_type: string provider_key: string provider_subject: string - channel_id?: number | null + verified_at?: string | null + issuer?: string | null + metadata: Record | null + created_at: string + updated_at: string + channel?: AdminBoundAuthIdentityChannel | null } /** diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 9621c26e..f49f3a1f 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -194,6 +194,7 @@ export interface OAuthTokenResponse { } export interface PendingOAuthBindLoginResponse extends Partial { + auth_result?: string redirect?: string error?: string requires_2fa?: boolean @@ -206,7 +207,9 @@ export interface PendingOAuthBindLoginResponse extends Partial : t('profile.authBindings.confirmEmailBindAction') ) +function resolveLegacyCompatibleWeChatSettings( + settings: WeChatOAuthPublicSettings | null | undefined +): (WeChatOAuthPublicSettings & { + wechat_oauth_open_enabled: boolean + wechat_oauth_mp_enabled: boolean +}) | null { + if (!settings) { + return null + } + + if (hasExplicitWeChatOAuthCapabilities(settings)) { + return settings + } + + if (typeof settings.wechat_oauth_enabled !== 'boolean') { + return null + } + + return { + ...settings, + wechat_oauth_open_enabled: settings.wechat_oauth_enabled, + wechat_oauth_mp_enabled: settings.wechat_oauth_enabled, + } +} + const wechatOAuthSettings = computed(() => { - if (hasExplicitWeChatOAuthCapabilities(appStore.cachedPublicSettings)) { - return appStore.cachedPublicSettings + const cachedSettings = resolveLegacyCompatibleWeChatSettings(appStore.cachedPublicSettings) + if (cachedSettings) { + return cachedSettings } - if (typeof props.wechatOpenEnabled === 'boolean' && typeof props.wechatMpEnabled === 'boolean') { - return { - wechat_oauth_enabled: props.wechatEnabled, - wechat_oauth_open_enabled: props.wechatOpenEnabled, - wechat_oauth_mp_enabled: props.wechatMpEnabled, - } - } - - return null + return resolveLegacyCompatibleWeChatSettings({ + wechat_oauth_enabled: props.wechatEnabled, + wechat_oauth_open_enabled: props.wechatOpenEnabled, + wechat_oauth_mp_enabled: props.wechatMpEnabled, + }) }) const resolvedWeChatBinding = computed(() => resolveWeChatOAuthStartStrict(wechatOAuthSettings.value)) @@ -362,6 +384,17 @@ function getBindingDetails(provider: UserAuthProvider): UserAuthBindingStatus | return binding } +function getDisplayableEmail(user: User | null | undefined): string { + const email = user?.email?.trim() || '' + if (!email) { + return '' + } + if (email.endsWith('.invalid') && !getBindingStatusForUser(user, 'email')) { + return '' + } + return email +} + function isProviderEnabledForBinding(provider: BindableProvider): boolean { if (provider === 'linuxdo') { return props.linuxdoEnabled @@ -444,14 +477,7 @@ function providerIconClass(provider: UserAuthProvider): string { function providerSummary(provider: UserAuthProvider): string { if (provider === 'email') { - const email = currentUser.value?.email?.trim() || '' - if (!email) { - return '' - } - if (currentUser.value?.email_bound === false && email.endsWith('.invalid')) { - return '' - } - return email + return getDisplayableEmail(currentUser.value) } return '' } diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue index 4544c337..37ee8a55 100644 --- a/frontend/src/components/user/profile/ProfileInfoCard.vue +++ b/frontend/src/components/user/profile/ProfileInfoCard.vue @@ -185,7 +185,7 @@ import Icon from '@/components/icons/Icon.vue' import ProfileAvatarCard from '@/components/user/profile/ProfileAvatarCard.vue' import ProfileEditForm from '@/components/user/profile/ProfileEditForm.vue' import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue' -import type { User, UserAuthProvider, UserProfileSourceContext } from '@/types' +import type { User, UserAuthBindingStatus, UserAuthProvider, UserProfileSourceContext } from '@/types' const props = withDefaults(defineProps<{ user: User | null @@ -206,6 +206,29 @@ const props = withDefaults(defineProps<{ const { t } = useI18n() +function normalizeBindingStatus(binding: boolean | UserAuthBindingStatus | undefined): boolean | null { + if (typeof binding === 'boolean') { + return binding + } + if (!binding) { + return null + } + if (typeof binding.bound === 'boolean') { + return binding.bound + } + return Boolean(binding.provider_subject || binding.issuer || binding.provider_key) +} + +function isEmailBound(user: User | null | undefined): boolean { + if (typeof user?.email_bound === 'boolean') { + return user.email_bound + } + + const nested = user?.auth_bindings?.email ?? user?.identity_bindings?.email + const normalized = normalizeBindingStatus(nested) + return normalized ?? false +} + const avatarUrl = computed(() => props.user?.avatar_url?.trim() || '') const displayName = computed(() => props.user?.username?.trim() || props.user?.email?.trim() || t('profile.user')) const primaryEmailDisplay = computed(() => { @@ -213,7 +236,7 @@ const primaryEmailDisplay = computed(() => { if (!email) { return '' } - if (props.user?.email_bound === false && email.endsWith('.invalid')) { + if (email.endsWith('.invalid') && !isEmailBound(props.user)) { return '' } return email diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts index 77d2219e..b54a1cce 100644 --- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts +++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts @@ -188,7 +188,7 @@ describe('ProfileIdentityBindingsSection', () => { expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(false) }) - it('hides the WeChat bind action when only the legacy aggregate setting is present', () => { + it('keeps the WeChat bind action visible when only the legacy aggregate setting is present', () => { const wrapper = mount(ProfileIdentityBindingsSection, { global: { plugins: [pinia], @@ -201,7 +201,28 @@ describe('ProfileIdentityBindingsSection', () => { }, }) - expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(false) + expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(true) + }) + + it('starts the WeChat bind flow when only the legacy aggregate setting is present', async () => { + const wrapper = mount(ProfileIdentityBindingsSection, { + global: { + plugins: [pinia], + }, + props: { + user: createUser(), + linuxdoEnabled: false, + oidcEnabled: false, + wechatEnabled: true, + }, + }) + + await wrapper.get('[data-testid="profile-binding-wechat-action"]').trigger('click') + + expect(locationState.current.href).toContain('/api/v1/auth/oauth/wechat/start?') + expect(locationState.current.href).toContain('mode=open') + expect(locationState.current.href).toContain('intent=bind_current_user') + expect(locationState.current.href).toContain('redirect=%2Fprofile') }) it('uses explicit cached WeChat capabilities and ignores legacy prop fallbacks', () => { @@ -358,6 +379,28 @@ describe('ProfileIdentityBindingsSection', () => { expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Not bound') }) + it('does not show a synthetic oauth-only email when only fallback auth bindings mark email as unbound', () => { + const wrapper = mount(ProfileIdentityBindingsSection, { + global: { + plugins: [pinia], + }, + props: { + user: createUser({ + email: 'legacy-user@wechat-connect.invalid', + auth_bindings: { + email: { bound: false }, + }, + }), + linuxdoEnabled: false, + oidcEnabled: false, + wechatEnabled: false, + }, + }) + + expect(wrapper.text()).not.toContain('legacy-user@wechat-connect.invalid') + expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Not bound') + }) + it('keeps the email form available for replacing a bound primary email', async () => { userApiMocks.sendEmailBindingCode.mockResolvedValue(undefined) userApiMocks.bindEmailIdentity.mockResolvedValue( diff --git a/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts index c7e60d9b..51653c6a 100644 --- a/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts +++ b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts @@ -152,6 +152,26 @@ describe('ProfileInfoCard', () => { expect(wrapper.text()).not.toContain('legacy-user@oidc-connect.invalid') }) + it('does not display synthetic oauth-only emails when only legacy identity bindings mark email as unbound', () => { + const wrapper = mount(ProfileInfoCard, { + props: { + user: createUser({ + email: 'legacy-user@wechat-connect.invalid', + identity_bindings: { + email: { bound: false } + } + }) + }, + global: { + stubs: { + Icon: true + } + } + }) + + expect(wrapper.text()).not.toContain('legacy-user@wechat-connect.invalid') + }) + it('renders the approved overview hero and two-column content shell', () => { const wrapper = mount(ProfileInfoCard, { props: { diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 5772c501..90daae5f 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -3763,11 +3763,7 @@

{{ t("admin.settings.payment.description") }} {{ t("admin.settings.payment.enabledPaymentTypesHint") }} + locale.value.startsWith("zh") + ? "https://github.com/Wei-Shaw/sub2api/blob/main/README_CN.md#%E6%94%AF%E4%BB%98" + : "https://github.com/Wei-Shaw/sub2api/blob/main/README.md#payment", +); + type SettingsTab = | "general" | "security" diff --git a/frontend/src/views/admin/__tests__/SettingsView.spec.ts b/frontend/src/views/admin/__tests__/SettingsView.spec.ts index 10c51b2a..c294756e 100644 --- a/frontend/src/views/admin/__tests__/SettingsView.spec.ts +++ b/frontend/src/views/admin/__tests__/SettingsView.spec.ts @@ -46,6 +46,8 @@ const { showSuccess: vi.fn(), })); +const localeRef = vi.hoisted(() => ({ value: "zh-CN" })); + vi.mock("@/api", () => ({ adminAPI: { settings: { @@ -149,6 +151,8 @@ vi.mock("vue-i18n", async () => { "admin.settings.paymentVisibleMethods.sourceLabel": "支付来源", "admin.settings.paymentVisibleMethods.sourceHint": "启用后必须明确选择一个来源;未配置状态不会对外展示该支付方式。", "admin.settings.paymentVisibleMethods.sourceRequiredError": "{title} 已启用,请先选择支付来源。", + "admin.settings.payment.configGuide": "查看支付配置说明", + "admin.settings.payment.findProvider": "查看支持的支付方式", "admin.settings.openaiExperimentalScheduler.title": "OpenAI 实验调度策略", "admin.settings.openaiExperimentalScheduler.description": "默认关闭。开启后仅影响本网关在 OpenAI 账号间的实验性调度选择逻辑,不代表上游 OpenAI 官方能力。", }; @@ -157,7 +161,7 @@ vi.mock("vue-i18n", async () => { useI18n: () => ({ t: (key: string, params?: Record) => (translations[key] ?? key).replace(/\{(\w+)\}/g, (_, token) => params?.[token] ?? `{${token}}`), - locale: ref("zh-CN"), + locale: localeRef, }), }; }); @@ -429,6 +433,7 @@ describe("admin SettingsView payment visible method controls", () => { adminSettingsFetch.mockReset(); showError.mockReset(); showSuccess.mockReset(); + localeRef.value = "zh-CN"; getSettings.mockResolvedValue({ ...baseSettingsResponse }); updateSettings.mockImplementation(async (payload) => ({ @@ -489,6 +494,30 @@ describe("admin SettingsView payment visible method controls", () => { expect(wrapper.text()).not.toContain("支付来源"); }); + it("links payment guidance to README sections instead of removed payment docs", async () => { + const wrapper = mountView(); + + await flushPromises(); + await openPaymentTab(wrapper); + + const paymentLinks = wrapper + .findAll("a") + .filter((node) => + ["查看支付配置说明", "查看支持的支付方式"].includes(node.text()), + ); + + expect(paymentLinks).toHaveLength(2); + expect(paymentLinks[0]?.attributes("href")).toBe( + "https://github.com/Wei-Shaw/sub2api/blob/main/README_CN.md#%E6%94%AF%E4%BB%98", + ); + expect(paymentLinks[1]?.attributes("href")).toBe( + "https://github.com/Wei-Shaw/sub2api/blob/main/README_CN.md#%E6%94%AF%E4%BB%98", + ); + for (const link of paymentLinks) { + expect(link.attributes("href")).not.toContain("docs/PAYMENT"); + } + }); + it("does not submit legacy visible payment method settings", async () => { const wrapper = mountView(); diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index 2cf4e694..f73d77de 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -456,7 +456,14 @@ function resolvePendingAccountAction( if (raw === 'email_required' || raw === 'create_account_required' || raw === 'create_account') { return 'create_account' } - if (raw === 'bind_login_required' || raw === 'bind_login') { + if ( + raw === 'bind_login_required' || + raw === 'bind_login' || + raw === 'existing_account' || + raw === 'existing_account_required' || + raw === 'existing_account_binding_required' || + raw === 'adopt_existing_user_by_email' + ) { return 'bind_login' } return 'none' diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue index bae20df8..9ecc5e47 100644 --- a/frontend/src/views/auth/WechatCallbackView.vue +++ b/frontend/src/views/auth/WechatCallbackView.vue @@ -613,8 +613,12 @@ async function handleBindCurrentAccount() { return } - await prepareOAuthBindAccessTokenCookie() - window.location.href = startURL + try { + await prepareOAuthBindAccessTokenCookie() + window.location.href = startURL + } catch (e: unknown) { + errorMessage.value = getRequestErrorMessage(e, t('auth.loginFailed')) + } } async function handleExistingAccountBinding() { diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts index 3fee2c27..333f8dc5 100644 --- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts @@ -336,6 +336,33 @@ describe('LinuxDoCallbackView', () => { ) }) + it('keeps rendering bind-login UI for legacy pending bind responses instead of treating them as success', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'adopt_existing_user_by_email', + redirect: '/profile/security', + email: 'existing@example.com' + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '

' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(showSuccess).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + expect((wrapper.get('[data-testid="linuxdo-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + it('persists a pending auth session when the oauth flow still needs account creation', async () => { exchangePendingOAuthCompletion.mockResolvedValue({ error: 'email_required', diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts index da41c987..7150dd7e 100644 --- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts @@ -621,6 +621,34 @@ describe('WechatCallbackView', () => { expect(locationState.current.href).toContain('mode=open') }) + it('shows an error and stays on the page when preparing bind-token for the current account fails', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/usage', + }) + getAuthTokenMock.mockReturnValue('current-auth-token') + prepareOAuthBindAccessTokenCookieMock.mockRejectedValue(new Error('bind token failed')) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.get('[data-testid="existing-account-submit"]').trigger('click').catch(() => undefined) + await flushPromises() + + expect(showErrorMock).toHaveBeenCalledWith('bind token failed') + expect(locationState.current.href).toBe('http://localhost/auth/wechat/callback') + }) + it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { getPublicSettingsMock.mockResolvedValue({ invitation_code_enabled: true, From 82259d138066fa70233eca6419032601b2bdba7c Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 16:01:25 +0800 Subject: [PATCH 23/31] fix(auth): preserve resolved token version on oauth login --- .../handler/auth_oauth_pending_flow_test.go | 15 +++++++++++++++ backend/internal/service/auth_service.go | 3 +++ backend/internal/service/user.go | 15 +++++++++------ backend/internal/service/user_service.go | 3 ++- 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 9f9e497b..a4b7a297 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -778,6 +778,14 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"]) require.NotContains(t, payload, "adoption_required") + accessToken, ok := payload["access_token"].(string) + require.True(t, ok) + claims, err := handler.authService.ValidateToken(accessToken) + require.NoError(t, err) + reloadedUser, err := handler.userService.GetByID(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion) + decisionCount, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). Count(ctx) @@ -2033,6 +2041,13 @@ func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) { payload := decodeJSONResponseData(t, recorder) require.NotEmpty(t, payload["access_token"]) require.NotEmpty(t, payload["refresh_token"]) + accessToken, ok := payload["access_token"].(string) + require.True(t, ok) + claims, err := handler.authService.ValidateToken(accessToken) + require.NoError(t, err) + reloadedUser, err := handler.userService.GetByID(ctx, existingUser.ID) + require.NoError(t, err) + require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion) identity, err := client.AuthIdentity.Query(). Where( diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 59442d1f..3bf9da3d 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -1500,6 +1500,9 @@ func resolvedTokenVersion(user *User) int64 { if user == nil { return 0 } + if user.TokenVersionResolved { + return user.TokenVersion + } material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash sum := sha256.Sum256([]byte(material)) diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index fa04d95e..9dc13381 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -23,12 +23,15 @@ type User struct { Status string AllowedGroups []int64 TokenVersion int64 // Incremented on password change to invalidate existing tokens - SignupSource string - LastLoginAt *time.Time - LastActiveAt *time.Time - LastUsedAt *time.Time - CreatedAt time.Time - UpdatedAt time.Time + // TokenVersionResolved indicates TokenVersion already contains the fingerprint-derived + // value expected in JWT claims and refresh-token state. + TokenVersionResolved bool + SignupSource string + LastLoginAt *time.Time + LastActiveAt *time.Time + LastUsedAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index a211103f..7ba401e7 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -943,10 +943,11 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) { } func normalizeLoadedUserTokenVersion(user *User) { - if user == nil { + if user == nil || user.TokenVersionResolved { return } user.TokenVersion = resolvedTokenVersion(user) + user.TokenVersionResolved = true } // TouchLastActive 通过防抖更新 users.last_active_at,减少鉴权热路径写放大。 From ad4600964e11737e9021737a927f0af0431ec663 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 16:38:36 +0800 Subject: [PATCH 24/31] fix(ci): clean up lint and dead code --- .../internal/handler/auth_oauth_pending_flow.go | 12 ------------ backend/internal/repository/migrations_runner.go | 4 +++- backend/internal/repository/user_repo.go | 4 ---- .../user_repo_email_lookup_unit_test.go | 6 +++--- backend/internal/service/payment_service.go | 9 --------- .../service/payment_visible_method_instances.go | 15 --------------- .../layout/__tests__/AppSidebar.spec.ts | 2 +- .../views/admin/__tests__/SettingsView.spec.ts | 2 +- 8 files changed, 8 insertions(+), 46 deletions(-) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index ab854d24..604ad903 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -265,18 +265,6 @@ func pendingSessionWantsInvitation(payload map[string]any) bool { return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") } -func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool { - if len(payload) == 0 { - return false - } - for _, key := range []string{"access_token", "refresh_token"} { - if value := pendingSessionStringValue(payload, key); value != "" { - return true - } - } - return false -} - func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool { if session == nil { return false diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index f5798486..be4a4cc5 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -301,7 +301,9 @@ func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]st if err != nil { return nil, err } - defer rows.Close() + defer func() { + _ = rows.Close() + }() duplicates := make([]string, 0, 5) for rows.Next() { diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 3d526e7b..c5db3dc4 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -739,10 +739,6 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx) } -func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error { - return ensureNormalizedEmailAvailableWithClient(ctx, clientFromContext(ctx, r.client), userID, email) -} - func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error { client = clientFromContext(ctx, client) if client == nil { diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go index 2ef9d761..7da3db9b 100644 --- a/backend/internal/repository/user_repo_email_lookup_unit_test.go +++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go @@ -209,10 +209,10 @@ func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency( successes := 0 conflicts := 0 for _, err := range errors { - switch { - case err == nil: + switch err { + case nil: successes++ - case err == service.ErrEmailExists: + case service.ErrEmailExists: conflicts++ default: t.Fatalf("unexpected create error: %v", err) diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index d39d2b19..97fd76a0 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -281,15 +281,6 @@ func newLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService return NewPaymentResumeService(signingKey, verifyFallbacks...) } -func psResumeSigningKey(configService *PaymentConfigService) []byte { - signingKey, _ := psResumeSigningKeys(configService) - return signingKey -} - -func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte) { - return resolvePaymentResumeSigningKeys(psResumeLegacyVerificationKey(configService)) -} - func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte { if configService == nil { return nil diff --git a/backend/internal/service/payment_visible_method_instances.go b/backend/internal/service/payment_visible_method_instances.go index 5dcdab16..899bd7a0 100644 --- a/backend/internal/service/payment_visible_method_instances.go +++ b/backend/internal/service/payment_visible_method_instances.go @@ -131,21 +131,6 @@ func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProvider return nil } -func buildPaymentProviderConflictError(method string, conflicting *dbent.PaymentProviderInstance) error { - metadata := map[string]string{ - "payment_method": NormalizeVisibleMethod(method), - } - if conflicting != nil { - metadata["conflicting_provider_id"] = fmt.Sprintf("%d", conflicting.ID) - metadata["conflicting_provider_key"] = conflicting.ProviderKey - metadata["conflicting_provider_name"] = conflicting.Name - } - return infraerrors.Conflict( - "PAYMENT_PROVIDER_CONFLICT", - fmt.Sprintf("%s payment already has an enabled provider instance", NormalizeVisibleMethod(method)), - ).WithMetadata(metadata) -} - func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts( ctx context.Context, excludeID int64, diff --git a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts index 118c7615..592ce8a3 100644 --- a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts +++ b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts @@ -21,7 +21,7 @@ describe('AppSidebar custom SVG styles', () => { describe('AppSidebar header styles', () => { it('does not clip the version badge dropdown', () => { - const sidebarHeaderBlockMatch = styleSource.match(/\.sidebar-header\s*\{[\s\S]*?\n \}/) + const sidebarHeaderBlockMatch = styleSource.match(/\.sidebar-header\s*\{[\s\S]*?\n {2}\}/) const sidebarBrandBlockMatch = componentSource.match(/\.sidebar-brand\s*\{[\s\S]*?\n\}/) expect(sidebarHeaderBlockMatch).not.toBeNull() diff --git a/frontend/src/views/admin/__tests__/SettingsView.spec.ts b/frontend/src/views/admin/__tests__/SettingsView.spec.ts index c294756e..210fd868 100644 --- a/frontend/src/views/admin/__tests__/SettingsView.spec.ts +++ b/frontend/src/views/admin/__tests__/SettingsView.spec.ts @@ -1,5 +1,5 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; -import { defineComponent, h, ref } from "vue"; +import { defineComponent, h } from "vue"; import { flushPromises, mount } from "@vue/test-utils"; import SettingsView from "../SettingsView.vue"; From 66680a305615eaae2045ce733f71a49de331cb5f Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 16:44:25 +0800 Subject: [PATCH 25/31] fix(test): update wechat bind start path assertion --- backend/internal/handler/user_handler_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index e4985a22..fa91bd2a 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -253,7 +253,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) { require.Equal(t, "https://issuer.example.com", resp.Data.Identities.OIDC.ProviderKey) require.False(t, resp.Data.Identities.WeChat.Bound) require.True(t, resp.Data.Identities.WeChat.CanBind) - require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/start") + require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/bind/start") } func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { From a94d89efa753e224d1f260b3289b49ab9df90abd Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 16:51:23 +0800 Subject: [PATCH 26/31] fix(unit): restore secure oidc defaults and wechat alias reuse --- backend/internal/service/admin_service.go | 135 ++++++++++++++++-- backend/internal/service/setting_service.go | 4 +- .../setting_service_oidc_config_test.go | 12 +- 3 files changed, 129 insertions(+), 22 deletions(-) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 110c9008..4ae66613 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -879,6 +879,8 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 if providerKey == "" || providerSubject == "" { return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required") } + canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey) + compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey) var issuer *string if input.Issuer != nil { @@ -900,25 +902,26 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 } defer func() { _ = tx.Rollback() }() - identity, err := tx.AuthIdentity.Query(). + identityRecords, err := tx.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ(providerType), - authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderKeyIn(compatibleProviderKeys...), authidentity.ProviderSubjectEQ(providerSubject), ). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) } - if identity != nil && identity.UserID != userID { + if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) { return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") } + identity := selectOwnedAdminAuthIdentity(identityRecords, userID) if identity == nil { create := tx.AuthIdentity.Create(). SetUserID(userID). SetProviderType(providerType). - SetProviderKey(providerKey). + SetProviderKey(canonicalProviderKey). SetProviderSubject(providerSubject). SetVerifiedAt(verifiedAt) if issuer != nil { @@ -932,7 +935,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) } } else { - update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt) + update := tx.AuthIdentity.UpdateOneID(identity.ID). + SetVerifiedAt(verifiedAt). + SetProviderKey(canonicalProviderKey) if issuer != nil { update = update.SetIssuer(*issuer) } @@ -947,27 +952,28 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 var channel *dbent.AuthIdentityChannel if channelInput != nil { - channel, err = tx.AuthIdentityChannel.Query(). + channelRecords, err := tx.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ(providerType), - authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ProviderKeyIn(compatibleProviderKeys...), authidentitychannel.ChannelEQ(channelInput.Channel), authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID), authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject), ). WithIdentity(). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) } - if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) { return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") } + channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID) if channel == nil { create := tx.AuthIdentityChannel.Create(). SetIdentityID(identity.ID). SetProviderType(providerType). - SetProviderKey(providerKey). + SetProviderKey(canonicalProviderKey). SetChannel(channelInput.Channel). SetChannelAppID(channelInput.ChannelAppID). SetChannelSubject(channelInput.ChannelSubject) @@ -979,7 +985,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) } } else { - update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID) + update := tx.AuthIdentityChannel.UpdateOneID(channel.ID). + SetIdentityID(identity.ID). + SetProviderKey(canonicalProviderKey) if channelInput.Metadata != nil { update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) } @@ -996,6 +1004,105 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 return buildAdminBoundAuthIdentity(identity, channel), nil } +func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" { + return []string{providerKey} + } + if providerType != "wechat" { + return []string{providerKey} + } + + keys := []string{providerKey} + if !strings.EqualFold(providerKey, "wechat-main") { + keys = append(keys, "wechat-main") + } + if !strings.EqualFold(providerKey, "wechat") { + keys = append(keys, "wechat") + } + return keys +} + +func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + existingKey = strings.TrimSpace(existingKey) + requestedKey = strings.TrimSpace(requestedKey) + if providerType != "wechat" { + if requestedKey != "" { + return requestedKey + } + return existingKey + } + if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") { + return "wechat-main" + } + if requestedKey != "" { + return requestedKey + } + return existingKey +} + +func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + providerKey = strings.TrimSpace(providerKey) + if providerType != "wechat" { + return 0 + } + switch { + case strings.EqualFold(providerKey, "wechat-main"): + return 0 + case strings.EqualFold(providerKey, "wechat"): + return 2 + default: + return 1 + } +} + +func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity { + var selected *dbent.AuthIdentity + for _, record := range records { + if record.UserID != userID { + continue + } + if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record + } + } + return selected +} + +func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool { + for _, record := range records { + if record.UserID != userID { + return true + } + } + return false +} + +func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel { + var selected *dbent.AuthIdentityChannel + for _, record := range records { + if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID { + continue + } + if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record + } + } + return selected +} + +func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool { + for _, record := range records { + if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID { + return true + } + } + return false +} + func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput { if input == nil { return nil diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index aac60b08..93b7def1 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -818,14 +818,14 @@ func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool { if base.UsePKCEExplicit { return base.UsePKCE } - return false + return true } func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool { if base.ValidateIDTokenExplicit { return base.ValidateIDToken } - return false + return true } // UpdateSettings 更新系统设置 diff --git a/backend/internal/service/setting_service_oidc_config_test.go b/backend/internal/service/setting_service_oidc_config_test.go index 1ece6405..61324204 100644 --- a/backend/internal/service/setting_service_oidc_config_test.go +++ b/backend/internal/service/setting_service_oidc_config_test.go @@ -133,7 +133,7 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue require.True(t, got.OIDCConnectValidateIDToken) } -func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) { +func TestSettingService_ParseSettings_DefaultsOIDCCompatibilityFlagsToSafeDefaultsWhenSettingsMissing(t *testing.T) { svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{ OIDC: config.OIDCConnectConfig{ UsePKCE: true, @@ -145,8 +145,8 @@ func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettin SettingKeyOIDCConnectEnabled: "true", }) - require.False(t, got.OIDCConnectUsePKCE) - require.False(t, got.OIDCConnectValidateIDToken) + require.True(t, got.OIDCConnectUsePKCE) + require.True(t, got.OIDCConnectValidateIDToken) } func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) { @@ -216,7 +216,7 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t require.True(t, got.ValidateIDToken) } -func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) { +func TestGetOIDCConnectOAuthConfig_DefaultsCompatibilityFlagsToSafeValuesWhenSettingsMissing(t *testing.T) { cfg := &config.Config{ OIDC: config.OIDCConnectConfig{ Enabled: true, @@ -246,6 +246,6 @@ func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsM got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) require.NoError(t, err) - require.False(t, got.UsePKCE) - require.False(t, got.ValidateIDToken) + require.True(t, got.UsePKCE) + require.True(t, got.ValidateIDToken) } From 3419cb0112381f70cf5c992aa767b0039a2bb82e Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 17:22:24 +0800 Subject: [PATCH 27/31] fix(admin): preserve legacy oidc security write defaults --- .../internal/handler/admin/setting_handler.go | 7 ++-- backend/internal/service/setting_service.go | 32 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index c6b45ab8..a882d1a1 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -682,8 +682,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } // Generic OIDC 参数验证 - oidcUsePKCE := previousSettings.OIDCConnectUsePKCE - oidcValidateIDToken := previousSettings.OIDCConnectValidateIDToken + oidcUsePKCE, oidcValidateIDToken, err := h.settingService.OIDCSecurityWriteDefaults(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } if req.OIDCConnectEnabled { req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName) req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 93b7def1..f2b644be 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -828,6 +828,16 @@ func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool return true } +func oidcCompatibilityWriteDefault(base config.OIDCConnectConfig, configured bool, raw string, explicit bool, explicitValue bool) bool { + if configured { + return strings.TrimSpace(raw) == "true" + } + if explicit { + return explicitValue + } + return false +} + // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { updates, err := s.buildSystemSettingsUpdates(ctx, settings) @@ -842,6 +852,28 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet return err } +func (s *SettingService) OIDCSecurityWriteDefaults(ctx context.Context) (bool, bool, error) { + rawSettings, err := s.settingRepo.GetMultiple(ctx, []string{ + SettingKeyOIDCConnectUsePKCE, + SettingKeyOIDCConnectValidateIDToken, + }) + if err != nil { + return false, false, fmt.Errorf("get oidc security write defaults: %w", err) + } + + base := config.OIDCConnectConfig{} + if s != nil && s.cfg != nil { + base = s.cfg.OIDC + } + + rawUsePKCE, hasUsePKCE := rawSettings[SettingKeyOIDCConnectUsePKCE] + rawValidateIDToken, hasValidateIDToken := rawSettings[SettingKeyOIDCConnectValidateIDToken] + + return oidcCompatibilityWriteDefault(base, hasUsePKCE, rawUsePKCE, base.UsePKCEExplicit, base.UsePKCE), + oidcCompatibilityWriteDefault(base, hasValidateIDToken, rawValidateIDToken, base.ValidateIDTokenExplicit, base.ValidateIDToken), + nil +} + // UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write. func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error { updates, err := s.buildSystemSettingsUpdates(ctx, settings) From 66b3acc274fa2bcb080735bb60732dd3aa111678 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 17:51:45 +0800 Subject: [PATCH 28/31] fix(lint): remove embedded response selectors in openai images --- backend/internal/service/openai_images.go | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 396c0381..e592cf18 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -1157,9 +1157,9 @@ func uploadOpenAIImageFiles(ctx context.Context, client *req.Client, headers htt if err != nil { return nil, err } - if putResp.Response != nil && putResp.Response.Body != nil { - _, _ = io.Copy(io.Discard, putResp.Response.Body) - _ = putResp.Response.Body.Close() + if putResp.Response != nil && putResp.Body != nil { + _, _ = io.Copy(io.Discard, putResp.Body) + _ = putResp.Body.Close() } if putResp.StatusCode < 200 || putResp.StatusCode >= 300 { return nil, newOpenAIImageStatusError(putResp, "upload image bytes failed") @@ -1294,10 +1294,10 @@ type openAIImageToolMessage struct { } func readOpenAIImageConversationStream(resp *req.Response, startTime time.Time) (string, []openAIImagePointerInfo, OpenAIUsage, *int, error) { - if resp == nil || resp.Response == nil || resp.Response.Body == nil { + if resp == nil || resp.Response == nil || resp.Body == nil { return "", nil, OpenAIUsage{}, nil, fmt.Errorf("empty conversation response") } - reader := bufio.NewReader(resp.Response.Body) + reader := bufio.NewReader(resp.Body) var ( conversationID string firstTokenMs *int @@ -1529,8 +1529,8 @@ func pollOpenAIImageConversation(ctx context.Context, client *req.Client, header lastErr = err } else { if resp.StatusCode >= 200 && resp.StatusCode < 300 { - body, readErr := io.ReadAll(resp.Response.Body) - _ = resp.Response.Body.Close() + body, readErr := io.ReadAll(resp.Body) + _ = resp.Body.Close() if readErr != nil { lastErr = readErr goto waitNextPoll @@ -1750,14 +1750,14 @@ func newOpenAIImageStatusError(resp *req.Response, fallback string) error { body := []byte(nil) if resp.Response != nil { - headers = resp.Response.Header.Clone() - requestID = strings.TrimSpace(resp.Response.Header.Get("x-request-id")) + headers = resp.Header.Clone() + requestID = strings.TrimSpace(resp.Header.Get("x-request-id")) if resp.Response.Request != nil && resp.Response.Request.URL != nil { requestURL = resp.Response.Request.URL.String() } - if resp.Response.Body != nil { - body, _ = io.ReadAll(io.LimitReader(resp.Response.Body, 2<<20)) - _ = resp.Response.Body.Close() + if resp.Body != nil { + body, _ = io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() } } From 9de7a72cce6404adcd9b5d75e90aa550cc5b0bb2 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 18:01:51 +0800 Subject: [PATCH 29/31] fix(upgrade): close payment and oidc compatibility gaps --- backend/internal/payment/provider/wxpay.go | 30 +---- .../internal/payment/provider/wxpay_test.go | 18 +-- .../internal/repository/migrations_runner.go | 6 +- .../migrations_runner_checksum_test.go | 45 +++++++- .../migrations_runner_extra_test.go | 3 + .../service/payment_order_lifecycle.go | 21 +++- .../service/payment_order_lifecycle_test.go | 107 ++++++++++++++++++ ...24_backfill_legacy_oidc_security_flags.sql | 32 ++++++ ...tity_payment_migrations_regression_test.go | 12 ++ deploy/config.example.yaml | 2 +- 10 files changed, 232 insertions(+), 44 deletions(-) create mode 100644 backend/migrations/124_backfill_legacy_oidc_security_flags.sql diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go index 9927a265..e1b337fc 100644 --- a/backend/internal/payment/provider/wxpay.go +++ b/backend/internal/payment/provider/wxpay.go @@ -200,14 +200,7 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ case wxpayModeJSAPI: return w.prepayJSAPI(ctx, client, req, notifyURL, totalFen) case wxpayModeH5: - resp, err := w.prepayH5(ctx, client, req, notifyURL, totalFen) - if err == nil { - return resp, nil - } - if wxpayShouldFallbackToNative(err) { - return w.prepayNativeFallback(ctx, client, req, notifyURL, totalFen) - } - return nil, err + return w.prepayH5(ctx, client, req, notifyURL, totalFen) case wxpayModeNative: return w.prepayNative(ctx, client, req, notifyURL, totalFen) default: @@ -292,23 +285,6 @@ func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.Create return &payment.CreatePaymentResponse{TradeNo: req.OrderID, PayURL: h5URL}, nil } -func (w *Wxpay) prepayNativeFallback(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) { - resp, err := w.prepayNative(ctx, c, req, notifyURL, totalFen) - if err != nil { - return nil, fmt.Errorf("wxpay native fallback after NO_AUTH: %w", err) - } - nativeURL := strings.TrimSpace(resp.PayURL) - if nativeURL == "" { - nativeURL = strings.TrimSpace(resp.QRCode) - } - if nativeURL == "" { - return resp, nil - } - resp.PayURL = nativeURL - resp.QRCode = nativeURL - return resp, nil -} - func buildWxpayH5Info(config map[string]string) *h5.H5Info { tp := wxpayH5Type info := &h5.H5Info{Type: &tp} @@ -321,10 +297,6 @@ func buildWxpayH5Info(config map[string]string) *h5.H5Info { return info } -func wxpayShouldFallbackToNative(err error) bool { - return err != nil && strings.Contains(err.Error(), wxpayErrNoAuth) -} - func resolveWxpayCreateMode(req payment.CreatePaymentRequest) (string, error) { if strings.TrimSpace(req.OpenID) != "" { return wxpayModeJSAPI, nil diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go index a5a406f9..e8ac5e54 100644 --- a/backend/internal/payment/provider/wxpay_test.go +++ b/backend/internal/payment/provider/wxpay_test.go @@ -643,7 +643,7 @@ func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) { } } -func TestCreatePaymentMobileH5FallsBackToNativeOnNoAuth(t *testing.T) { +func TestCreatePaymentMobileH5ReturnsNoAuthErrorWithoutNativeFallback(t *testing.T) { origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment origNativePrepay := wxpayNativePrepay origH5Prepay := wxpayH5Prepay @@ -688,8 +688,8 @@ func TestCreatePaymentMobileH5FallsBackToNativeOnNoAuth(t *testing.T) { ClientIP: "203.0.113.10", IsMobile: true, }) - if err != nil { - t.Fatalf("unexpected error: %v", err) + if err == nil { + t.Fatal("expected no-auth error, got nil") } if jsapiCalls != 0 { t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls) @@ -697,13 +697,13 @@ func TestCreatePaymentMobileH5FallsBackToNativeOnNoAuth(t *testing.T) { if h5Calls != 1 { t.Fatalf("h5 prepay calls = %d, want 1", h5Calls) } - if nativeCalls != 1 { - t.Fatalf("native prepay calls = %d, want 1", nativeCalls) + if nativeCalls != 0 { + t.Fatalf("native prepay calls = %d, want 0", nativeCalls) } - if resp.PayURL != "weixin://wxpay/bizpayurl?pr=fallback-native" { - t.Fatalf("pay_url = %q, want native fallback url", resp.PayURL) + if resp != nil { + t.Fatalf("expected nil response, got %+v", resp) } - if resp.QRCode != "weixin://wxpay/bizpayurl?pr=fallback-native" { - t.Fatalf("qr_code = %q, want native fallback url", resp.QRCode) + if !strings.Contains(err.Error(), "NO_AUTH") { + t.Fatalf("error = %v, want NO_AUTH", err) } } diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index be4a4cc5..6dbb9fbd 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -66,10 +66,12 @@ type migrationChecksumCompatibilityRule struct { var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{ "054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"), "061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"), - "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"), + "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"), + "110_pending_auth_and_provider_default_grants.sql": newMigrationChecksumCompatibilityRule("32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279", "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925"), + "112_add_payment_order_provider_key_snapshot.sql": newMigrationChecksumCompatibilityRule("b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99", "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e"), "115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"), "116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"), - "118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"), + "118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227", "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb"), "119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"), "120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"), "123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"), diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go index 57647093..1fcb3be1 100644 --- a/backend/internal/repository/migrations_runner_checksum_test.go +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -55,8 +55,17 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { t.Run("109历史checksum可兼容", func(t *testing.T) { ok := isMigrationChecksumCompatible( "109_auth_identity_compat_backfill.sql", - "2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee", + "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", + ) + require.True(t, ok) + }) + + t.Run("109当前checksum可兼容历史checksum", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "109_auth_identity_compat_backfill.sql", + "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee", + "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", ) require.True(t, ok) }) @@ -64,8 +73,26 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { t.Run("109回滚到历史文件后仍兼容已应用的新checksum", func(t *testing.T) { ok := isMigrationChecksumCompatible( "109_auth_identity_compat_backfill.sql", + "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee", - "2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", + ) + require.True(t, ok) + }) + + t.Run("110历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "110_pending_auth_and_provider_default_grants.sql", + "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925", + "32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279", + ) + require.True(t, ok) + }) + + t.Run("112历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "112_add_payment_order_provider_key_snapshot.sql", + "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e", + "b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99", ) require.True(t, ok) }) @@ -97,6 +124,20 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { require.True(t, ok) }) + t.Run("118多个历史checksum都可兼容当前版本", func(t *testing.T) { + for _, dbChecksum := range []string{ + "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb", + "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227", + } { + ok := isMigrationChecksumCompatible( + "118_wechat_dual_mode_and_auth_source_defaults.sql", + dbChecksum, + "b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", + ) + require.True(t, ok) + } + }) + t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) { for _, dbChecksum := range []string{ "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go index a8bc15bc..5d67665e 100644 --- a/backend/internal/repository/migrations_runner_extra_test.go +++ b/backend/internal/repository/migrations_runner_extra_test.go @@ -96,6 +96,9 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) { func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) { for _, name := range []string{ + "109_auth_identity_compat_backfill.sql", + "110_pending_auth_and_provider_default_grants.sql", + "112_add_payment_order_provider_key_snapshot.sql", "115_auth_identity_legacy_external_backfill.sql", "116_auth_identity_legacy_external_safety_reports.sql", "118_wechat_dual_mode_and_auth_source_defaults.sql", diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index f14dc55d..b627ced4 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -158,7 +158,11 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s "queryRef": queryRef, }) slog.Warn("query upstream returned invalid paid amount", "orderID", o.ID, "queryRef", queryRef, "paid", resp.Amount) - return "" + retriedResp, retryOK := requeryPaidOrderOnce(ctx, prov, queryRef) + if !retryOK { + return "" + } + resp = retriedResp } notificationTradeNo := o.PaymentTradeNo if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) { @@ -184,6 +188,21 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s return "" } +func requeryPaidOrderOnce(ctx context.Context, prov payment.Provider, queryRef string) (*payment.QueryOrderResponse, bool) { + if prov == nil || strings.TrimSpace(queryRef) == "" { + return nil, false + } + resp, err := prov.QueryOrder(ctx, queryRef) + if err != nil { + slog.Warn("query upstream retry failed", "queryRef", queryRef, "error", err) + return nil, false + } + if resp == nil || resp.Status != payment.ProviderStatusPaid || !isValidProviderAmount(resp.Amount) { + return nil, false + } + return resp, true +} + func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string { if order == nil { return "" diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go index cabdb445..8dfd2e7e 100644 --- a/backend/internal/service/payment_order_lifecycle_test.go +++ b/backend/internal/service/payment_order_lifecycle_test.go @@ -21,6 +21,8 @@ import ( type paymentOrderLifecycleQueryProvider struct { lastQueryTradeNo string + queryCalls int + responses []*payment.QueryOrderResponse resp *payment.QueryOrderResponse } @@ -48,6 +50,14 @@ func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, paym func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { p.lastQueryTradeNo = tradeNo + p.queryCalls++ + if len(p.responses) > 0 { + resp := p.responses[0] + if len(p.responses) > 1 { + p.responses = p.responses[1:] + } + return resp, nil + } return p.resp, nil } @@ -234,6 +244,103 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) { require.Equal(t, user.ID, redeemRepo.useCalls[0].userID) } +func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) { + ctx := context.Background() + client := newPaymentOrderLifecycleTestClient(t) + + user, err := client.User.Create(). + SetEmail("checkpaid-retry@example.com"). + SetPasswordHash("hash"). + SetUsername("checkpaid-retry-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("CHECKPAID-UPSTREAM-RETRY"). + SetOutTradeNo("sub2_checkpaid_retry_zero_amount"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + userRepo := &mockUserRepo{ + getByIDUser: &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + Balance: 0, + }, + } + userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error { + require.Equal(t, user.ID, id) + if userRepo.getByIDUser != nil { + userRepo.getByIDUser.Balance += amount + } + return nil + } + redeemRepo := &paymentOrderLifecycleRedeemRepo{ + codesByCode: map[string]*RedeemCode{ + order.RechargeCode: { + ID: 1, + Code: order.RechargeCode, + Type: RedeemTypeBalance, + Value: order.Amount, + Status: StatusUnused, + }, + }, + } + redeemService := NewRedeemService( + redeemRepo, + userRepo, + nil, + nil, + nil, + client, + nil, + ) + registry := payment.NewRegistry() + provider := &paymentOrderLifecycleQueryProvider{ + responses: []*payment.QueryOrderResponse{ + { + TradeNo: "upstream-trade-zero", + Status: payment.ProviderStatusPaid, + Amount: 0, + }, + { + TradeNo: "upstream-trade-retry", + Status: payment.ProviderStatusPaid, + Amount: 88, + }, + }, + } + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + redeemService: redeemService, + userRepo: userRepo, + providersLoaded: true, + } + + got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) + require.NoError(t, err) + require.Equal(t, 2, provider.queryCalls) + require.Equal(t, OrderStatusCompleted, got.Status) + require.Equal(t, "upstream-trade-retry", got.PaymentTradeNo) +} + func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) { ctx := context.Background() client := newPaymentOrderLifecycleTestClient(t) diff --git a/backend/migrations/124_backfill_legacy_oidc_security_flags.sql b/backend/migrations/124_backfill_legacy_oidc_security_flags.sql new file mode 100644 index 00000000..e68bb11a --- /dev/null +++ b/backend/migrations/124_backfill_legacy_oidc_security_flags.sql @@ -0,0 +1,32 @@ +-- Preserve legacy OIDC behavior for upgraded installs that predate the +-- introduction of secure PKCE/id_token defaults. Fresh installs continue to +-- inherit runtime defaults when these rows are absent. + +WITH legacy_oidc_install AS ( + SELECT 1 + FROM settings + WHERE key IN ( + 'oidc_connect_enabled', + 'oidc_connect_client_id', + 'oidc_connect_authorize_url', + 'oidc_connect_token_url', + 'oidc_connect_issuer_url', + 'oidc_connect_userinfo_url', + 'oidc_connect_frontend_redirect_url' + ) + LIMIT 1 +) +INSERT INTO settings (key, value) +SELECT defaults.key, 'false' +FROM legacy_oidc_install +CROSS JOIN ( + VALUES + ('oidc_connect_use_pkce'), + ('oidc_connect_validate_id_token') +) AS defaults(key) +WHERE NOT EXISTS ( + SELECT 1 + FROM settings existing + WHERE existing.key = defaults.key +) +ON CONFLICT (key) DO NOTHING; diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go index 6a95d335..798ae0fe 100644 --- a/backend/migrations/auth_identity_payment_migrations_regression_test.go +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -115,3 +115,15 @@ func TestMigration123BackfillsLegacyAuthSourceGrantDefaultsSafely(t *testing.T) require.Contains(t, sql, "value = 'false'") require.Contains(t, sql, "auth_identity_migration_reports") } + +func TestMigration124BackfillsLegacyOIDCSecurityFlagsSafely(t *testing.T) { + content, err := FS.ReadFile("124_backfill_legacy_oidc_security_flags.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "oidc_connect_use_pkce") + require.Contains(t, sql, "oidc_connect_validate_id_token") + require.Contains(t, sql, "ON CONFLICT (key) DO NOTHING") + require.Contains(t, sql, "oidc_connect_enabled") + require.Contains(t, sql, "'false'") +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 358f6a31..dfc363b5 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -841,7 +841,7 @@ linuxdo_connect: frontend_redirect_url: "/auth/linuxdo/callback" token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none # 注意:当 token_auth_method=none(public client)时,必须启用 PKCE - use_pkce: false + use_pkce: true userinfo_email_path: "" userinfo_id_path: "" userinfo_username_path: "" From 1e0d466002456ed8e8adadce8d28225bae0efd97 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 22 Apr 2026 17:49:18 +0800 Subject: [PATCH 30/31] =?UTF-8?q?feat:=20=E8=A1=A5=E5=85=85gpt=E7=94=9F?= =?UTF-8?q?=E5=9B=BE=E6=A8=A1=E5=9E=8B=E6=B5=8B=E8=AF=95=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/pkg/openai/constants.go | 2 + .../internal/service/account_test_service.go | 354 +++++++++++++++++- .../account_test_service_openai_test.go | 4 +- .../service/openai_account_scheduler.go | 10 +- backend/internal/service/openai_images.go | 14 +- .../components/account/AccountTestModal.vue | 89 +++-- .../admin/account/AccountTestModal.vue | 89 +++-- .../__tests__/AccountTestModal.spec.ts | 6 +- frontend/src/composables/useModelWhitelist.ts | 4 +- frontend/src/i18n/locales/en.ts | 16 +- frontend/src/i18n/locales/zh.ts | 16 +- frontend/src/views/admin/GroupsView.vue | 10 +- 12 files changed, 533 insertions(+), 81 deletions(-) diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index f023e32b..60ffefb3 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -20,6 +20,8 @@ var DefaultModels = []Model{ {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, + {ID: "gpt-image-1", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1"}, + {ID: "gpt-image-1.5", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1.5"}, {ID: "gpt-image-2", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 2"}, } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index a5559b7d..396a3973 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "crypto/rand" + "encoding/base64" "encoding/hex" "encoding/json" "errors" @@ -52,8 +53,14 @@ type TestEvent struct { const ( defaultGeminiTextTestPrompt = "hi" defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background." + defaultOpenAIImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background." ) +// isOpenAIImageModel checks if the model is an OpenAI image generation model (e.g. gpt-image-2). +func isOpenAIImageModel(model string) bool { + return strings.HasPrefix(strings.ToLower(model), "gpt-image-") +} + // AccountTestService handles account testing operations type AccountTestService struct { accountRepo AccountRepository @@ -170,7 +177,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int // Route to platform-specific test method if account.IsOpenAI() { - return s.testOpenAIAccountConnection(c, account, modelID) + return s.testOpenAIAccountConnection(c, account, modelID, prompt) } if account.IsGemini() { @@ -410,7 +417,7 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co } // testOpenAIAccountConnection tests an OpenAI account's connection -func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error { +func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error { ctx := c.Request.Context() // Default to openai.DefaultTestModel for OpenAI testing @@ -429,6 +436,18 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account } } + // Route to image generation test if an image model is selected + if isOpenAIImageModel(testModelID) { + imagePrompt := strings.TrimSpace(prompt) + if imagePrompt == "" { + imagePrompt = defaultOpenAIImageTestPrompt + } + if account.Type == "apikey" { + return s.testOpenAIImageAPIKey(c, ctx, account, testModelID, imagePrompt) + } + return s.testOpenAIImageOAuth(c, ctx, account, testModelID, imagePrompt) + } + // Determine authentication method and API URL var authToken string var apiURL string @@ -1025,7 +1044,336 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) } } -// sendEvent sends a SSE event to the client +// testOpenAIImageAPIKey tests OpenAI image generation using an API Key account. +func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error { + authToken := account.GetOpenAIApiKey() + if authToken == "" { + return s.sendErrorAndEnd(c, "No API key available") + } + + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) + } + apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations" + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID}) + + payload := map[string]any{ + "model": modelID, + "prompt": prompt, + "n": 1, + "response_format": "b64_json", + } + payloadBytes, _ := json.Marshal(payload) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+authToken) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read response: %s", err.Error())) + } + + if resp.StatusCode != http.StatusOK { + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + } + + // Parse {"data": [{"b64_json": "...", "revised_prompt": "..."}]} + var result struct { + Data []struct { + B64JSON string `json:"b64_json"` + RevisedPrompt string `json:"revised_prompt"` + } `json:"data"` + } + if err := json.Unmarshal(body, &result); err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error())) + } + + if len(result.Data) == 0 { + return s.sendErrorAndEnd(c, "No images returned from API") + } + + for _, item := range result.Data { + if item.RevisedPrompt != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt}) + } + if item.B64JSON != "" { + s.sendEvent(c, TestEvent{ + Type: "image", + ImageURL: "data:image/png;base64," + item.B64JSON, + MimeType: "image/png", + }) + } + } + + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + +// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via ChatGPT backend API. +func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error { + authToken := account.GetOpenAIAccessToken() + if authToken == "" { + return s.sendErrorAndEnd(c, "No access token available") + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID}) + s.sendEvent(c, TestEvent{Type: "content", Text: "Initializing ChatGPT backend...\n"}) + + // Build headers (replicating buildOpenAIBackendAPIHeaders logic) + headers := buildOpenAIBackendAPIHeadersForTest(ctx, account, authToken, s.accountRepo) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + client, err := newOpenAIBackendAPIClient(proxyURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create client: %s", err.Error())) + } + + // Bootstrap + if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil { + log.Printf("OpenAI image test bootstrap warning: %v", bootstrapErr) + } + + // Fetch chat requirements + s.sendEvent(c, TestEvent{Type: "content", Text: "Fetching chat requirements...\n"}) + chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Chat requirements failed: %s", err.Error())) + } + if chatReqs.Arkose.Required { + return s.sendErrorAndEnd(c, "Unsupported challenge: arkose required") + } + + // Initialize and prepare conversation + s.sendEvent(c, TestEvent{Type: "content", Text: "Preparing image conversation...\n"}) + parentMessageID := uuid.NewString() + proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent")) + _ = initializeOpenAIImageConversation(ctx, client, headers) + conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, prompt, parentMessageID, chatReqs.Token, proofToken) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation prepare failed: %s", err.Error())) + } + + // Build simplified conversation request (no file uploads) + convReq := buildOpenAIImageTestConversationRequest(prompt, parentMessageID) + convHeaders := cloneHTTPHeader(headers) + convHeaders.Set("Accept", "text/event-stream") + convHeaders.Set("Content-Type", "application/json") + convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token) + if conduitToken != "" { + convHeaders.Set("x-conduit-token", conduitToken) + } + if proofToken != "" { + convHeaders.Set("openai-sentinel-proof-token", proofToken) + } + + s.sendEvent(c, TestEvent{Type: "content", Text: "Generating image...\n"}) + + resp, err := client.R(). + SetContext(ctx). + DisableAutoReadResponse(). + SetHeaders(headerToMap(convHeaders)). + SetBodyJsonMarshal(convReq). + Post(openAIChatGPTConversationURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation request failed: %s", err.Error())) + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + if resp.StatusCode >= 400 { + return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation API returned %d", resp.StatusCode)) + } + + startTime := time.Now() + conversationID, pointerInfos, _, _, err := readOpenAIImageConversationStream(resp, startTime) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read failed: %s", err.Error())) + } + + pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil) + if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) { + s.sendEvent(c, TestEvent{Type: "content", Text: "Waiting for image generation to complete...\n"}) + polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID) + if pollErr != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Poll failed: %s", pollErr.Error())) + } + pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers) + } + pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos) + if len(pointerInfos) == 0 { + return s.sendErrorAndEnd(c, "No images returned from conversation") + } + + s.sendEvent(c, TestEvent{Type: "content", Text: "Downloading generated image...\n"}) + + // Download and encode each image + for _, pointer := range pointerInfos { + downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Download URL fetch failed: %s", err.Error())) + } + data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Image download failed: %s", err.Error())) + } + b64 := base64.StdEncoding.EncodeToString(data) + mimeType := http.DetectContentType(data) + if pointer.Prompt != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: pointer.Prompt}) + } + s.sendEvent(c, TestEvent{ + Type: "image", + ImageURL: "data:" + mimeType + ";base64," + b64, + MimeType: mimeType, + }) + } + + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + +// buildOpenAIBackendAPIHeadersForTest builds ChatGPT backend API headers for test purposes. +// Replicates the logic from OpenAIGatewayService.buildOpenAIBackendAPIHeaders without +// requiring the full gateway service dependency. +func buildOpenAIBackendAPIHeadersForTest(ctx context.Context, account *Account, token string, repo AccountRepository) http.Header { + // Ensure device and session IDs exist + deviceID := account.GetOpenAIDeviceID() + sessionID := account.GetOpenAISessionID() + if deviceID == "" || sessionID == "" { + updates := map[string]any{} + if deviceID == "" { + deviceID = uuid.NewString() + updates["openai_device_id"] = deviceID + } + if sessionID == "" { + sessionID = uuid.NewString() + updates["openai_session_id"] = sessionID + } + if account.Extra == nil { + account.Extra = map[string]any{} + } + for key, value := range updates { + account.Extra[key] = value + } + if repo != nil { + updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + _ = repo.UpdateExtra(updateCtx, account.ID, updates) + } + } + + headers := make(http.Header) + headers.Set("Authorization", "Bearer "+token) + headers.Set("Accept", "application/json") + headers.Set("Origin", "https://chatgpt.com") + headers.Set("Referer", "https://chatgpt.com/") + headers.Set("Sec-Fetch-Dest", "empty") + headers.Set("Sec-Fetch-Mode", "cors") + headers.Set("Sec-Fetch-Site", "same-origin") + headers.Set("User-Agent", openAIImageBackendUserAgent) + if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" { + headers.Set("User-Agent", customUA) + } + if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" { + headers.Set("chatgpt-account-id", chatgptAccountID) + } + if deviceID != "" { + headers.Set("oai-device-id", deviceID) + headers.Set("Cookie", "oai-did="+deviceID) + } + if sessionID != "" { + headers.Set("oai-session-id", sessionID) + } + return headers +} + +// buildOpenAIImageTestConversationRequest creates a simplified image generation conversation request. +func buildOpenAIImageTestConversationRequest(prompt, parentMessageID string) map[string]any { + promptText := strings.TrimSpace(prompt) + if promptText == "" { + promptText = "Generate an image." + } + metadata := map[string]any{ + "developer_mode_connector_ids": []any{}, + "selected_github_repos": []any{}, + "selected_all_github_repos": false, + "system_hints": []string{"picture_v2"}, + "serialization_metadata": map[string]any{ + "custom_symbol_offsets": []any{}, + }, + } + message := map[string]any{ + "id": uuid.NewString(), + "author": map[string]any{"role": "user"}, + "content": map[string]any{ + "content_type": "text", + "parts": []any{promptText}, + }, + "metadata": metadata, + "create_time": float64(time.Now().UnixMilli()) / 1000, + } + return map[string]any{ + "action": "next", + "client_prepare_state": "sent", + "parent_message_id": parentMessageID, + "messages": []any{message}, + "model": "auto", + "timezone_offset_min": openAITimezoneOffsetMinutes(), + "timezone": openAITimezoneName(), + "conversation_mode": map[string]any{"kind": "primary_assistant"}, + "system_hints": []string{"picture_v2"}, + "supports_buffering": true, + "supported_encodings": []string{"v1"}, + "client_contextual_info": map[string]any{"app_name": "chatgpt.com"}, + "force_nulligen": false, + "force_paragen": false, + "force_paragen_model_slug": "", + "force_rate_limit": false, + "websocket_request_id": uuid.NewString(), + } +} + func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { eventJSON, _ := json.Marshal(event) if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil { diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index 82606979..82ff0a8b 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -103,7 +103,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing. Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") require.NoError(t, err) require.NotEmpty(t, repo.updatedExtra) require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"]) @@ -134,7 +134,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") require.Error(t, err) require.NotEmpty(t, repo.updatedExtra) require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"]) diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index f3533ec4..808f1229 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -917,7 +917,15 @@ func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages( excludedIDs map[int64]struct{}, requiredCapability OpenAIImagesCapability, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { - return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability) + selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability) + if err == nil && selection != nil && selection.Account != nil { + return selection, decision, nil + } + // 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号) + if requiredCapability == OpenAIImagesCapabilityNative { + return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic) + } + return selection, decision, err } func (s *OpenAIGatewayService) selectAccountWithScheduler( diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 396c0381..48bce22b 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -1157,9 +1157,9 @@ func uploadOpenAIImageFiles(ctx context.Context, client *req.Client, headers htt if err != nil { return nil, err } - if putResp.Response != nil && putResp.Response.Body != nil { - _, _ = io.Copy(io.Discard, putResp.Response.Body) - _ = putResp.Response.Body.Close() + if putResp.Response != nil && putResp.Body != nil { + _, _ = io.Copy(io.Discard, putResp.Body) + _ = putResp.Body.Close() } if putResp.StatusCode < 200 || putResp.StatusCode >= 300 { return nil, newOpenAIImageStatusError(putResp, "upload image bytes failed") @@ -1294,10 +1294,10 @@ type openAIImageToolMessage struct { } func readOpenAIImageConversationStream(resp *req.Response, startTime time.Time) (string, []openAIImagePointerInfo, OpenAIUsage, *int, error) { - if resp == nil || resp.Response == nil || resp.Response.Body == nil { + if resp == nil || resp.Body == nil { return "", nil, OpenAIUsage{}, nil, fmt.Errorf("empty conversation response") } - reader := bufio.NewReader(resp.Response.Body) + reader := bufio.NewReader(resp.Body) var ( conversationID string firstTokenMs *int @@ -1529,8 +1529,8 @@ func pollOpenAIImageConversation(ctx context.Context, client *req.Client, header lastErr = err } else { if resp.StatusCode >= 200 && resp.StatusCode < 300 { - body, readErr := io.ReadAll(resp.Response.Body) - _ = resp.Response.Body.Close() + body, readErr := io.ReadAll(resp.Body) + _ = resp.Body.Close() if readErr != nil { lastErr = readErr goto waitNextPoll diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index 67409a7c..2e3db61b 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -55,12 +55,12 @@ />
-
+