2026-04-20 17:39:57 +08:00
package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
2026-04-22 14:56:56 +08:00
"errors"
2026-04-20 17:39:57 +08:00
"fmt"
2026-04-22 14:56:56 +08:00
"hash/fnv"
"sort"
2026-04-20 17:39:57 +08:00
"strings"
2026-04-22 14:56:56 +08:00
"sync"
2026-04-20 17:39:57 +08:00
"time"
2026-04-22 14:56:56 +08:00
"entgo.io/ent/dialect"
2026-04-20 17:39:57 +08:00
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
2026-04-22 02:32:53 +08:00
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
2026-04-20 17:39:57 +08:00
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
2026-04-21 07:48:24 -07:00
entsql "entgo.io/ent/dialect/sql"
2026-04-20 17:39:57 +08:00
)
var (
ErrPendingAuthSessionNotFound = infraerrors . NotFound ( "PENDING_AUTH_SESSION_NOT_FOUND" , "pending auth session not found" )
ErrPendingAuthSessionExpired = infraerrors . Unauthorized ( "PENDING_AUTH_SESSION_EXPIRED" , "pending auth session has expired" )
ErrPendingAuthSessionConsumed = infraerrors . Unauthorized ( "PENDING_AUTH_SESSION_CONSUMED" , "pending auth session has already been used" )
ErrPendingAuthCodeInvalid = infraerrors . Unauthorized ( "PENDING_AUTH_CODE_INVALID" , "pending auth completion code is invalid" )
ErrPendingAuthCodeExpired = infraerrors . Unauthorized ( "PENDING_AUTH_CODE_EXPIRED" , "pending auth completion code has expired" )
ErrPendingAuthCodeConsumed = infraerrors . Unauthorized ( "PENDING_AUTH_CODE_CONSUMED" , "pending auth completion code has already been used" )
ErrPendingAuthBrowserMismatch = infraerrors . Unauthorized ( "PENDING_AUTH_BROWSER_MISMATCH" , "pending auth completion code does not match this browser session" )
)
const (
defaultPendingAuthTTL = 15 * time . Minute
defaultPendingAuthCompletionTTL = 5 * time . Minute
)
type PendingAuthIdentityKey struct {
ProviderType string
ProviderKey string
ProviderSubject string
}
type CreatePendingAuthSessionInput struct {
SessionToken string
Intent string
Identity PendingAuthIdentityKey
TargetUserID * int64
RedirectTo string
ResolvedEmail string
RegistrationPasswordHash string
BrowserSessionKey string
UpstreamIdentityClaims map [ string ] any
LocalFlowState map [ string ] any
ExpiresAt time . Time
}
type IssuePendingAuthCompletionCodeInput struct {
PendingAuthSessionID int64
BrowserSessionKey string
TTL time . Duration
}
type IssuePendingAuthCompletionCodeResult struct {
Code string
ExpiresAt time . Time
}
type PendingIdentityAdoptionDecisionInput struct {
PendingAuthSessionID int64
IdentityID * int64
AdoptDisplayName bool
AdoptAvatar bool
}
type AuthPendingIdentityService struct {
entClient * dbent . Client
}
2026-04-22 14:56:56 +08:00
var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry ( )
type authPendingIdentityScopedKeyLockRegistry struct {
mu sync . Mutex
locks map [ string ] * authPendingIdentityScopedKeyLockEntry
}
type authPendingIdentityScopedKeyLockEntry struct {
mu sync . Mutex
refs int
}
func newAuthPendingIdentityScopedKeyLockRegistry ( ) * authPendingIdentityScopedKeyLockRegistry {
return & authPendingIdentityScopedKeyLockRegistry {
locks : make ( map [ string ] * authPendingIdentityScopedKeyLockEntry ) ,
}
}
func ( r * authPendingIdentityScopedKeyLockRegistry ) lock ( keys ... string ) func ( ) {
normalized := normalizeAuthPendingIdentityLockKeys ( keys ... )
if len ( normalized ) == 0 {
return func ( ) { }
}
entries := make ( [ ] * authPendingIdentityScopedKeyLockEntry , 0 , len ( normalized ) )
r . mu . Lock ( )
for _ , key := range normalized {
entry := r . locks [ key ]
if entry == nil {
entry = & authPendingIdentityScopedKeyLockEntry { }
r . locks [ key ] = entry
}
entry . refs ++
entries = append ( entries , entry )
}
r . mu . Unlock ( )
for _ , entry := range entries {
entry . mu . Lock ( )
}
return func ( ) {
for i := len ( entries ) - 1 ; i >= 0 ; i -- {
entries [ i ] . mu . Unlock ( )
}
r . mu . Lock ( )
defer r . mu . Unlock ( )
for idx , key := range normalized {
entry := entries [ idx ]
entry . refs --
if entry . refs == 0 {
delete ( r . locks , key )
}
}
}
}
func normalizeAuthPendingIdentityLockKeys ( keys ... string ) [ ] string {
if len ( keys ) == 0 {
return nil
}
deduped := make ( map [ string ] struct { } , len ( keys ) )
for _ , key := range keys {
trimmed := strings . TrimSpace ( key )
if trimmed == "" {
continue
}
deduped [ trimmed ] = struct { } { }
}
if len ( deduped ) == 0 {
return nil
}
normalized := make ( [ ] string , 0 , len ( deduped ) )
for key := range deduped {
normalized = append ( normalized , key )
}
sort . Strings ( normalized )
return normalized
}
func authPendingIdentityAdvisoryLockHash ( key string ) int64 {
hasher := fnv . New64a ( )
_ , _ = hasher . Write ( [ ] byte ( key ) )
return int64 ( hasher . Sum64 ( ) )
}
func lockAuthPendingIdentityKeys ( ctx context . Context , client * dbent . Client , keys ... string ) ( func ( ) , error ) {
release := authPendingIdentityScopedKeyLocks . lock ( keys ... )
normalized := normalizeAuthPendingIdentityLockKeys ( keys ... )
if len ( normalized ) == 0 || client == nil || client . Driver ( ) . Dialect ( ) != dialect . Postgres {
return release , nil
}
for _ , key := range normalized {
var rows entsql . Rows
if err := client . Driver ( ) . Query ( ctx , "SELECT pg_advisory_xact_lock($1)" , [ ] any { authPendingIdentityAdvisoryLockHash ( key ) } , & rows ) ; err != nil {
release ( )
return nil , err
}
_ = rows . Close ( )
}
return release , nil
}
func pendingIdentityAdoptionLockKeys ( pendingAuthSessionID int64 , identityID * int64 ) [ ] string {
keys := [ ] string { fmt . Sprintf ( "pending-auth-adoption:pending:%d" , pendingAuthSessionID ) }
if identityID != nil && * identityID > 0 {
keys = append ( keys , fmt . Sprintf ( "pending-auth-adoption:identity:%d" , * identityID ) )
}
return keys
}
2026-04-20 17:39:57 +08:00
func NewAuthPendingIdentityService ( entClient * dbent . Client ) * AuthPendingIdentityService {
return & AuthPendingIdentityService { entClient : entClient }
}
func ( s * AuthPendingIdentityService ) CreatePendingSession ( ctx context . Context , input CreatePendingAuthSessionInput ) ( * dbent . PendingAuthSession , error ) {
if s == nil || s . entClient == nil {
return nil , fmt . Errorf ( "pending auth ent client is not configured" )
}
sessionToken := strings . TrimSpace ( input . SessionToken )
if sessionToken == "" {
var err error
sessionToken , err = randomOpaqueToken ( 24 )
if err != nil {
return nil , err
}
}
expiresAt := input . ExpiresAt . UTC ( )
if expiresAt . IsZero ( ) {
expiresAt = time . Now ( ) . UTC ( ) . Add ( defaultPendingAuthTTL )
}
create := s . entClient . PendingAuthSession . Create ( ) .
SetSessionToken ( sessionToken ) .
SetIntent ( strings . TrimSpace ( input . Intent ) ) .
SetProviderType ( strings . TrimSpace ( input . Identity . ProviderType ) ) .
SetProviderKey ( strings . TrimSpace ( input . Identity . ProviderKey ) ) .
SetProviderSubject ( strings . TrimSpace ( input . Identity . ProviderSubject ) ) .
SetRedirectTo ( strings . TrimSpace ( input . RedirectTo ) ) .
SetResolvedEmail ( strings . TrimSpace ( input . ResolvedEmail ) ) .
SetRegistrationPasswordHash ( strings . TrimSpace ( input . RegistrationPasswordHash ) ) .
SetBrowserSessionKey ( strings . TrimSpace ( input . BrowserSessionKey ) ) .
SetUpstreamIdentityClaims ( copyPendingMap ( input . UpstreamIdentityClaims ) ) .
SetLocalFlowState ( copyPendingMap ( input . LocalFlowState ) ) .
SetExpiresAt ( expiresAt )
if input . TargetUserID != nil {
create = create . SetTargetUserID ( * input . TargetUserID )
}
return create . Save ( ctx )
}
func ( s * AuthPendingIdentityService ) IssueCompletionCode ( ctx context . Context , input IssuePendingAuthCompletionCodeInput ) ( * IssuePendingAuthCompletionCodeResult , error ) {
if s == nil || s . entClient == nil {
return nil , fmt . Errorf ( "pending auth ent client is not configured" )
}
session , err := s . entClient . PendingAuthSession . Get ( ctx , input . PendingAuthSessionID )
if err != nil {
if dbent . IsNotFound ( err ) {
return nil , ErrPendingAuthSessionNotFound
}
return nil , err
}
code , err := randomOpaqueToken ( 24 )
if err != nil {
return nil , err
}
ttl := input . TTL
if ttl <= 0 {
ttl = defaultPendingAuthCompletionTTL
}
expiresAt := time . Now ( ) . UTC ( ) . Add ( ttl )
update := s . entClient . PendingAuthSession . UpdateOneID ( session . ID ) .
SetCompletionCodeHash ( hashPendingAuthCode ( code ) ) .
SetCompletionCodeExpiresAt ( expiresAt )
if strings . TrimSpace ( input . BrowserSessionKey ) != "" {
update = update . SetBrowserSessionKey ( strings . TrimSpace ( input . BrowserSessionKey ) )
}
if _ , err := update . Save ( ctx ) ; err != nil {
return nil , err
}
return & IssuePendingAuthCompletionCodeResult {
Code : code ,
ExpiresAt : expiresAt ,
} , nil
}
func ( s * AuthPendingIdentityService ) ConsumeCompletionCode ( ctx context . Context , rawCode , browserSessionKey string ) ( * dbent . PendingAuthSession , error ) {
if s == nil || s . entClient == nil {
return nil , fmt . Errorf ( "pending auth ent client is not configured" )
}
codeHash := hashPendingAuthCode ( strings . TrimSpace ( rawCode ) )
session , err := s . entClient . PendingAuthSession . Query ( ) .
Where ( pendingauthsession . CompletionCodeHashEQ ( codeHash ) ) .
Only ( ctx )
if err != nil {
if dbent . IsNotFound ( err ) {
return nil , ErrPendingAuthCodeInvalid
}
return nil , err
}
return s . consumeSession ( ctx , session , browserSessionKey , ErrPendingAuthCodeExpired , ErrPendingAuthCodeConsumed )
}
func ( s * AuthPendingIdentityService ) ConsumeBrowserSession ( ctx context . Context , sessionToken , browserSessionKey string ) ( * dbent . PendingAuthSession , error ) {
if s == nil || s . entClient == nil {
return nil , fmt . Errorf ( "pending auth ent client is not configured" )
}
session , err := s . getBrowserSession ( ctx , sessionToken )
if err != nil {
return nil , err
}
return s . consumeSession ( ctx , session , browserSessionKey , ErrPendingAuthSessionExpired , ErrPendingAuthSessionConsumed )
}
func ( s * AuthPendingIdentityService ) GetBrowserSession ( ctx context . Context , sessionToken , browserSessionKey string ) ( * dbent . PendingAuthSession , error ) {
if s == nil || s . entClient == nil {
return nil , fmt . Errorf ( "pending auth ent client is not configured" )
}
session , err := s . getBrowserSession ( ctx , sessionToken )
if err != nil {
return nil , err
}
if err := validatePendingSessionState ( session , browserSessionKey , ErrPendingAuthSessionExpired , ErrPendingAuthSessionConsumed ) ; err != nil {
return nil , err
}
return session , nil
}
func ( s * AuthPendingIdentityService ) getBrowserSession ( ctx context . Context , sessionToken string ) ( * dbent . PendingAuthSession , error ) {
if s == nil || s . entClient == nil {
return nil , fmt . Errorf ( "pending auth ent client is not configured" )
}
sessionToken = strings . TrimSpace ( sessionToken )
if sessionToken == "" {
return nil , ErrPendingAuthSessionNotFound
}
session , err := s . entClient . PendingAuthSession . Query ( ) .
Where ( pendingauthsession . SessionTokenEQ ( sessionToken ) ) .
Only ( ctx )
if err != nil {
if dbent . IsNotFound ( err ) {
return nil , ErrPendingAuthSessionNotFound
}
return nil , err
}
return session , nil
}
func ( s * AuthPendingIdentityService ) consumeSession (
ctx context . Context ,
session * dbent . PendingAuthSession ,
browserSessionKey string ,
expiredErr error ,
consumedErr error ,
) ( * dbent . PendingAuthSession , error ) {
if err := validatePendingSessionState ( session , browserSessionKey , expiredErr , consumedErr ) ; err != nil {
return nil , err
}
2026-04-22 11:29:05 +08:00
sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState ( session . LocalFlowState )
2026-04-20 17:39:57 +08:00
now := time . Now ( ) . UTC ( )
2026-04-22 11:17:38 +08:00
update := s . entClient . PendingAuthSession . UpdateOneID ( session . ID ) .
Where (
pendingauthsession . ConsumedAtIsNil ( ) ,
pendingauthsession . ExpiresAtGTE ( now ) ,
pendingauthsession . Or (
pendingauthsession . CompletionCodeExpiresAtIsNil ( ) ,
pendingauthsession . CompletionCodeExpiresAtGTE ( now ) ,
) ,
) .
2026-04-20 17:39:57 +08:00
SetConsumedAt ( now ) .
2026-04-22 11:29:05 +08:00
SetLocalFlowState ( sanitizedLocalFlowState ) .
2026-04-20 17:39:57 +08:00
SetCompletionCodeHash ( "" ) .
2026-04-22 11:17:38 +08:00
ClearCompletionCodeExpiresAt ( )
if expectedBrowserSessionKey := strings . TrimSpace ( session . BrowserSessionKey ) ; expectedBrowserSessionKey != "" {
update = update . Where ( pendingauthsession . BrowserSessionKeyEQ ( expectedBrowserSessionKey ) )
}
updated , err := update . Save ( ctx )
if err == nil {
return updated , nil
}
if ! dbent . IsNotFound ( err ) {
return nil , err
}
current , currentErr := s . entClient . PendingAuthSession . Get ( ctx , session . ID )
if currentErr != nil {
if dbent . IsNotFound ( currentErr ) {
return nil , ErrPendingAuthSessionNotFound
}
return nil , currentErr
}
if err := validatePendingSessionState ( current , browserSessionKey , expiredErr , consumedErr ) ; err != nil {
2026-04-20 17:39:57 +08:00
return nil , err
}
2026-04-22 11:17:38 +08:00
return nil , consumedErr
2026-04-20 17:39:57 +08:00
}
2026-04-22 11:29:05 +08:00
func sanitizePendingAuthLocalFlowState ( localFlowState map [ string ] any ) map [ string ] any {
sanitized := copyPendingMap ( localFlowState )
if len ( sanitized ) == 0 {
return sanitized
}
rawCompletion , ok := sanitized [ "completion_response" ]
if ! ok {
return sanitized
}
completion , ok := rawCompletion . ( map [ string ] any )
if ! ok {
return sanitized
}
cleanedCompletion := copyPendingMap ( completion )
for _ , key := range [ ] string { "access_token" , "refresh_token" , "expires_in" , "token_type" } {
delete ( cleanedCompletion , key )
}
sanitized [ "completion_response" ] = cleanedCompletion
return sanitized
}
2026-04-20 17:39:57 +08:00
func validatePendingSessionState ( session * dbent . PendingAuthSession , browserSessionKey string , expiredErr error , consumedErr error ) error {
if session == nil {
return ErrPendingAuthSessionNotFound
}
now := time . Now ( ) . UTC ( )
if session . ConsumedAt != nil {
return consumedErr
}
if ! session . ExpiresAt . IsZero ( ) && now . After ( session . ExpiresAt ) {
return expiredErr
}
if session . CompletionCodeExpiresAt != nil && now . After ( * session . CompletionCodeExpiresAt ) {
return expiredErr
}
if strings . TrimSpace ( session . BrowserSessionKey ) != "" && strings . TrimSpace ( browserSessionKey ) != strings . TrimSpace ( session . BrowserSessionKey ) {
return ErrPendingAuthBrowserMismatch
}
return nil
}
func ( s * AuthPendingIdentityService ) UpsertAdoptionDecision ( ctx context . Context , input PendingIdentityAdoptionDecisionInput ) ( * dbent . IdentityAdoptionDecision , error ) {
if s == nil || s . entClient == nil {
return nil , fmt . Errorf ( "pending auth ent client is not configured" )
}
2026-04-22 14:56:56 +08:00
tx , err := s . entClient . Tx ( ctx )
if err != nil && ! errors . Is ( err , dbent . ErrTxStarted ) {
return nil , err
}
client := s . entClient
txCtx := ctx
if err == nil {
defer func ( ) { _ = tx . Rollback ( ) } ( )
client = tx . Client ( )
txCtx = dbent . NewTxContext ( ctx , tx )
} else if existingTx := dbent . TxFromContext ( ctx ) ; existingTx != nil {
client = existingTx . Client ( )
}
releaseLocks , err := lockAuthPendingIdentityKeys ( txCtx , client , pendingIdentityAdoptionLockKeys ( input . PendingAuthSessionID , input . IdentityID ) ... )
if err != nil {
return nil , err
}
defer releaseLocks ( )
2026-04-21 07:48:24 -07:00
if input . IdentityID != nil && * input . IdentityID > 0 {
2026-04-22 14:56:56 +08:00
if _ , err := client . IdentityAdoptionDecision . Update ( ) .
2026-04-21 07:48:24 -07:00
Where (
identityadoptiondecision . IdentityIDEQ ( * input . IdentityID ) ,
dbpredicate . IdentityAdoptionDecision ( func ( s * entsql . Selector ) {
col := s . C ( identityadoptiondecision . FieldPendingAuthSessionID )
s . Where ( entsql . Or (
entsql . IsNull ( col ) ,
entsql . NEQ ( col , input . PendingAuthSessionID ) ,
) )
} ) ,
) .
ClearIdentityID ( ) .
2026-04-22 14:56:56 +08:00
Save ( txCtx ) ; err != nil {
2026-04-21 07:48:24 -07:00
return nil , err
}
}
2026-04-22 14:56:56 +08:00
create := client . IdentityAdoptionDecision . Create ( ) .
SetPendingAuthSessionID ( input . PendingAuthSessionID ) .
SetAdoptDisplayName ( input . AdoptDisplayName ) .
SetAdoptAvatar ( input . AdoptAvatar ) .
SetDecidedAt ( time . Now ( ) . UTC ( ) )
if input . IdentityID != nil && * input . IdentityID > 0 {
create = create . SetIdentityID ( * input . IdentityID )
}
decisionID , err := create .
OnConflictColumns ( identityadoptiondecision . FieldPendingAuthSessionID ) .
UpdateNewValues ( ) .
ID ( txCtx )
if err != nil {
2026-04-20 17:39:57 +08:00
return nil , err
}
2026-04-22 14:56:56 +08:00
decision , err := client . IdentityAdoptionDecision . Get ( txCtx , decisionID )
if err != nil {
return nil , err
2026-04-20 17:39:57 +08:00
}
2026-04-22 14:56:56 +08:00
if tx != nil {
if err := tx . Commit ( ) ; err != nil {
return nil , err
}
2026-04-20 17:39:57 +08:00
}
2026-04-22 14:56:56 +08:00
return decision , nil
2026-04-20 17:39:57 +08:00
}
func copyPendingMap ( in map [ string ] any ) map [ string ] any {
if len ( in ) == 0 {
return map [ string ] any { }
}
out := make ( map [ string ] any , len ( in ) )
for k , v := range in {
out [ k ] = v
}
return out
}
func randomOpaqueToken ( byteLen int ) ( string , error ) {
if byteLen <= 0 {
byteLen = 16
}
buf := make ( [ ] byte , byteLen )
if _ , err := rand . Read ( buf ) ; err != nil {
return "" , err
}
return hex . EncodeToString ( buf ) , nil
}
func hashPendingAuthCode ( code string ) string {
sum := sha256 . Sum256 ( [ ] byte ( code ) )
return hex . EncodeToString ( sum [ : ] )
}