mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
fix auth completion and payment resume hardening
This commit is contained in:
@@ -1350,10 +1350,24 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if !adoptionDecision.hasDecision() {
|
||||
response.Success(c, payload)
|
||||
return
|
||||
adoptionRequired, _ := payload["adoption_required"].(bool)
|
||||
if adoptionRequired {
|
||||
response.Success(c, payload)
|
||||
return
|
||||
}
|
||||
}
|
||||
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
|
||||
|
||||
decisionReq := adoptionDecision
|
||||
if !decisionReq.hasDecision() {
|
||||
adoptDisplayName := false
|
||||
adoptAvatar := false
|
||||
decisionReq = oauthAdoptionDecisionRequest{
|
||||
AdoptDisplayName: &adoptDisplayName,
|
||||
AdoptAvatar: &adoptAvatar,
|
||||
}
|
||||
}
|
||||
|
||||
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decisionReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -523,6 +523,68 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdopti
|
||||
require.NotNil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestExchangePendingOAuthCompletionLoginWithoutDecisionStillBindsIdentity(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
userEntity, err := client.User.Create().
|
||||
SetEmail("login-nodecision@example.com").
|
||||
SetUsername("legacy-name").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("login-nodecision-session-token").
|
||||
SetIntent("login").
|
||||
SetProviderType("linuxdo").
|
||||
SetProviderKey("linuxdo").
|
||||
SetProviderSubject("login-nodecision-123").
|
||||
SetTargetUserID(userEntity.ID).
|
||||
SetResolvedEmail(userEntity.Email).
|
||||
SetBrowserSessionKey("login-nodecision-browser-session-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "login-nodecision-user",
|
||||
}).
|
||||
SetLocalFlowState(map[string]any{
|
||||
oauthCompletionResponseKey: map[string]any{
|
||||
"access_token": "access-token",
|
||||
},
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-nodecision-browser-session-key")})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.ExchangePendingOAuthCompletion(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("linuxdo"),
|
||||
authidentity.ProviderKeyEQ("linuxdo"),
|
||||
authidentity.ProviderSubjectEQ("login-nodecision-123"),
|
||||
).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userEntity.ID, identity.UserID)
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.IDEQ(session.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||
settingValues: map[string]string{
|
||||
|
||||
@@ -190,8 +190,9 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// VerifyOrderPublic verifies payment status without user authentication.
|
||||
// Used by the payment result page when the user's session has expired.
|
||||
// VerifyOrderPublic returns the currently persisted public order state without
|
||||
// triggering any upstream reconciliation. Signed resume-token recovery is the
|
||||
// only public recovery path allowed to query upstream state.
|
||||
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
|
||||
o, err := s.entClient.PaymentOrder.Query().
|
||||
Where(paymentorder.OutTradeNo(outTradeNo)).
|
||||
@@ -199,15 +200,6 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
|
||||
if err != nil {
|
||||
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
|
||||
}
|
||||
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
|
||||
result := s.checkPaid(ctx, o)
|
||||
if result == checkPaidResultAlreadyPaid {
|
||||
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reload order: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -200,3 +200,48 @@ func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T)
|
||||
require.Equal(t, order.ID, got.ID)
|
||||
require.Equal(t, 1, provider.queryCount)
|
||||
}
|
||||
|
||||
func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
user, err := client.User.Create().
|
||||
SetEmail("public-verify@example.com").
|
||||
SetPasswordHash("hash").
|
||||
SetUsername("public-verify-user").
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
order, err := client.PaymentOrder.Create().
|
||||
SetUserID(user.ID).
|
||||
SetUserEmail(user.Email).
|
||||
SetUserName(user.Username).
|
||||
SetAmount(88).
|
||||
SetPayAmount(88).
|
||||
SetFeeRate(0).
|
||||
SetRechargeCode("PUBLIC-VERIFY").
|
||||
SetOutTradeNo("sub2_public_verify_pending").
|
||||
SetPaymentType(payment.TypeAlipay).
|
||||
SetPaymentTradeNo("trade-public-verify").
|
||||
SetOrderType(payment.OrderTypeBalance).
|
||||
SetStatus(OrderStatusPending).
|
||||
SetExpiresAt(time.Now().Add(time.Hour)).
|
||||
SetClientIP("127.0.0.1").
|
||||
SetSrcHost("api.example.com").
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
registry := payment.NewRegistry()
|
||||
provider := &paymentResumeLookupProvider{}
|
||||
registry.Register(provider)
|
||||
|
||||
svc := &PaymentService{
|
||||
entClient: client,
|
||||
registry: registry,
|
||||
providersLoaded: true,
|
||||
}
|
||||
|
||||
got, err := svc.VerifyOrderPublic(ctx, order.OutTradeNo)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, order.ID, got.ID)
|
||||
require.Equal(t, 0, provider.queryCount)
|
||||
}
|
||||
|
||||
@@ -36,6 +36,9 @@ const (
|
||||
|
||||
paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED"
|
||||
paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key"
|
||||
|
||||
paymentResumeTokenTTL = 24 * time.Hour
|
||||
wechatPaymentResumeTokenTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
type ResumeTokenClaims struct {
|
||||
@@ -46,6 +49,7 @@ type ResumeTokenClaims struct {
|
||||
PaymentType string `json:"pt,omitempty"`
|
||||
CanonicalReturnURL string `json:"ru,omitempty"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
ExpiresAt int64 `json:"exp,omitempty"`
|
||||
}
|
||||
|
||||
type WeChatPaymentResumeClaims struct {
|
||||
@@ -58,6 +62,7 @@ type WeChatPaymentResumeClaims struct {
|
||||
RedirectTo string `json:"rd,omitempty"`
|
||||
Scope string `json:"scp,omitempty"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
ExpiresAt int64 `json:"exp,omitempty"`
|
||||
}
|
||||
|
||||
type PaymentResumeService struct {
|
||||
@@ -263,6 +268,9 @@ func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, er
|
||||
if claims.IssuedAt == 0 {
|
||||
claims.IssuedAt = time.Now().Unix()
|
||||
}
|
||||
if claims.ExpiresAt == 0 {
|
||||
claims.ExpiresAt = time.Now().Add(paymentResumeTokenTTL).Unix()
|
||||
}
|
||||
return s.createSignedToken(claims)
|
||||
}
|
||||
|
||||
@@ -277,6 +285,9 @@ func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, err
|
||||
if claims.OrderID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
|
||||
}
|
||||
if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_RESUME_TOKEN", "resume token has expired"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
@@ -291,6 +302,9 @@ func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPayme
|
||||
if claims.IssuedAt == 0 {
|
||||
claims.IssuedAt = time.Now().Unix()
|
||||
}
|
||||
if claims.ExpiresAt == 0 {
|
||||
claims.ExpiresAt = time.Now().Add(wechatPaymentResumeTokenTTL).Unix()
|
||||
}
|
||||
if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
|
||||
claims.PaymentType = normalized
|
||||
}
|
||||
@@ -319,6 +333,9 @@ func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeC
|
||||
if claims.OpenID == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
|
||||
}
|
||||
if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token has expired"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
|
||||
claims.PaymentType = normalized
|
||||
}
|
||||
@@ -355,6 +372,16 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
|
||||
return json.Unmarshal(payload, dest)
|
||||
}
|
||||
|
||||
func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
|
||||
if expiresAt <= 0 {
|
||||
return nil
|
||||
}
|
||||
if time.Now().Unix() > expiresAt {
|
||||
return infraerrors.BadRequest(code, message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) sign(payload string) string {
|
||||
mac := hmac.New(sha256.New, s.signingKey)
|
||||
_, _ = mac.Write([]byte(payload))
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
@@ -175,6 +176,26 @@ func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenRejectsExpiredToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
|
||||
token, err := svc.CreateToken(ResumeTokenClaims{
|
||||
OrderID: 42,
|
||||
UserID: 7,
|
||||
IssuedAt: time.Now().Add(-25 * time.Hour).Unix(),
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateToken returned error: %v", err)
|
||||
}
|
||||
|
||||
_, err = svc.ParseToken(token)
|
||||
if err == nil {
|
||||
t.Fatal("ParseToken should reject expired tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -233,6 +254,26 @@ func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMi
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
|
||||
token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
|
||||
OpenID: "openid-123",
|
||||
PaymentType: payment.TypeWxpay,
|
||||
IssuedAt: time.Now().Add(-30 * time.Minute).Unix(),
|
||||
ExpiresAt: time.Now().Add(-1 * time.Minute).Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
|
||||
}
|
||||
|
||||
_, err = svc.ParseWeChatPaymentResumeToken(token)
|
||||
if err == nil {
|
||||
t.Fatal("ParseWeChatPaymentResumeToken should reject expired tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeVisibleMethodSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -219,7 +219,6 @@ onMounted(async () => {
|
||||
const routeOrderId = Number(route.query.order_id) || 0
|
||||
const outTradeNo = String(route.query.out_trade_no || '')
|
||||
let orderId = 0
|
||||
let canUseLegacyPublicVerify = false
|
||||
|
||||
if (resumeToken && typeof window !== 'undefined') {
|
||||
const restored = readPaymentRecoverySnapshot(
|
||||
@@ -264,23 +263,12 @@ onMounted(async () => {
|
||||
const hasLegacyFallbackContext = typeof route.query.trade_status === 'string'
|
||||
&& route.query.trade_status.trim() !== ''
|
||||
if (!order.value && !resumeToken && !orderId && outTradeNo && hasLegacyFallbackContext) {
|
||||
canUseLegacyPublicVerify = true
|
||||
returnInfo.value = {
|
||||
outTradeNo,
|
||||
money: String(route.query.money || ''),
|
||||
type: String(route.query.type || ''),
|
||||
tradeStatus: String(route.query.trade_status || ''),
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await paymentAPI.verifyOrderPublic(outTradeNo)
|
||||
order.value = result.data
|
||||
} catch (_err: unknown) {
|
||||
try {
|
||||
const result = await paymentAPI.verifyOrder(outTradeNo)
|
||||
order.value = result.data
|
||||
} catch (_e: unknown) { /* fall through */ }
|
||||
}
|
||||
}
|
||||
|
||||
const refreshOrder = async (): Promise<PaymentOrder | null> => {
|
||||
@@ -292,20 +280,6 @@ onMounted(async () => {
|
||||
return await paymentStore.pollOrderStatus(orderId)
|
||||
}
|
||||
|
||||
if (canUseLegacyPublicVerify && outTradeNo) {
|
||||
try {
|
||||
const result = await paymentAPI.verifyOrderPublic(outTradeNo)
|
||||
return result.data
|
||||
} catch (_err: unknown) {
|
||||
try {
|
||||
const result = await paymentAPI.verifyOrder(outTradeNo)
|
||||
return result.data
|
||||
} catch (_e: unknown) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
|
||||
@@ -225,16 +225,13 @@ describe('PaymentResultView', () => {
|
||||
expect(verifyOrder).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('keeps legacy out_trade_no verification as a fallback when no order context is available', async () => {
|
||||
it('does not use anonymous out_trade_no verification when no signed resume context is available', async () => {
|
||||
routeState.query = {
|
||||
out_trade_no: 'legacy-123',
|
||||
trade_status: 'TRADE_SUCCESS',
|
||||
}
|
||||
verifyOrderPublic.mockResolvedValue({
|
||||
data: orderFactory('PAID'),
|
||||
})
|
||||
|
||||
const wrapper = mount(PaymentResultView, {
|
||||
mount(PaymentResultView, {
|
||||
global: {
|
||||
stubs: {
|
||||
OrderStatusBadge: true,
|
||||
@@ -244,8 +241,8 @@ describe('PaymentResultView', () => {
|
||||
|
||||
await flushPromises()
|
||||
|
||||
expect(verifyOrderPublic).toHaveBeenCalledWith('legacy-123')
|
||||
expect(wrapper.text()).toContain('payment.result.success')
|
||||
expect(verifyOrderPublic).not.toHaveBeenCalled()
|
||||
expect(verifyOrder).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not use public out_trade_no verification for bare order numbers without legacy return markers', async () => {
|
||||
|
||||
Reference in New Issue
Block a user