package service import ( "context" "fmt" "log/slog" "strconv" "time" ) // OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器 // TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁 type OAuthRefreshExecutor interface { TokenRefresher // CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致) CacheKey(account *Account) string } const refreshLockTTL = 30 * time.Second // OAuthRefreshResult 统一刷新结果 type OAuthRefreshResult struct { Refreshed bool // 实际执行了刷新 NewCredentials map[string]any // 刷新后的 credentials(nil 表示未刷新) Account *Account // 从 DB 重新读取的最新 account LockHeld bool // 锁被其他 worker 持有(未执行刷新) } // OAuthRefreshAPI 统一的 OAuth Token 刷新入口 // 封装分布式锁、DB 重读、已刷新检查等通用逻辑 type OAuthRefreshAPI struct { accountRepo AccountRepository tokenCache GeminiTokenCache // 可选,nil = 无锁 } // NewOAuthRefreshAPI 创建统一刷新 API func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI { return &OAuthRefreshAPI{ accountRepo: accountRepo, tokenCache: tokenCache, } } // RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token // // 流程: // 1. 获取分布式锁 // 2. 从 DB 重读最新 account(防止使用过时的 refresh_token) // 3. 二次检查是否仍需刷新 // 4. 调用 executor.Refresh() 执行平台特定刷新逻辑 // 5. 设置 _token_version + 更新 DB // 6. 释放锁 func (api *OAuthRefreshAPI) RefreshIfNeeded( ctx context.Context, account *Account, executor OAuthRefreshExecutor, refreshWindow time.Duration, ) (*OAuthRefreshResult, error) { cacheKey := executor.CacheKey(account) // 1. 获取分布式锁 lockAcquired := false if api.tokenCache != nil { acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL) if lockErr != nil { // Redis 错误,降级为无锁刷新 slog.Warn("oauth_refresh_lock_failed_degraded", "account_id", account.ID, "cache_key", cacheKey, "error", lockErr, ) } else if !acquired { // 锁被其他 worker 持有 return &OAuthRefreshResult{LockHeld: true}, nil } else { lockAcquired = true defer func() { _ = api.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() } } // 2. 从 DB 重读最新 account(锁保护下,确保使用最新的 refresh_token) freshAccount, err := api.accountRepo.GetByID(ctx, account.ID) if err != nil { slog.Warn("oauth_refresh_db_reread_failed", "account_id", account.ID, "error", err, ) // 降级使用传入的 account freshAccount = account } else if freshAccount == nil { freshAccount = account } // 3. 二次检查是否仍需刷新(另一条路径可能已刷新) if !executor.NeedsRefresh(freshAccount, refreshWindow) { return &OAuthRefreshResult{ Account: freshAccount, }, nil } // 4. 执行平台特定刷新逻辑 newCredentials, refreshErr := executor.Refresh(ctx, freshAccount) if refreshErr != nil { return nil, refreshErr } // 5. 设置版本号 + 更新 DB if newCredentials != nil { newCredentials["_token_version"] = time.Now().UnixMilli() freshAccount.Credentials = newCredentials if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil { slog.Error("oauth_refresh_update_failed", "account_id", freshAccount.ID, "error", updateErr, ) return nil, fmt.Errorf("oauth refresh succeeded but DB update failed: %w", updateErr) } } _ = lockAcquired // suppress unused warning when tokenCache is nil return &OAuthRefreshResult{ Refreshed: true, NewCredentials: newCredentials, Account: freshAccount, }, nil } // MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中 func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any { if newCreds == nil { newCreds = make(map[string]any) } for k, v := range oldCreds { if _, exists := newCreds[k]; !exists { newCreds[k] = v } } return newCreds } // BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map // 消除 Claude 平台没有 BuildAccountCredentials 方法的问题 func BuildClaudeAccountCredentials(tokenInfo *TokenInfo) map[string]any { creds := map[string]any{ "access_token": tokenInfo.AccessToken, "token_type": tokenInfo.TokenType, "expires_in": strconv.FormatInt(tokenInfo.ExpiresIn, 10), "expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10), } if tokenInfo.RefreshToken != "" { creds["refresh_token"] = tokenInfo.RefreshToken } if tokenInfo.Scope != "" { creds["scope"] = tokenInfo.Scope } return creds }