mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
fix(auth): harden oauth identity upgrade paths
This commit is contained in:
@@ -576,6 +576,258 @@ FROM auth_identity_migration_reports
|
||||
require.Equal(t, beforeCount, afterCount)
|
||||
}
|
||||
|
||||
func TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
ctx := context.Background()
|
||||
|
||||
migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
|
||||
migrationSQL, err := os.ReadFile(migrationPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
|
||||
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
|
||||
|
||||
var linuxDoFirstUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&linuxDoFirstUserID))
|
||||
|
||||
var linuxDoSecondUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&linuxDoSecondUserID))
|
||||
|
||||
var wechatFirstUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&wechatFirstUserID))
|
||||
|
||||
var wechatSecondUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&wechatSecondUserID))
|
||||
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}')
|
||||
RETURNING id
|
||||
`, linuxDoFirstUserID).Scan(new(int64)))
|
||||
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}')
|
||||
RETURNING id
|
||||
`, linuxDoSecondUserID).Scan(new(int64)))
|
||||
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}')
|
||||
RETURNING id
|
||||
`, wechatFirstUserID).Scan(new(int64)))
|
||||
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}')
|
||||
RETURNING id
|
||||
`, wechatSecondUserID).Scan(new(int64)))
|
||||
|
||||
_, err = tx.ExecContext(ctx, string(migrationSQL))
|
||||
require.NoError(t, err)
|
||||
|
||||
var linuxDoIdentityCount int
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM auth_identities
|
||||
WHERE provider_type = 'linuxdo'
|
||||
AND provider_key = 'linuxdo'
|
||||
AND provider_subject = 'linuxdo-ambiguous-subject'
|
||||
`).Scan(&linuxDoIdentityCount))
|
||||
require.Zero(t, linuxDoIdentityCount)
|
||||
|
||||
var wechatIdentityCount int
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM auth_identities
|
||||
WHERE provider_type = 'wechat'
|
||||
AND provider_key = 'wechat-main'
|
||||
AND provider_subject = 'union-ambiguous-subject'
|
||||
`).Scan(&wechatIdentityCount))
|
||||
require.Zero(t, wechatIdentityCount)
|
||||
|
||||
var wechatChannelCount int
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM auth_identity_channels
|
||||
WHERE provider_type = 'wechat'
|
||||
AND provider_key = 'wechat-main'
|
||||
AND channel = 'oa'
|
||||
AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b')
|
||||
`).Scan(&wechatChannelCount))
|
||||
require.Zero(t, wechatChannelCount)
|
||||
}
|
||||
|
||||
func TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
ctx := context.Background()
|
||||
|
||||
migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
|
||||
migration115SQL, err := os.ReadFile(migration115Path)
|
||||
require.NoError(t, err)
|
||||
|
||||
migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
|
||||
migration116SQL, err := os.ReadFile(migration116Path)
|
||||
require.NoError(t, err)
|
||||
|
||||
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
|
||||
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
|
||||
|
||||
var linuxDoFirstUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&linuxDoFirstUserID))
|
||||
|
||||
var linuxDoSecondUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&linuxDoSecondUserID))
|
||||
|
||||
var wechatFirstUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&wechatFirstUserID))
|
||||
|
||||
var wechatSecondUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&wechatSecondUserID))
|
||||
|
||||
var linuxDoFirstLegacyID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}')
|
||||
RETURNING id
|
||||
`, linuxDoFirstUserID).Scan(&linuxDoFirstLegacyID))
|
||||
|
||||
var linuxDoSecondLegacyID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}')
|
||||
RETURNING id
|
||||
`, linuxDoSecondUserID).Scan(&linuxDoSecondLegacyID))
|
||||
|
||||
var wechatFirstLegacyID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}')
|
||||
RETURNING id
|
||||
`, wechatFirstUserID).Scan(&wechatFirstLegacyID))
|
||||
|
||||
var wechatSecondLegacyID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}')
|
||||
RETURNING id
|
||||
`, wechatSecondUserID).Scan(&wechatSecondLegacyID))
|
||||
|
||||
_, err = tx.ExecContext(ctx, string(migration115SQL))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tx.ExecContext(ctx, string(migration116SQL))
|
||||
require.NoError(t, err)
|
||||
|
||||
var identityCount int
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM auth_identities
|
||||
WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject')
|
||||
OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject')
|
||||
`).Scan(&identityCount))
|
||||
require.Zero(t, identityCount)
|
||||
|
||||
var conflictReportCount int
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM auth_identity_migration_reports
|
||||
WHERE report_type = 'legacy_external_identity_conflict'
|
||||
AND report_key IN ($1, $2, $3, $4)
|
||||
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&conflictReportCount))
|
||||
require.Equal(t, 4, conflictReportCount)
|
||||
|
||||
var winnerAttributedReportCount int
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM auth_identity_migration_reports
|
||||
WHERE report_type = 'legacy_external_identity_conflict'
|
||||
AND report_key IN ($1, $2, $3, $4)
|
||||
AND details ->> 'existing_identity_id' IS NOT NULL
|
||||
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&winnerAttributedReportCount))
|
||||
require.Zero(t, winnerAttributedReportCount)
|
||||
}
|
||||
|
||||
func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -51,6 +51,8 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
|
||||
const migrationsAdvisoryLockID int64 = 694208311321144027
|
||||
const migrationsLockRetryInterval = 500 * time.Millisecond
|
||||
const nonTransactionalMigrationSuffix = "_notx.sql"
|
||||
const paymentOrdersOutTradeNoUniqueMigration = "120_enforce_payment_orders_out_trade_no_unique_notx.sql"
|
||||
const paymentOrdersOutTradeNoUniqueIndex = "paymentorder_out_trade_no_unique"
|
||||
|
||||
type migrationChecksumCompatibilityRule struct {
|
||||
fileChecksum string
|
||||
@@ -65,9 +67,11 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil
|
||||
"054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"),
|
||||
"061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"),
|
||||
"109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"),
|
||||
"115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"),
|
||||
"116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"),
|
||||
"118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"),
|
||||
"119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"),
|
||||
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"),
|
||||
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"),
|
||||
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"),
|
||||
}
|
||||
|
||||
@@ -195,6 +199,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
}
|
||||
|
||||
if nonTx {
|
||||
if err := prepareNonTransactionalMigration(ctx, db, name); err != nil {
|
||||
return fmt.Errorf("prepare migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
|
||||
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
|
||||
statements := splitSQLStatements(content)
|
||||
@@ -244,6 +252,88 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareNonTransactionalMigration(ctx context.Context, db *sql.DB, name string) error {
|
||||
switch name {
|
||||
case paymentOrdersOutTradeNoUniqueMigration:
|
||||
return preparePaymentOrdersOutTradeNoUniqueMigration(ctx, db)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func preparePaymentOrdersOutTradeNoUniqueMigration(ctx context.Context, db *sql.DB) error {
|
||||
duplicates, err := findDuplicatePaymentOrderOutTradeNos(ctx, db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("precheck duplicate out_trade_no: %w", err)
|
||||
}
|
||||
if len(duplicates) > 0 {
|
||||
return fmt.Errorf(
|
||||
"duplicate out_trade_no values block %s; remediate duplicates before retrying: %s",
|
||||
paymentOrdersOutTradeNoUniqueMigration,
|
||||
strings.Join(duplicates, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
invalid, err := indexIsInvalid(ctx, db, paymentOrdersOutTradeNoUniqueIndex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
|
||||
}
|
||||
if !invalid {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", paymentOrdersOutTradeNoUniqueIndex)); err != nil {
|
||||
return fmt.Errorf("drop invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]string, error) {
|
||||
rows, err := db.QueryContext(ctx, `
|
||||
SELECT out_trade_no, COUNT(*) AS duplicate_count
|
||||
FROM payment_orders
|
||||
WHERE out_trade_no <> ''
|
||||
GROUP BY out_trade_no
|
||||
HAVING COUNT(*) > 1
|
||||
ORDER BY duplicate_count DESC, out_trade_no
|
||||
LIMIT 5
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
duplicates := make([]string, 0, 5)
|
||||
for rows.Next() {
|
||||
var outTradeNo string
|
||||
var duplicateCount int
|
||||
if err := rows.Scan(&outTradeNo, &duplicateCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
duplicates = append(duplicates, fmt.Sprintf("%s (count=%d)", outTradeNo, duplicateCount))
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return duplicates, nil
|
||||
}
|
||||
|
||||
func indexIsInvalid(ctx context.Context, db *sql.DB, indexName string) (bool, error) {
|
||||
var invalid bool
|
||||
err := db.QueryRowContext(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_class idx
|
||||
JOIN pg_namespace ns ON ns.oid = idx.relnamespace
|
||||
JOIN pg_index i ON i.indexrelid = idx.oid
|
||||
WHERE ns.nspname = 'public'
|
||||
AND idx.relname = $1
|
||||
AND NOT i.indisvalid
|
||||
)
|
||||
`, indexName).Scan(&invalid)
|
||||
return invalid, err
|
||||
}
|
||||
|
||||
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
|
||||
if err != nil {
|
||||
|
||||
@@ -70,6 +70,24 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("115历史checksum可兼容修复后的legacy external backfill", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"115_auth_identity_legacy_external_backfill.sql",
|
||||
"4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f",
|
||||
"022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f",
|
||||
)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("116历史checksum可兼容修复后的legacy external safety reports", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"116_auth_identity_legacy_external_safety_reports.sql",
|
||||
"f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877",
|
||||
"07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488",
|
||||
)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("119历史checksum可兼容占位文件", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"119_enforce_payment_orders_out_trade_no_unique.sql",
|
||||
@@ -79,6 +97,21 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) {
|
||||
for _, dbChecksum := range []string{
|
||||
"e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61",
|
||||
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22",
|
||||
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a",
|
||||
} {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"120_enforce_payment_orders_out_trade_no_unique_notx.sql",
|
||||
dbChecksum,
|
||||
"34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074",
|
||||
)
|
||||
require.True(t, ok)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("119未知checksum不兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"119_enforce_payment_orders_out_trade_no_unique.sql",
|
||||
|
||||
@@ -96,6 +96,8 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
|
||||
|
||||
func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) {
|
||||
for _, name := range []string{
|
||||
"115_auth_identity_legacy_external_backfill.sql",
|
||||
"116_auth_identity_legacy_external_safety_reports.sql",
|
||||
"118_wechat_dual_mode_and_auth_source_defaults.sql",
|
||||
"120_enforce_payment_orders_out_trade_no_unique_notx.sql",
|
||||
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql",
|
||||
|
||||
@@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}).AddRow("dup-out-trade-no", 2))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
|
||||
Data: []byte(`
|
||||
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
|
||||
ON payment_orders (out_trade_no)
|
||||
WHERE out_trade_no <> '';
|
||||
|
||||
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
|
||||
`),
|
||||
},
|
||||
}
|
||||
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "duplicate out_trade_no")
|
||||
require.Contains(t, err.Error(), "dup-out-trade-no")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("paymentorder_out_trade_no_unique").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
|
||||
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql", sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
|
||||
Data: []byte(`
|
||||
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
|
||||
ON payment_orders (out_trade_no)
|
||||
WHERE out_trade_no <> '';
|
||||
|
||||
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
|
||||
`),
|
||||
},
|
||||
}
|
||||
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -93,6 +93,19 @@ func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T)
|
||||
tx := testTx(t)
|
||||
|
||||
requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false)
|
||||
requireColumn(t, tx, "users", "signup_source", "character varying", 20, false)
|
||||
requireColumnDefaultContains(t, tx, "users", "signup_source", "email")
|
||||
requireConstraintDefinitionContains(
|
||||
t,
|
||||
tx,
|
||||
"users",
|
||||
"users_signup_source_check",
|
||||
"signup_source",
|
||||
"'email'",
|
||||
"'linuxdo'",
|
||||
"'wechat'",
|
||||
"'oidc'",
|
||||
)
|
||||
|
||||
requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE")
|
||||
requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE")
|
||||
@@ -195,6 +208,45 @@ LIMIT 1
|
||||
require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable)
|
||||
}
|
||||
|
||||
func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) {
|
||||
t.Helper()
|
||||
|
||||
var def string
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT pg_get_constraintdef(c.oid)
|
||||
FROM pg_constraint c
|
||||
JOIN pg_class tbl ON tbl.oid = c.conrelid
|
||||
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
|
||||
WHERE ns.nspname = 'public'
|
||||
AND tbl.relname = $1
|
||||
AND c.conname = $2
|
||||
`, table, constraint).Scan(&def)
|
||||
require.NoError(t, err, "query constraint definition for %s.%s", table, constraint)
|
||||
|
||||
for _, fragment := range fragments {
|
||||
require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment)
|
||||
}
|
||||
}
|
||||
|
||||
func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) {
|
||||
t.Helper()
|
||||
|
||||
var columnDefault sql.NullString
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = $1
|
||||
AND column_name = $2
|
||||
`, table, column).Scan(&columnDefault)
|
||||
require.NoError(t, err, "query column_default for %s.%s", table, column)
|
||||
require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column)
|
||||
|
||||
for _, fragment := range fragments {
|
||||
require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment)
|
||||
}
|
||||
}
|
||||
|
||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -4,11 +4,15 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
@@ -120,6 +124,113 @@ type sqlQueryExecutor interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
var repositoryScopedKeyLocks = newScopedKeyLockRegistry()
|
||||
|
||||
type scopedKeyLockRegistry struct {
|
||||
mu sync.Mutex
|
||||
locks map[string]*scopedKeyLockEntry
|
||||
}
|
||||
|
||||
type scopedKeyLockEntry struct {
|
||||
mu sync.Mutex
|
||||
refs int
|
||||
}
|
||||
|
||||
func newScopedKeyLockRegistry() *scopedKeyLockRegistry {
|
||||
return &scopedKeyLockRegistry{
|
||||
locks: make(map[string]*scopedKeyLockEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *scopedKeyLockRegistry) lock(keys ...string) func() {
|
||||
normalized := normalizeLockKeys(keys...)
|
||||
if len(normalized) == 0 {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
entries := make([]*scopedKeyLockEntry, 0, len(normalized))
|
||||
r.mu.Lock()
|
||||
for _, key := range normalized {
|
||||
entry := r.locks[key]
|
||||
if entry == nil {
|
||||
entry = &scopedKeyLockEntry{}
|
||||
r.locks[key] = entry
|
||||
}
|
||||
entry.refs++
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, entry := range entries {
|
||||
entry.mu.Lock()
|
||||
}
|
||||
|
||||
return func() {
|
||||
for i := len(entries) - 1; i >= 0; i-- {
|
||||
entries[i].mu.Unlock()
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for idx, key := range normalized {
|
||||
entry := entries[idx]
|
||||
entry.refs--
|
||||
if entry.refs == 0 {
|
||||
delete(r.locks, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeLockKeys(keys ...string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
deduped := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
deduped[trimmed] = struct{}{}
|
||||
}
|
||||
if len(deduped) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := make([]string, 0, len(deduped))
|
||||
for key := range deduped {
|
||||
normalized = append(normalized, key)
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return normalized
|
||||
}
|
||||
|
||||
func advisoryLockHash(key string) int64 {
|
||||
hasher := fnv.New64a()
|
||||
_, _ = hasher.Write([]byte(key))
|
||||
return int64(hasher.Sum64())
|
||||
}
|
||||
|
||||
func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) {
|
||||
release := repositoryScopedKeyLocks.lock(keys...)
|
||||
normalized := normalizeLockKeys(keys...)
|
||||
if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres {
|
||||
return release, nil
|
||||
}
|
||||
|
||||
for _, key := range normalized {
|
||||
rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key))
|
||||
if err != nil {
|
||||
release()
|
||||
return nil, err
|
||||
}
|
||||
_ = rows.Close()
|
||||
}
|
||||
return release, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
|
||||
if dbent.TxFromContext(ctx) != nil {
|
||||
return fn(ctx)
|
||||
@@ -329,7 +440,11 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey)
|
||||
update := client.AuthIdentity.UpdateOneID(identity.ID)
|
||||
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) {
|
||||
update = update.SetProviderKey(targetProviderKey)
|
||||
}
|
||||
if input.Metadata != nil {
|
||||
update = update.SetMetadata(copyMetadata(input.Metadata))
|
||||
}
|
||||
@@ -378,8 +493,12 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey)
|
||||
update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
|
||||
SetIdentityID(identity.ID)
|
||||
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) {
|
||||
update = update.SetProviderKey(targetProviderKey)
|
||||
}
|
||||
if input.ChannelMetadata != nil {
|
||||
update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
|
||||
}
|
||||
@@ -418,13 +537,52 @@ func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
|
||||
return keys
|
||||
}
|
||||
|
||||
func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string {
|
||||
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||
existingKey = strings.TrimSpace(existingKey)
|
||||
requestedKey = strings.TrimSpace(requestedKey)
|
||||
if providerType != "wechat" {
|
||||
if requestedKey != "" {
|
||||
return requestedKey
|
||||
}
|
||||
return existingKey
|
||||
}
|
||||
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
|
||||
return "wechat-main"
|
||||
}
|
||||
if requestedKey != "" {
|
||||
return requestedKey
|
||||
}
|
||||
return existingKey
|
||||
}
|
||||
|
||||
func compatibleIdentityProviderKeyRank(providerType, providerKey string) int {
|
||||
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||
providerKey = strings.TrimSpace(providerKey)
|
||||
if providerType != "wechat" {
|
||||
return 0
|
||||
}
|
||||
switch {
|
||||
case strings.EqualFold(providerKey, "wechat-main"):
|
||||
return 0
|
||||
case strings.EqualFold(providerKey, "wechat"):
|
||||
return 2
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
|
||||
var selected *dbent.AuthIdentity
|
||||
for _, record := range records {
|
||||
if record.UserID == userID {
|
||||
return record
|
||||
if record.UserID != userID {
|
||||
continue
|
||||
}
|
||||
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
|
||||
selected = record
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return selected
|
||||
}
|
||||
|
||||
func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
|
||||
@@ -437,12 +595,16 @@ func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64)
|
||||
}
|
||||
|
||||
func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
|
||||
var selected *dbent.AuthIdentityChannel
|
||||
for _, record := range records {
|
||||
if record.Edges.Identity != nil && record.Edges.Identity.UserID == userID {
|
||||
return record
|
||||
if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
|
||||
continue
|
||||
}
|
||||
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
|
||||
selected = record
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return selected
|
||||
}
|
||||
|
||||
func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
|
||||
@@ -479,51 +641,70 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
|
||||
}
|
||||
|
||||
func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||
if _, err := client.IdentityAdoptionDecision.Update().
|
||||
Where(
|
||||
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
|
||||
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
|
||||
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
|
||||
s.Where(entsql.Or(
|
||||
entsql.IsNull(col),
|
||||
entsql.NEQ(col, input.PendingAuthSessionID),
|
||||
))
|
||||
}),
|
||||
).
|
||||
ClearIdentityID().
|
||||
Save(ctx); err != nil {
|
||||
return nil, err
|
||||
var result *dbent.IdentityAdoptionDecision
|
||||
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
|
||||
client := clientFromContext(txCtx, r.client)
|
||||
releaseLocks, err := lockRepositoryScopedKeys(
|
||||
txCtx,
|
||||
client,
|
||||
txAwareSQLExecutor(txCtx, r.sql, r.client),
|
||||
identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)...,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer releaseLocks()
|
||||
|
||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||
if _, err := client.IdentityAdoptionDecision.Update().
|
||||
Where(
|
||||
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
|
||||
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
|
||||
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
|
||||
s.Where(entsql.Or(
|
||||
entsql.IsNull(col),
|
||||
entsql.NEQ(col, input.PendingAuthSessionID),
|
||||
))
|
||||
}),
|
||||
).
|
||||
ClearIdentityID().
|
||||
Save(txCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
current, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if current == nil {
|
||||
create := client.IdentityAdoptionDecision.Create().
|
||||
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar).
|
||||
SetDecidedAt(now)
|
||||
if input.IdentityID != nil {
|
||||
SetDecidedAt(time.Now().UTC())
|
||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||
create = create.SetIdentityID(*input.IdentityID)
|
||||
}
|
||||
return create.Save(ctx)
|
||||
}
|
||||
|
||||
update := client.IdentityAdoptionDecision.UpdateOneID(current.ID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar)
|
||||
if input.IdentityID != nil {
|
||||
update = update.SetIdentityID(*input.IdentityID)
|
||||
decisionID, err := create.
|
||||
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
|
||||
UpdateNewValues().
|
||||
ID(txCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return update.Save(ctx)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
|
||||
keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)}
|
||||
if identityID != nil && *identityID > 0 {
|
||||
keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID))
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
|
||||
|
||||
@@ -0,0 +1,212 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias(t *testing.T) {
|
||||
repo, client := newUserEntRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &service.User{
|
||||
Email: "wechat-legacy@example.com",
|
||||
Username: "wechat-legacy",
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, user))
|
||||
|
||||
legacyIdentity, err := client.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat").
|
||||
SetProviderSubject("union-legacy-123").
|
||||
SetMetadata(map[string]any{"source": "legacy-alias"}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
legacyChannel, err := client.AuthIdentityChannel.Create().
|
||||
SetIdentityID(legacyIdentity.ID).
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat").
|
||||
SetChannel("oa").
|
||||
SetChannelAppID("wx-app-legacy").
|
||||
SetChannelSubject("openid-legacy-123").
|
||||
SetMetadata(map[string]any{"scene": "legacy-alias"}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
bound, err := repo.BindAuthIdentityToUser(ctx, BindAuthIdentityInput{
|
||||
UserID: user.ID,
|
||||
Canonical: AuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-main",
|
||||
ProviderSubject: "union-legacy-123",
|
||||
},
|
||||
Channel: &AuthIdentityChannelKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-main",
|
||||
Channel: "oa",
|
||||
ChannelAppID: "wx-app-legacy",
|
||||
ChannelSubject: "openid-legacy-123",
|
||||
},
|
||||
Metadata: map[string]any{"source": "canonical-bind"},
|
||||
ChannelMetadata: map[string]any{"scene": "canonical-bind"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, bound)
|
||||
require.NotNil(t, bound.Identity)
|
||||
require.NotNil(t, bound.Channel)
|
||||
require.Equal(t, legacyIdentity.ID, bound.Identity.ID)
|
||||
require.Equal(t, legacyChannel.ID, bound.Channel.ID)
|
||||
require.Equal(t, "wechat-main", bound.Identity.ProviderKey)
|
||||
require.Equal(t, "wechat-main", bound.Channel.ProviderKey)
|
||||
|
||||
reloadedIdentity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "wechat-main", reloadedIdentity.ProviderKey)
|
||||
require.Equal(t, "canonical-bind", reloadedIdentity.Metadata["source"])
|
||||
|
||||
reloadedChannel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "wechat-main", reloadedChannel.ProviderKey)
|
||||
require.Equal(t, "canonical-bind", reloadedChannel.Metadata["scene"])
|
||||
|
||||
identityCount, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.UserIDEQ(user.ID),
|
||||
authidentity.ProviderTypeEQ("wechat"),
|
||||
authidentity.ProviderSubjectEQ("union-legacy-123"),
|
||||
).
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, identityCount)
|
||||
|
||||
channelCount, err := client.AuthIdentityChannel.Query().
|
||||
Where(
|
||||
authidentitychannel.ProviderTypeEQ("wechat"),
|
||||
authidentitychannel.ChannelEQ("oa"),
|
||||
authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
|
||||
authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
|
||||
).
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, channelCount)
|
||||
}
|
||||
|
||||
func TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency(t *testing.T) {
|
||||
repo, client := newUserEntRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &service.User{
|
||||
Email: "repo-adoption@example.com",
|
||||
Username: "repo-adoption",
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, user))
|
||||
|
||||
identity, err := client.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat-main").
|
||||
SetProviderSubject("union-repo-adoption").
|
||||
SetMetadata(map[string]any{}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("pending-repo-adoption").
|
||||
SetIntent("bind_current_user").
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat-main").
|
||||
SetProviderSubject("union-repo-adoption").
|
||||
SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
|
||||
SetUpstreamIdentityClaims(map[string]any{"provider_subject": "union-repo-adoption"}).
|
||||
SetLocalFlowState(map[string]any{"step": "pending"}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
firstCreateStarted := make(chan struct{})
|
||||
releaseFirstCreate := make(chan struct{})
|
||||
var firstCreate sync.Once
|
||||
client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
|
||||
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
|
||||
blocked := false
|
||||
if m.Op().Is(dbent.OpCreate) {
|
||||
firstCreate.Do(func() {
|
||||
blocked = true
|
||||
close(firstCreateStarted)
|
||||
})
|
||||
}
|
||||
if blocked {
|
||||
<-releaseFirstCreate
|
||||
}
|
||||
return next.Mutate(ctx, m)
|
||||
})
|
||||
})
|
||||
|
||||
type adoptionResult struct {
|
||||
decision *dbent.IdentityAdoptionDecision
|
||||
err error
|
||||
}
|
||||
|
||||
input := IdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: session.ID,
|
||||
IdentityID: &identity.ID,
|
||||
AdoptDisplayName: true,
|
||||
AdoptAvatar: true,
|
||||
}
|
||||
|
||||
results := make(chan adoptionResult, 2)
|
||||
go func() {
|
||||
decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
|
||||
results <- adoptionResult{decision: decision, err: err}
|
||||
}()
|
||||
|
||||
<-firstCreateStarted
|
||||
|
||||
go func() {
|
||||
decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
|
||||
results <- adoptionResult{decision: decision, err: err}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
close(releaseFirstCreate)
|
||||
|
||||
first := <-results
|
||||
second := <-results
|
||||
|
||||
require.NoError(t, first.err)
|
||||
require.NoError(t, second.err)
|
||||
require.NotNil(t, first.decision)
|
||||
require.NotNil(t, second.decision)
|
||||
require.Equal(t, first.decision.ID, second.decision.ID)
|
||||
|
||||
count, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
loaded, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, loaded.IdentityID)
|
||||
require.Equal(t, identity.ID, *loaded.IdentityID)
|
||||
require.True(t, loaded.AdoptDisplayName)
|
||||
require.True(t, loaded.AdoptAvatar)
|
||||
}
|
||||
@@ -43,9 +43,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
if err := r.ensureNormalizedEmailAvailable(ctx, 0, userIn.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
||||
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
||||
@@ -55,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
}
|
||||
|
||||
var txClient *dbent.Client
|
||||
txCtx := ctx
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
txCtx = dbent.NewTxContext(ctx, tx)
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
||||
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||
@@ -67,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
}
|
||||
}
|
||||
|
||||
releaseEmailLock, err := lockRepositoryScopedKeys(
|
||||
txCtx,
|
||||
txClient,
|
||||
txAwareSQLExecutor(txCtx, r.sql, r.client),
|
||||
normalizedEmailUniquenessLockKey(userIn.Email),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer releaseEmailLock()
|
||||
|
||||
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
created, err := txClient.User.Create().
|
||||
SetEmail(userIn.Email).
|
||||
SetUsername(userIn.Username).
|
||||
@@ -79,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
|
||||
SetNillableLastLoginAt(userIn.LastLoginAt).
|
||||
SetNillableLastActiveAt(userIn.LastActiveAt).
|
||||
Save(ctx)
|
||||
Save(txCtx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
|
||||
if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -149,9 +163,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
if err := r.ensureNormalizedEmailAvailable(ctx, userIn.ID, userIn.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
@@ -160,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
}
|
||||
|
||||
var txClient *dbent.Client
|
||||
txCtx := ctx
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
txCtx = dbent.NewTxContext(ctx, tx)
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
||||
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||
@@ -171,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
txClient = r.client
|
||||
}
|
||||
}
|
||||
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
|
||||
|
||||
releaseEmailLock, err := lockRepositoryScopedKeys(
|
||||
txCtx,
|
||||
txClient,
|
||||
txAwareSQLExecutor(txCtx, r.sql, r.client),
|
||||
normalizedEmailUniquenessLockKey(userIn.Email),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer releaseEmailLock()
|
||||
|
||||
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
@@ -203,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
if userIn.BalanceNotifyThreshold == nil {
|
||||
updateOp = updateOp.ClearBalanceNotifyThreshold()
|
||||
}
|
||||
updated, err := updateOp.Save(ctx)
|
||||
updated, err := updateOp.Save(txCtx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
|
||||
if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -711,7 +740,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
|
||||
}
|
||||
|
||||
func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error {
|
||||
matches, err := r.client.User.Query().
|
||||
return ensureNormalizedEmailAvailableWithClient(ctx, clientFromContext(ctx, r.client), userID, email)
|
||||
}
|
||||
|
||||
func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
|
||||
client = clientFromContext(ctx, client)
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
matches, err := client.User.Query().
|
||||
Where(userEmailLookupPredicate(email)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
@@ -726,7 +764,7 @@ func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, use
|
||||
}
|
||||
|
||||
func userEmailLookupPredicate(email string) predicate.User {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
normalized := normalizeEmailLookupValue(email)
|
||||
if normalized == "" {
|
||||
return dbuser.EmailEQ(email)
|
||||
}
|
||||
@@ -740,6 +778,18 @@ func userEmailLookupPredicate(email string) predicate.User {
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeEmailLookupValue(email string) string {
|
||||
return strings.ToLower(strings.TrimSpace(email))
|
||||
}
|
||||
|
||||
func normalizedEmailUniquenessLockKey(email string) string {
|
||||
normalized := normalizeEmailLookupValue(email)
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
return "users:normalized-email:" + normalized
|
||||
}
|
||||
|
||||
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
err := client.UserAllowedGroup.Create().
|
||||
@@ -874,11 +924,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
}
|
||||
|
||||
func userSignupSourceOrDefault(signupSource string) string {
|
||||
signupSource = strings.TrimSpace(signupSource)
|
||||
if signupSource == "" {
|
||||
switch strings.TrimSpace(strings.ToLower(signupSource)) {
|
||||
case "", "email":
|
||||
return "email"
|
||||
case "linuxdo", "wechat", "oidc":
|
||||
return strings.TrimSpace(strings.ToLower(signupSource))
|
||||
default:
|
||||
return "email"
|
||||
}
|
||||
return signupSource
|
||||
}
|
||||
|
||||
// marshalExtraEmails serializes notify email entries to JSON for storage.
|
||||
|
||||
@@ -3,7 +3,10 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
@@ -18,9 +21,10 @@ import (
|
||||
func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared")
|
||||
db, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name()))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
db.SetMaxOpenConns(10)
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
@@ -144,3 +148,80 @@ func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "normalized email lookup matched multiple users")
|
||||
}
|
||||
|
||||
func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency(t *testing.T) {
|
||||
repo, client := newUserEntRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
firstCreateStarted := make(chan struct{})
|
||||
releaseFirstCreate := make(chan struct{})
|
||||
var firstCreate sync.Once
|
||||
client.User.Use(func(next dbent.Mutator) dbent.Mutator {
|
||||
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
|
||||
blocked := false
|
||||
if m.Op().Is(dbent.OpCreate) {
|
||||
firstCreate.Do(func() {
|
||||
blocked = true
|
||||
close(firstCreateStarted)
|
||||
})
|
||||
}
|
||||
if blocked {
|
||||
<-releaseFirstCreate
|
||||
}
|
||||
return next.Mutate(ctx, m)
|
||||
})
|
||||
})
|
||||
|
||||
type createResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
results := make(chan createResult, 2)
|
||||
go func() {
|
||||
results <- createResult{err: repo.Create(ctx, &service.User{
|
||||
Email: " Race@Example.com ",
|
||||
Username: "race-user-1",
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})}
|
||||
}()
|
||||
|
||||
<-firstCreateStarted
|
||||
|
||||
go func() {
|
||||
results <- createResult{err: repo.Create(ctx, &service.User{
|
||||
Email: "race@example.com",
|
||||
Username: "race-user-2",
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
close(releaseFirstCreate)
|
||||
|
||||
first := <-results
|
||||
second := <-results
|
||||
|
||||
errors := []error{first.err, second.err}
|
||||
successes := 0
|
||||
conflicts := 0
|
||||
for _, err := range errors {
|
||||
switch {
|
||||
case err == nil:
|
||||
successes++
|
||||
case err == service.ErrEmailExists:
|
||||
conflicts++
|
||||
default:
|
||||
t.Fatalf("unexpected create error: %v", err)
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, successes)
|
||||
require.Equal(t, 1, conflicts)
|
||||
|
||||
count, err := client.User.Query().Where(userEmailLookupPredicate("race@example.com")).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user