mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-02 22:42:14 +08:00
Introduce OAuthRefreshAPI as the single entry point for all OAuth token refresh operations, eliminating the race condition where background refresh and inline refresh could simultaneously use the same refresh_token (fixes #1035). Key changes: - Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey - Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow - Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types - Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI - Rewrite TokenRefreshService.refreshWithRetry to use unified API path - Add MergeCredentials and BuildClaudeAccountCredentials helpers - Add 40 unit tests covering all new and modified code paths
160 lines
4.7 KiB
Go
160 lines
4.7 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"log/slog"
|
||
"strconv"
|
||
"time"
|
||
)
|
||
|
||
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
|
||
// TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁
|
||
type OAuthRefreshExecutor interface {
|
||
TokenRefresher
|
||
|
||
// CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致)
|
||
CacheKey(account *Account) string
|
||
}
|
||
|
||
const refreshLockTTL = 30 * time.Second
|
||
|
||
// OAuthRefreshResult 统一刷新结果
|
||
type OAuthRefreshResult struct {
|
||
Refreshed bool // 实际执行了刷新
|
||
NewCredentials map[string]any // 刷新后的 credentials(nil 表示未刷新)
|
||
Account *Account // 从 DB 重新读取的最新 account
|
||
LockHeld bool // 锁被其他 worker 持有(未执行刷新)
|
||
}
|
||
|
||
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
|
||
// 封装分布式锁、DB 重读、已刷新检查等通用逻辑
|
||
type OAuthRefreshAPI struct {
|
||
accountRepo AccountRepository
|
||
tokenCache GeminiTokenCache // 可选,nil = 无锁
|
||
}
|
||
|
||
// NewOAuthRefreshAPI 创建统一刷新 API
|
||
func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
|
||
return &OAuthRefreshAPI{
|
||
accountRepo: accountRepo,
|
||
tokenCache: tokenCache,
|
||
}
|
||
}
|
||
|
||
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
|
||
//
|
||
// 流程:
|
||
// 1. 获取分布式锁
|
||
// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token)
|
||
// 3. 二次检查是否仍需刷新
|
||
// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑
|
||
// 5. 设置 _token_version + 更新 DB
|
||
// 6. 释放锁
|
||
func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
||
ctx context.Context,
|
||
account *Account,
|
||
executor OAuthRefreshExecutor,
|
||
refreshWindow time.Duration,
|
||
) (*OAuthRefreshResult, error) {
|
||
cacheKey := executor.CacheKey(account)
|
||
|
||
// 1. 获取分布式锁
|
||
lockAcquired := false
|
||
if api.tokenCache != nil {
|
||
acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL)
|
||
if lockErr != nil {
|
||
// Redis 错误,降级为无锁刷新
|
||
slog.Warn("oauth_refresh_lock_failed_degraded",
|
||
"account_id", account.ID,
|
||
"cache_key", cacheKey,
|
||
"error", lockErr,
|
||
)
|
||
} else if !acquired {
|
||
// 锁被其他 worker 持有
|
||
return &OAuthRefreshResult{LockHeld: true}, nil
|
||
} else {
|
||
lockAcquired = true
|
||
defer func() { _ = api.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||
}
|
||
}
|
||
|
||
// 2. 从 DB 重读最新 account(锁保护下,确保使用最新的 refresh_token)
|
||
freshAccount, err := api.accountRepo.GetByID(ctx, account.ID)
|
||
if err != nil {
|
||
slog.Warn("oauth_refresh_db_reread_failed",
|
||
"account_id", account.ID,
|
||
"error", err,
|
||
)
|
||
// 降级使用传入的 account
|
||
freshAccount = account
|
||
} else if freshAccount == nil {
|
||
freshAccount = account
|
||
}
|
||
|
||
// 3. 二次检查是否仍需刷新(另一条路径可能已刷新)
|
||
if !executor.NeedsRefresh(freshAccount, refreshWindow) {
|
||
return &OAuthRefreshResult{
|
||
Account: freshAccount,
|
||
}, nil
|
||
}
|
||
|
||
// 4. 执行平台特定刷新逻辑
|
||
newCredentials, refreshErr := executor.Refresh(ctx, freshAccount)
|
||
if refreshErr != nil {
|
||
return nil, refreshErr
|
||
}
|
||
|
||
// 5. 设置版本号 + 更新 DB
|
||
if newCredentials != nil {
|
||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||
freshAccount.Credentials = newCredentials
|
||
if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil {
|
||
slog.Error("oauth_refresh_update_failed",
|
||
"account_id", freshAccount.ID,
|
||
"error", updateErr,
|
||
)
|
||
return nil, fmt.Errorf("oauth refresh succeeded but DB update failed: %w", updateErr)
|
||
}
|
||
}
|
||
|
||
_ = lockAcquired // suppress unused warning when tokenCache is nil
|
||
|
||
return &OAuthRefreshResult{
|
||
Refreshed: true,
|
||
NewCredentials: newCredentials,
|
||
Account: freshAccount,
|
||
}, nil
|
||
}
|
||
|
||
// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中
|
||
func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any {
|
||
if newCreds == nil {
|
||
newCreds = make(map[string]any)
|
||
}
|
||
for k, v := range oldCreds {
|
||
if _, exists := newCreds[k]; !exists {
|
||
newCreds[k] = v
|
||
}
|
||
}
|
||
return newCreds
|
||
}
|
||
|
||
// BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map
|
||
// 消除 Claude 平台没有 BuildAccountCredentials 方法的问题
|
||
func BuildClaudeAccountCredentials(tokenInfo *TokenInfo) map[string]any {
|
||
creds := map[string]any{
|
||
"access_token": tokenInfo.AccessToken,
|
||
"token_type": tokenInfo.TokenType,
|
||
"expires_in": strconv.FormatInt(tokenInfo.ExpiresIn, 10),
|
||
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
|
||
}
|
||
if tokenInfo.RefreshToken != "" {
|
||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||
}
|
||
if tokenInfo.Scope != "" {
|
||
creds["scope"] = tokenInfo.Scope
|
||
}
|
||
return creds
|
||
}
|