mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-19 14:24:45 +08:00
feat: unified OAuth token refresh API with distributed locking
Introduce OAuthRefreshAPI as the single entry point for all OAuth token refresh operations, eliminating the race condition where background refresh and inline refresh could simultaneously use the same refresh_token (fixes #1035). Key changes: - Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey - Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow - Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types - Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI - Rewrite TokenRefreshService.refreshWithRetry to use unified API path - Add MergeCredentials and BuildClaudeAccountCredentials helpers - Add 40 unit tests covering all new and modified code paths
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -15,14 +14,17 @@ const (
|
||||
claudeLockWaitTime = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
// ClaudeTokenCache token cache interface.
|
||||
type ClaudeTokenCache = GeminiTokenCache
|
||||
|
||||
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
|
||||
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
|
||||
type ClaudeTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache ClaudeTokenCache
|
||||
oauthService *OAuthService
|
||||
accountRepo AccountRepository
|
||||
tokenCache ClaudeTokenCache
|
||||
oauthService *OAuthService
|
||||
refreshAPI *OAuthRefreshAPI
|
||||
executor OAuthRefreshExecutor
|
||||
refreshPolicy ProviderRefreshPolicy
|
||||
}
|
||||
|
||||
func NewClaudeTokenProvider(
|
||||
@@ -31,13 +33,25 @@ func NewClaudeTokenProvider(
|
||||
oauthService *OAuthService,
|
||||
) *ClaudeTokenProvider {
|
||||
return &ClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
oauthService: oauthService,
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
oauthService: oauthService,
|
||||
refreshPolicy: ClaudeProviderRefreshPolicy(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||
func (p *ClaudeTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||
p.refreshAPI = api
|
||||
p.executor = executor
|
||||
}
|
||||
|
||||
// SetRefreshPolicy injects caller-side refresh policy.
|
||||
func (p *ClaudeTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||
p.refreshPolicy = policy
|
||||
}
|
||||
|
||||
// GetAccessToken returns a valid access_token.
|
||||
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
@@ -48,7 +62,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
// 1) Try cache first.
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
|
||||
@@ -60,114 +74,39 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
|
||||
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
// 2) Refresh if needed (pre-expiry skew).
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, claudeTokenRefreshSkew)
|
||||
if err != nil {
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
// 构建新 credentials,保留原有字段
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if lockErr != nil {
|
||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||
|
||||
// 检查 ctx 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
if p.accountRepo != nil {
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
|
||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
// 构建新 credentials,保留原有字段
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true
|
||||
} else if result.LockHeld {
|
||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||
time.Sleep(claudeLockWaitTime)
|
||||
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
} else if needsRefresh && p.tokenCache != nil {
|
||||
// Backward-compatible test path when refreshAPI is not injected.
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
} else if lockErr != nil {
|
||||
slog.Warn("claude_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||
} else {
|
||||
time.Sleep(claudeLockWaitTime)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
@@ -181,22 +120,23 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
// 3) Populate cache with TTL.
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
if p.refreshPolicy.FailureTTL > 0 {
|
||||
ttl = p.refreshPolicy.FailureTTL
|
||||
} else {
|
||||
ttl = time.Minute
|
||||
}
|
||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
|
||||
Reference in New Issue
Block a user