From 1fc9dd7b68a61f49bf6dac9e4073ee6ad1016c2c Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 16 Mar 2026 01:31:54 +0800 Subject: [PATCH 1/2] feat: unified OAuth token refresh API with distributed locking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce OAuthRefreshAPI as the single entry point for all OAuth token refresh operations, eliminating the race condition where background refresh and inline refresh could simultaneously use the same refresh_token (fixes #1035). Key changes: - Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey - Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow - Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types - Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI - Rewrite TokenRefreshService.refreshWithRetry to use unified API path - Add MergeCredentials and BuildClaudeAccountCredentials helpers - Add 40 unit tests covering all new and modified code paths --- backend/cmd/server/wire_gen.go | 11 +- .../service/antigravity_token_provider.go | 103 +++-- .../service/antigravity_token_refresher.go | 11 +- .../internal/service/claude_token_provider.go | 172 +++----- .../internal/service/gemini_token_provider.go | 77 ++-- .../service/gemini_token_refresher.go | 11 +- backend/internal/service/oauth_refresh_api.go | 159 +++++++ .../service/oauth_refresh_api_test.go | 395 ++++++++++++++++++ .../internal/service/openai_token_provider.go | 176 ++++---- backend/internal/service/refresh_policy.go | 99 +++++ .../internal/service/token_refresh_service.go | 228 ++++++---- .../service/token_refresh_service_test.go | 237 ++++++++++- backend/internal/service/token_refresher.go | 39 +- backend/internal/service/wire.go | 70 +++- 14 files changed, 1336 insertions(+), 452 deletions(-) create mode 100644 backend/internal/service/oauth_refresh_api.go create mode 100644 backend/internal/service/oauth_refresh_api_test.go create mode 100644 backend/internal/service/refresh_policy.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 44f0af08..f632bff3 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -124,6 +124,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tempUnschedCache := repository.NewTempUnschedCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient) + oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache) compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator) httpUpstream := repository.NewHTTPUpstream(configConfig) @@ -132,11 +133,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { usageCache := service.NewUsageCache() identityCache := repository.NewIdentityCache(redisClient) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache) - geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) + geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI) gatewayCache := repository.NewGatewayCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) - antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) + antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) @@ -166,10 +167,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingService := service.NewBillingService(configConfig, pricingService) identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) - claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) + claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService) - openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) + openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) @@ -232,7 +233,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index 068d6a08..9cdc49aa 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -3,7 +3,6 @@ package service import ( "context" "errors" - "log" "log/slog" "strconv" "strings" @@ -17,15 +16,18 @@ const ( antigravityBackfillCooldown = 5 * time.Minute ) -// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) +// AntigravityTokenCache token cache interface. type AntigravityTokenCache = GeminiTokenCache -// AntigravityTokenProvider 管理 Antigravity 账户的 access_token +// AntigravityTokenProvider manages access_token for antigravity accounts. type AntigravityTokenProvider struct { accountRepo AccountRepository tokenCache AntigravityTokenCache antigravityOAuthService *AntigravityOAuthService - backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time + backfillCooldown sync.Map // key: accountID -> last attempt time + refreshAPI *OAuthRefreshAPI + executor OAuthRefreshExecutor + refreshPolicy ProviderRefreshPolicy } func NewAntigravityTokenProvider( @@ -37,10 +39,22 @@ func NewAntigravityTokenProvider( accountRepo: accountRepo, tokenCache: tokenCache, antigravityOAuthService: antigravityOAuthService, + refreshPolicy: AntigravityProviderRefreshPolicy(), } } -// GetAccessToken 获取有效的 access_token +// SetRefreshAPI injects unified OAuth refresh API and executor. +func (p *AntigravityTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) { + p.refreshAPI = api + p.executor = executor +} + +// SetRefreshPolicy injects caller-side refresh policy. +func (p *AntigravityTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) { + p.refreshPolicy = policy +} + +// GetAccessToken returns a valid access_token. func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { if account == nil { return "", errors.New("account is nil") @@ -48,7 +62,8 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * if account.Platform != PlatformAntigravity { return "", errors.New("not an antigravity account") } - // upstream 类型:直接从 credentials 读取 api_key,不走 OAuth 刷新流程 + + // upstream accounts use static api_key and never refresh oauth token. if account.Type == AccountTypeUpstream { apiKey := account.GetCredential("api_key") if apiKey == "" { @@ -62,46 +77,38 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * cacheKey := AntigravityTokenCacheKey(account) - // 1. 先尝试缓存 + // 1) Try cache first. if p.tokenCache != nil { if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { return token, nil } } - // 2. 如果即将过期则刷新 + // 2) Refresh if needed (pre-expiry skew). expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew - if needsRefresh && p.tokenCache != nil { + if needsRefresh && p.refreshAPI != nil && p.executor != nil { + result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, antigravityTokenRefreshSkew) + if err != nil { + if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { + return "", err + } + } else if result.LockHeld { + if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil { + if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + // default policy: continue with existing token. + } else { + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") + } + } else if needsRefresh && p.tokenCache != nil { + // Backward-compatible test path when refreshAPI is not injected. locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) if err == nil && locked { defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() - - // 拿到锁后再次检查缓存(另一个 worker 可能已刷新) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil - } - - // 从数据库获取最新账户信息 - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh - } - expiresAt = account.GetCredentialAsTime("expires_at") - if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew { - if p.antigravityOAuthService == nil { - return "", errors.New("antigravity oauth service not configured") - } - tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account) - if err != nil { - return "", err - } - p.mergeCredentials(account, tokenInfo) - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr) - } - expiresAt = account.GetCredentialAsTime("expires_at") - } } } @@ -110,32 +117,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return "", errors.New("access_token not found in credentials") } - // 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现 - // "Invalid project resource name projects/"。 - // 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。 + // Backfill project_id online when missing, with cooldown to avoid hammering. if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil { if p.shouldAttemptBackfill(account.ID) { p.markBackfillAttempted(account.ID) if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" { account.Credentials["project_id"] = projectID if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr) + slog.Warn("antigravity_project_id_backfill_persist_failed", + "account_id", account.ID, + "error", updateErr, + ) } } } } - // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) + // 3) Populate cache with TTL. if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) if isStale && latestAccount != nil { - // 版本过时,使用 DB 中的最新 token slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID) accessToken = latestAccount.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found after version check") } - // 不写入缓存,让下次请求重新处理 } else { ttl := 30 * time.Minute if expiresAt != nil { @@ -156,18 +162,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return accessToken, nil } -// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段 -func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) { - newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } - account.Credentials = newCredentials -} - -// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试) +// shouldAttemptBackfill checks backfill cooldown. func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool { if v, ok := p.backfillCooldown.Load(accountID); ok { if lastAttempt, ok := v.(time.Time); ok { diff --git a/backend/internal/service/antigravity_token_refresher.go b/backend/internal/service/antigravity_token_refresher.go index e33f88d0..7ce0ccf0 100644 --- a/backend/internal/service/antigravity_token_refresher.go +++ b/backend/internal/service/antigravity_token_refresher.go @@ -25,6 +25,11 @@ func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthServi } } +// CacheKey 返回用于分布式锁的缓存键 +func (r *AntigravityTokenRefresher) CacheKey(account *Account) string { + return AntigravityTokenCacheKey(account) +} + // CanRefresh 检查是否可以刷新此账户 func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool { return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth @@ -58,11 +63,7 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo) // 合并旧的 credentials,保留新 credentials 中不存在的字段 - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } + newCredentials = MergeCredentials(account.Credentials, newCredentials) // 特殊处理 project_id:如果新值为空但旧值非空,保留旧值 // 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失 diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go index f6cab204..82fa31c4 100644 --- a/backend/internal/service/claude_token_provider.go +++ b/backend/internal/service/claude_token_provider.go @@ -4,7 +4,6 @@ import ( "context" "errors" "log/slog" - "strconv" "strings" "time" ) @@ -15,14 +14,17 @@ const ( claudeLockWaitTime = 200 * time.Millisecond ) -// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) +// ClaudeTokenCache token cache interface. type ClaudeTokenCache = GeminiTokenCache -// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token +// ClaudeTokenProvider manages access_token for Claude OAuth accounts. type ClaudeTokenProvider struct { - accountRepo AccountRepository - tokenCache ClaudeTokenCache - oauthService *OAuthService + accountRepo AccountRepository + tokenCache ClaudeTokenCache + oauthService *OAuthService + refreshAPI *OAuthRefreshAPI + executor OAuthRefreshExecutor + refreshPolicy ProviderRefreshPolicy } func NewClaudeTokenProvider( @@ -31,13 +33,25 @@ func NewClaudeTokenProvider( oauthService *OAuthService, ) *ClaudeTokenProvider { return &ClaudeTokenProvider{ - accountRepo: accountRepo, - tokenCache: tokenCache, - oauthService: oauthService, + accountRepo: accountRepo, + tokenCache: tokenCache, + oauthService: oauthService, + refreshPolicy: ClaudeProviderRefreshPolicy(), } } -// GetAccessToken 获取有效的 access_token +// SetRefreshAPI injects unified OAuth refresh API and executor. +func (p *ClaudeTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) { + p.refreshAPI = api + p.executor = executor +} + +// SetRefreshPolicy injects caller-side refresh policy. +func (p *ClaudeTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) { + p.refreshPolicy = policy +} + +// GetAccessToken returns a valid access_token. func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { if account == nil { return "", errors.New("account is nil") @@ -48,7 +62,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou cacheKey := ClaudeTokenCacheKey(account) - // 1. 先尝试缓存 + // 1) Try cache first. if p.tokenCache != nil { if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { slog.Debug("claude_token_cache_hit", "account_id", account.ID) @@ -60,114 +74,39 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou slog.Debug("claude_token_cache_miss", "account_id", account.ID) - // 2. 如果即将过期则刷新 + // 2) Refresh if needed (pre-expiry skew). expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew refreshFailed := false - if needsRefresh && p.tokenCache != nil { - locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) - if lockErr == nil && locked { - defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() - // 拿到锁后再次检查缓存(另一个 worker 可能已刷新) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil + if needsRefresh && p.refreshAPI != nil && p.executor != nil { + result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, claudeTokenRefreshSkew) + if err != nil { + if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { + return "", err } - - // 从数据库获取最新账户信息 - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh - } - expiresAt = account.GetCredentialAsTime("expires_at") - if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew { - if p.oauthService == nil { - slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID) - refreshFailed = true // 无法刷新,标记失败 - } else { - tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account) - if err != nil { - // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token - slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err) - refreshFailed = true // 刷新失败,标记以使用短 TTL - } else { - // 构建新 credentials,保留原有字段 - newCredentials := make(map[string]any) - for k, v := range account.Credentials { - newCredentials[k] = v - } - newCredentials["access_token"] = tokenInfo.AccessToken - newCredentials["token_type"] = tokenInfo.TokenType - newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10) - newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) - if tokenInfo.RefreshToken != "" { - newCredentials["refresh_token"] = tokenInfo.RefreshToken - } - if tokenInfo.Scope != "" { - newCredentials["scope"] = tokenInfo.Scope - } - account.Credentials = newCredentials - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr) - } - expiresAt = account.GetCredentialAsTime("expires_at") - } - } - } - } else if lockErr != nil { - // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时) - slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) - - // 检查 ctx 是否已取消 - if ctx.Err() != nil { - return "", ctx.Err() - } - - // 从数据库获取最新账户信息 - if p.accountRepo != nil { - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh - } - } - expiresAt = account.GetCredentialAsTime("expires_at") - - // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 - if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew { - if p.oauthService == nil { - slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID) - refreshFailed = true - } else { - tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account) - if err != nil { - slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err) - refreshFailed = true - } else { - // 构建新 credentials,保留原有字段 - newCredentials := make(map[string]any) - for k, v := range account.Credentials { - newCredentials[k] = v - } - newCredentials["access_token"] = tokenInfo.AccessToken - newCredentials["token_type"] = tokenInfo.TokenType - newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10) - newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) - if tokenInfo.RefreshToken != "" { - newCredentials["refresh_token"] = tokenInfo.RefreshToken - } - if tokenInfo.Scope != "" { - newCredentials["scope"] = tokenInfo.Scope - } - account.Credentials = newCredentials - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr) - } - expiresAt = account.GetCredentialAsTime("expires_at") - } + slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err) + refreshFailed = true + } else if result.LockHeld { + if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil { + time.Sleep(claudeLockWaitTime) + if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" { + slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID) + return token, nil } } } else { - // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存 + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") + } + } else if needsRefresh && p.tokenCache != nil { + // Backward-compatible test path when refreshAPI is not injected. + locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if lockErr == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + } else if lockErr != nil { + slog.Warn("claude_token_lock_failed", "account_id", account.ID, "error", lockErr) + } else { time.Sleep(claudeLockWaitTime) if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID) @@ -181,22 +120,23 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou return "", errors.New("access_token not found in credentials") } - // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) + // 3) Populate cache with TTL. if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) if isStale && latestAccount != nil { - // 版本过时,使用 DB 中的最新 token slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID) accessToken = latestAccount.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found after version check") } - // 不写入缓存,让下次请求重新处理 } else { ttl := 30 * time.Minute if refreshFailed { - // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动 - ttl = time.Minute + if p.refreshPolicy.FailureTTL > 0 { + ttl = p.refreshPolicy.FailureTTL + } else { + ttl = time.Minute + } slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed") } else if expiresAt != nil { until := time.Until(*expiresAt) diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 313b048f..1dab67c4 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -15,10 +15,14 @@ const ( geminiTokenCacheSkew = 5 * time.Minute ) +// GeminiTokenProvider manages access_token for Gemini OAuth accounts. type GeminiTokenProvider struct { accountRepo AccountRepository tokenCache GeminiTokenCache geminiOAuthService *GeminiOAuthService + refreshAPI *OAuthRefreshAPI + executor OAuthRefreshExecutor + refreshPolicy ProviderRefreshPolicy } func NewGeminiTokenProvider( @@ -30,9 +34,21 @@ func NewGeminiTokenProvider( accountRepo: accountRepo, tokenCache: tokenCache, geminiOAuthService: geminiOAuthService, + refreshPolicy: GeminiProviderRefreshPolicy(), } } +// SetRefreshAPI injects unified OAuth refresh API and executor. +func (p *GeminiTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) { + p.refreshAPI = api + p.executor = executor +} + +// SetRefreshPolicy injects caller-side refresh policy. +func (p *GeminiTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) { + p.refreshPolicy = policy +} + func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { if account == nil { return "", errors.New("account is nil") @@ -53,39 +69,31 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou // 2) Refresh if needed (pre-expiry skew). expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew - if needsRefresh && p.tokenCache != nil { - locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) - if err == nil && locked { - defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() - // Re-check after lock (another worker may have refreshed). - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil + if needsRefresh && p.refreshAPI != nil && p.executor != nil { + result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, geminiTokenRefreshSkew) + if err != nil { + if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { + return "", err } - - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh + } else if result.LockHeld { + if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil { + if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" { + return token, nil + } } + slog.Debug("gemini_token_lock_held_use_old", "account_id", account.ID) + } else { + account = result.Account expiresAt = account.GetCredentialAsTime("expires_at") - if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew { - if p.geminiOAuthService == nil { - return "", errors.New("gemini oauth service not configured") - } - tokenInfo, err := p.geminiOAuthService.RefreshAccountToken(ctx, account) - if err != nil { - return "", err - } - newCredentials := p.geminiOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } - account.Credentials = newCredentials - _ = p.accountRepo.Update(ctx, account) - expiresAt = account.GetCredentialAsTime("expires_at") - } + } + } else if needsRefresh && p.tokenCache != nil { + // Backward-compatible test path when refreshAPI is not injected. + locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if lockErr == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + } else if lockErr != nil { + slog.Warn("gemini_token_lock_failed", "account_id", account.ID, "error", lockErr) } } @@ -95,15 +103,14 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } // project_id is optional now: - // - If present: will use Code Assist API (requires project_id) - // - If absent: will use AI Studio API with OAuth token (like regular API key mode) - // Auto-detect project_id only if explicitly enabled via a credential flag + // - If present: use Code Assist API (requires project_id) + // - If absent: use AI Studio API with OAuth token. projectID := strings.TrimSpace(account.GetCredential("project_id")) autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true" if projectID == "" && autoDetectProjectID { if p.geminiOAuthService == nil { - return accessToken, nil // Fallback to AI Studio API mode + return accessToken, nil } var proxyURL string @@ -132,17 +139,15 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } } - // 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) + // 3) Populate cache with TTL. if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) if isStale && latestAccount != nil { - // 版本过时,使用 DB 中的最新 token slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID) accessToken = latestAccount.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found after version check") } - // 不写入缓存,让下次请求重新处理 } else { ttl := 30 * time.Minute if expiresAt != nil { diff --git a/backend/internal/service/gemini_token_refresher.go b/backend/internal/service/gemini_token_refresher.go index 7dfc5521..d5e502da 100644 --- a/backend/internal/service/gemini_token_refresher.go +++ b/backend/internal/service/gemini_token_refresher.go @@ -13,6 +13,11 @@ func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiToke return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService} } +// CacheKey 返回用于分布式锁的缓存键 +func (r *GeminiTokenRefresher) CacheKey(account *Account) string { + return GeminiTokenCacheKey(account) +} + func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool { return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth } @@ -35,11 +40,7 @@ func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (m } newCredentials := r.geminiOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } + newCredentials = MergeCredentials(account.Credentials, newCredentials) return newCredentials, nil } diff --git a/backend/internal/service/oauth_refresh_api.go b/backend/internal/service/oauth_refresh_api.go new file mode 100644 index 00000000..17b9128c --- /dev/null +++ b/backend/internal/service/oauth_refresh_api.go @@ -0,0 +1,159 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "strconv" + "time" +) + +// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器 +// TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁 +type OAuthRefreshExecutor interface { + TokenRefresher + + // CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致) + CacheKey(account *Account) string +} + +const refreshLockTTL = 30 * time.Second + +// OAuthRefreshResult 统一刷新结果 +type OAuthRefreshResult struct { + Refreshed bool // 实际执行了刷新 + NewCredentials map[string]any // 刷新后的 credentials(nil 表示未刷新) + Account *Account // 从 DB 重新读取的最新 account + LockHeld bool // 锁被其他 worker 持有(未执行刷新) +} + +// OAuthRefreshAPI 统一的 OAuth Token 刷新入口 +// 封装分布式锁、DB 重读、已刷新检查等通用逻辑 +type OAuthRefreshAPI struct { + accountRepo AccountRepository + tokenCache GeminiTokenCache // 可选,nil = 无锁 +} + +// NewOAuthRefreshAPI 创建统一刷新 API +func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI { + return &OAuthRefreshAPI{ + accountRepo: accountRepo, + tokenCache: tokenCache, + } +} + +// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token +// +// 流程: +// 1. 获取分布式锁 +// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token) +// 3. 二次检查是否仍需刷新 +// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑 +// 5. 设置 _token_version + 更新 DB +// 6. 释放锁 +func (api *OAuthRefreshAPI) RefreshIfNeeded( + ctx context.Context, + account *Account, + executor OAuthRefreshExecutor, + refreshWindow time.Duration, +) (*OAuthRefreshResult, error) { + cacheKey := executor.CacheKey(account) + + // 1. 获取分布式锁 + lockAcquired := false + if api.tokenCache != nil { + acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL) + if lockErr != nil { + // Redis 错误,降级为无锁刷新 + slog.Warn("oauth_refresh_lock_failed_degraded", + "account_id", account.ID, + "cache_key", cacheKey, + "error", lockErr, + ) + } else if !acquired { + // 锁被其他 worker 持有 + return &OAuthRefreshResult{LockHeld: true}, nil + } else { + lockAcquired = true + defer func() { _ = api.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + } + } + + // 2. 从 DB 重读最新 account(锁保护下,确保使用最新的 refresh_token) + freshAccount, err := api.accountRepo.GetByID(ctx, account.ID) + if err != nil { + slog.Warn("oauth_refresh_db_reread_failed", + "account_id", account.ID, + "error", err, + ) + // 降级使用传入的 account + freshAccount = account + } else if freshAccount == nil { + freshAccount = account + } + + // 3. 二次检查是否仍需刷新(另一条路径可能已刷新) + if !executor.NeedsRefresh(freshAccount, refreshWindow) { + return &OAuthRefreshResult{ + Account: freshAccount, + }, nil + } + + // 4. 执行平台特定刷新逻辑 + newCredentials, refreshErr := executor.Refresh(ctx, freshAccount) + if refreshErr != nil { + return nil, refreshErr + } + + // 5. 设置版本号 + 更新 DB + if newCredentials != nil { + newCredentials["_token_version"] = time.Now().UnixMilli() + freshAccount.Credentials = newCredentials + if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil { + slog.Error("oauth_refresh_update_failed", + "account_id", freshAccount.ID, + "error", updateErr, + ) + return nil, fmt.Errorf("oauth refresh succeeded but DB update failed: %w", updateErr) + } + } + + _ = lockAcquired // suppress unused warning when tokenCache is nil + + return &OAuthRefreshResult{ + Refreshed: true, + NewCredentials: newCredentials, + Account: freshAccount, + }, nil +} + +// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中 +func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any { + if newCreds == nil { + newCreds = make(map[string]any) + } + for k, v := range oldCreds { + if _, exists := newCreds[k]; !exists { + newCreds[k] = v + } + } + return newCreds +} + +// BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map +// 消除 Claude 平台没有 BuildAccountCredentials 方法的问题 +func BuildClaudeAccountCredentials(tokenInfo *TokenInfo) map[string]any { + creds := map[string]any{ + "access_token": tokenInfo.AccessToken, + "token_type": tokenInfo.TokenType, + "expires_in": strconv.FormatInt(tokenInfo.ExpiresIn, 10), + "expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10), + } + if tokenInfo.RefreshToken != "" { + creds["refresh_token"] = tokenInfo.RefreshToken + } + if tokenInfo.Scope != "" { + creds["scope"] = tokenInfo.Scope + } + return creds +} diff --git a/backend/internal/service/oauth_refresh_api_test.go b/backend/internal/service/oauth_refresh_api_test.go new file mode 100644 index 00000000..6cf9371f --- /dev/null +++ b/backend/internal/service/oauth_refresh_api_test.go @@ -0,0 +1,395 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// ---------- mock helpers ---------- + +// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests. +type refreshAPIAccountRepo struct { + mockAccountRepoForGemini + account *Account // returned by GetByID + getByIDErr error + updateErr error + updateCalls int +} + +func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) { + if r.getByIDErr != nil { + return nil, r.getByIDErr + } + return r.account, nil +} + +func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error { + r.updateCalls++ + return r.updateErr +} + +// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests. +type refreshAPIExecutorStub struct { + needsRefresh bool + credentials map[string]any + err error + refreshCalls int +} + +func (e *refreshAPIExecutorStub) CanRefresh(_ *Account) bool { return true } + +func (e *refreshAPIExecutorStub) NeedsRefresh(_ *Account, _ time.Duration) bool { + return e.needsRefresh +} + +func (e *refreshAPIExecutorStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) { + e.refreshCalls++ + if e.err != nil { + return nil, e.err + } + return e.credentials, nil +} + +func (e *refreshAPIExecutorStub) CacheKey(account *Account) string { + return "test:api:" + account.Platform +} + +// refreshAPICacheStub implements GeminiTokenCache for OAuthRefreshAPI tests. +type refreshAPICacheStub struct { + lockResult bool + lockErr error + releaseCalls int +} + +func (c *refreshAPICacheStub) GetAccessToken(context.Context, string) (string, error) { + return "", nil +} + +func (c *refreshAPICacheStub) SetAccessToken(context.Context, string, string, time.Duration) error { + return nil +} + +func (c *refreshAPICacheStub) DeleteAccessToken(context.Context, string) error { return nil } + +func (c *refreshAPICacheStub) AcquireRefreshLock(context.Context, string, time.Duration) (bool, error) { + return c.lockResult, c.lockErr +} + +func (c *refreshAPICacheStub) ReleaseRefreshLock(context.Context, string) error { + c.releaseCalls++ + return nil +} + +// ========== RefreshIfNeeded tests ========== + +func TestRefreshIfNeeded_Success(t *testing.T) { + account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "new-token"}, + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) + require.NotNil(t, result.NewCredentials) + require.Equal(t, "new-token", result.NewCredentials["access_token"]) + require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set + require.Equal(t, 1, repo.updateCalls) // DB updated + require.Equal(t, 1, cache.releaseCalls) // lock released + require.Equal(t, 1, executor.refreshCalls) +} + +func TestRefreshIfNeeded_LockHeld(t *testing.T) { + account := &Account{ID: 2, Platform: PlatformAnthropic} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: false} // lock not acquired + executor := &refreshAPIExecutorStub{needsRefresh: true} + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.LockHeld) + require.False(t, result.Refreshed) + require.Equal(t, 0, repo.updateCalls) + require.Equal(t, 0, executor.refreshCalls) +} + +func TestRefreshIfNeeded_LockErrorDegrades(t *testing.T) { + account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockErr: errors.New("redis down")} // lock error + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "degraded-token"}, + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) // still refreshed (degraded mode) + require.Equal(t, 1, repo.updateCalls) // DB updated + require.Equal(t, 0, cache.releaseCalls) // no lock to release + require.Equal(t, 1, executor.refreshCalls) +} + +func TestRefreshIfNeeded_NoCacheNoLock(t *testing.T) { + account := &Account{ID: 4, Platform: PlatformGemini, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{account: account} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "no-cache-token"}, + } + + api := NewOAuthRefreshAPI(repo, nil) // no cache = no lock + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) + require.Equal(t, 1, repo.updateCalls) +} + +func TestRefreshIfNeeded_AlreadyRefreshed(t *testing.T) { + account := &Account{ID: 5, Platform: PlatformAnthropic} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{needsRefresh: false} // already refreshed + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.False(t, result.Refreshed) + require.False(t, result.LockHeld) + require.NotNil(t, result.Account) // returns fresh account + require.Equal(t, 0, repo.updateCalls) + require.Equal(t, 0, executor.refreshCalls) +} + +func TestRefreshIfNeeded_RefreshError(t *testing.T) { + account := &Account{ID: 6, Platform: PlatformAnthropic} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + err: errors.New("invalid_grant: token revoked"), + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "invalid_grant") + require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error + require.Equal(t, 1, cache.releaseCalls) // lock still released via defer +} + +func TestRefreshIfNeeded_DBUpdateError(t *testing.T) { + account := &Account{ID: 7, Platform: PlatformGemini, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{ + account: account, + updateErr: errors.New("db connection lost"), + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "token"}, + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "DB update failed") + require.Equal(t, 1, repo.updateCalls) // attempted +} + +func TestRefreshIfNeeded_DBRereadFails(t *testing.T) { + account := &Account{ID: 8, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{ + account: nil, // GetByID returns nil + getByIDErr: errors.New("db timeout"), + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "fallback-token"}, + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) + require.Equal(t, 1, executor.refreshCalls) // still refreshes using passed-in account +} + +func TestRefreshIfNeeded_NilCredentials(t *testing.T) { + account := &Account{ID: 9, Platform: PlatformGemini, Type: AccountTypeOAuth} + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: nil, // Refresh returns nil credentials + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) + require.Nil(t, result.NewCredentials) + require.Equal(t, 0, repo.updateCalls) // no DB update when credentials are nil +} + +// ========== MergeCredentials tests ========== + +func TestMergeCredentials_Basic(t *testing.T) { + old := map[string]any{"a": "1", "b": "2", "c": "3"} + new := map[string]any{"a": "new", "d": "4"} + + result := MergeCredentials(old, new) + + require.Equal(t, "new", result["a"]) // new value preserved + require.Equal(t, "2", result["b"]) // old value kept + require.Equal(t, "3", result["c"]) // old value kept + require.Equal(t, "4", result["d"]) // new value preserved +} + +func TestMergeCredentials_NilNew(t *testing.T) { + old := map[string]any{"a": "1"} + + result := MergeCredentials(old, nil) + + require.NotNil(t, result) + require.Equal(t, "1", result["a"]) +} + +func TestMergeCredentials_NilOld(t *testing.T) { + new := map[string]any{"a": "1"} + + result := MergeCredentials(nil, new) + + require.Equal(t, "1", result["a"]) +} + +func TestMergeCredentials_BothNil(t *testing.T) { + result := MergeCredentials(nil, nil) + require.NotNil(t, result) + require.Empty(t, result) +} + +func TestMergeCredentials_NewOverridesOld(t *testing.T) { + old := map[string]any{"access_token": "old-token", "refresh_token": "old-refresh"} + new := map[string]any{"access_token": "new-token"} + + result := MergeCredentials(old, new) + + require.Equal(t, "new-token", result["access_token"]) // overridden + require.Equal(t, "old-refresh", result["refresh_token"]) // preserved +} + +// ========== BuildClaudeAccountCredentials tests ========== + +func TestBuildClaudeAccountCredentials_Full(t *testing.T) { + tokenInfo := &TokenInfo{ + AccessToken: "at-123", + TokenType: "Bearer", + ExpiresIn: 3600, + ExpiresAt: 1700000000, + RefreshToken: "rt-456", + Scope: "openid", + } + + creds := BuildClaudeAccountCredentials(tokenInfo) + + require.Equal(t, "at-123", creds["access_token"]) + require.Equal(t, "Bearer", creds["token_type"]) + require.Equal(t, "3600", creds["expires_in"]) + require.Equal(t, "1700000000", creds["expires_at"]) + require.Equal(t, "rt-456", creds["refresh_token"]) + require.Equal(t, "openid", creds["scope"]) +} + +func TestBuildClaudeAccountCredentials_Minimal(t *testing.T) { + tokenInfo := &TokenInfo{ + AccessToken: "at-789", + TokenType: "Bearer", + ExpiresIn: 7200, + ExpiresAt: 1700003600, + } + + creds := BuildClaudeAccountCredentials(tokenInfo) + + require.Equal(t, "at-789", creds["access_token"]) + require.Equal(t, "Bearer", creds["token_type"]) + require.Equal(t, "7200", creds["expires_in"]) + require.Equal(t, "1700003600", creds["expires_at"]) + _, hasRefresh := creds["refresh_token"] + _, hasScope := creds["scope"] + require.False(t, hasRefresh, "refresh_token should not be set when empty") + require.False(t, hasScope, "scope should not be set when empty") +} + +// ========== BackgroundRefreshPolicy tests ========== + +func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) { + p := DefaultBackgroundRefreshPolicy() + + require.ErrorIs(t, p.handleLockHeld(), errRefreshSkipped) + require.ErrorIs(t, p.handleAlreadyRefreshed(), errRefreshSkipped) +} + +func TestBackgroundRefreshPolicy_SuccessOverride(t *testing.T) { + p := BackgroundRefreshPolicy{ + OnLockHeld: BackgroundSkipAsSuccess, + OnAlreadyRefresh: BackgroundSkipAsSuccess, + } + + require.NoError(t, p.handleLockHeld()) + require.NoError(t, p.handleAlreadyRefreshed()) +} + +// ========== ProviderRefreshPolicy tests ========== + +func TestClaudeProviderRefreshPolicy(t *testing.T) { + p := ClaudeProviderRefreshPolicy() + require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError) + require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld) + require.Equal(t, time.Minute, p.FailureTTL) +} + +func TestOpenAIProviderRefreshPolicy(t *testing.T) { + p := OpenAIProviderRefreshPolicy() + require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError) + require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld) + require.Equal(t, time.Minute, p.FailureTTL) +} + +func TestGeminiProviderRefreshPolicy(t *testing.T) { + p := GeminiProviderRefreshPolicy() + require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError) + require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld) + require.Equal(t, time.Duration(0), p.FailureTTL) +} + +func TestAntigravityProviderRefreshPolicy(t *testing.T) { + p := AntigravityProviderRefreshPolicy() + require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError) + require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld) + require.Equal(t, time.Duration(0), p.FailureTTL) +} diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index a8a6b96c..69477ce7 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -20,7 +20,7 @@ const ( openAILockWarnThresholdMs = 250 ) -// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。 +// OpenAITokenRuntimeMetrics is a snapshot of refresh and lock contention metrics. type OpenAITokenRuntimeMetrics struct { RefreshRequests int64 RefreshSuccess int64 @@ -72,15 +72,18 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() { m.lastObservedUnixMs.Store(time.Now().UnixMilli()) } -// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) +// OpenAITokenCache token cache interface. type OpenAITokenCache = GeminiTokenCache -// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token +// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts. type OpenAITokenProvider struct { accountRepo AccountRepository tokenCache OpenAITokenCache openAIOAuthService *OpenAIOAuthService metrics *openAITokenRuntimeMetricsStore + refreshAPI *OAuthRefreshAPI + executor OAuthRefreshExecutor + refreshPolicy ProviderRefreshPolicy } func NewOpenAITokenProvider( @@ -93,9 +96,21 @@ func NewOpenAITokenProvider( tokenCache: tokenCache, openAIOAuthService: openAIOAuthService, metrics: &openAITokenRuntimeMetricsStore{}, + refreshPolicy: OpenAIProviderRefreshPolicy(), } } +// SetRefreshAPI injects unified OAuth refresh API and executor. +func (p *OpenAITokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) { + p.refreshAPI = api + p.executor = executor +} + +// SetRefreshPolicy injects caller-side refresh policy. +func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) { + p.refreshPolicy = policy +} + func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics { if p == nil { return OpenAITokenRuntimeMetrics{} @@ -110,7 +125,7 @@ func (p *OpenAITokenProvider) ensureMetrics() { } } -// GetAccessToken 获取有效的 access_token +// GetAccessToken returns a valid access_token. func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { p.ensureMetrics() if account == nil { @@ -122,7 +137,7 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou cacheKey := OpenAITokenCacheKey(account) - // 1. 先尝试缓存 + // 1) Try cache first. if p.tokenCache != nil { if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { slog.Debug("openai_token_cache_hit", "account_id", account.ID) @@ -134,114 +149,62 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou slog.Debug("openai_token_cache_miss", "account_id", account.ID) - // 2. 如果即将过期则刷新 + // 2) Refresh if needed (pre-expiry skew). expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew refreshFailed := false - if needsRefresh && p.tokenCache != nil { + + if needsRefresh && p.refreshAPI != nil && p.executor != nil { + p.metrics.refreshRequests.Add(1) + p.metrics.touchNow() + + // Sora accounts skip OpenAI OAuth refresh and keep existing token path. + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) + refreshFailed = true + } else { + result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew) + if err != nil { + if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { + return "", err + } + slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) + refreshFailed = true + } else if result.LockHeld { + if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache { + p.metrics.lockContention.Add(1) + p.metrics.touchNow() + token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) + if waitErr != nil { + return "", waitErr + } + if strings.TrimSpace(token) != "" { + slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) + return token, nil + } + } + } else if result.Refreshed { + p.metrics.refreshSuccess.Add(1) + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") + } else { + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") + } + } + } else if needsRefresh && p.tokenCache != nil { + // Backward-compatible test path when refreshAPI is not injected. p.metrics.refreshRequests.Add(1) p.metrics.touchNow() locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) if lockErr == nil && locked { defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() - - // 拿到锁后再次检查缓存(另一个 worker 可能已刷新) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil - } - - // 从数据库获取最新账户信息 - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh - } - expiresAt = account.GetCredentialAsTime("expires_at") - if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if account.Platform == PlatformSora { - slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) - // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 - refreshFailed = true - } else if p.openAIOAuthService == nil { - slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) - p.metrics.refreshFailure.Add(1) - refreshFailed = true // 无法刷新,标记失败 - } else { - tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) - if err != nil { - // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token - slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) - p.metrics.refreshFailure.Add(1) - refreshFailed = true // 刷新失败,标记以使用短 TTL - } else { - p.metrics.refreshSuccess.Add(1) - newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } - account.Credentials = newCredentials - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr) - } - expiresAt = account.GetCredentialAsTime("expires_at") - } - } - } } else if lockErr != nil { - // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时) p.metrics.lockAcquireFailure.Add(1) p.metrics.touchNow() - slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) - - // 检查 ctx 是否已取消 - if ctx.Err() != nil { - return "", ctx.Err() - } - - // 从数据库获取最新账户信息 - if p.accountRepo != nil { - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh - } - } - expiresAt = account.GetCredentialAsTime("expires_at") - - // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 - if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if account.Platform == PlatformSora { - slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID) - // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 - refreshFailed = true - } else if p.openAIOAuthService == nil { - slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) - p.metrics.refreshFailure.Add(1) - refreshFailed = true - } else { - tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) - if err != nil { - slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err) - p.metrics.refreshFailure.Add(1) - refreshFailed = true - } else { - p.metrics.refreshSuccess.Add(1) - newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } - account.Credentials = newCredentials - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr) - } - expiresAt = account.GetCredentialAsTime("expires_at") - } - } - } + slog.Warn("openai_token_lock_failed", "account_id", account.ID, "error", lockErr) } else { - // 锁被其他 worker 持有:使用短轮询+jitter,降低固定等待导致的尾延迟台阶。 p.metrics.lockContention.Add(1) p.metrics.touchNow() token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) @@ -260,22 +223,23 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou return "", errors.New("access_token not found in credentials") } - // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) + // 3) Populate cache with TTL. if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) if isStale && latestAccount != nil { - // 版本过时,使用 DB 中的最新 token slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID) accessToken = latestAccount.GetOpenAIAccessToken() if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found after version check") } - // 不写入缓存,让下次请求重新处理 } else { ttl := 30 * time.Minute if refreshFailed { - // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动 - ttl = time.Minute + if p.refreshPolicy.FailureTTL > 0 { + ttl = p.refreshPolicy.FailureTTL + } else { + ttl = time.Minute + } slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed") } else if expiresAt != nil { until := time.Until(*expiresAt) diff --git a/backend/internal/service/refresh_policy.go b/backend/internal/service/refresh_policy.go new file mode 100644 index 00000000..7f299be0 --- /dev/null +++ b/backend/internal/service/refresh_policy.go @@ -0,0 +1,99 @@ +package service + +import "time" + +// ProviderRefreshErrorAction 定义 provider 在刷新失败时的处理动作。 +type ProviderRefreshErrorAction int + +const ( + // ProviderRefreshErrorReturn 失败即返回错误(不降级旧 token)。 + ProviderRefreshErrorReturn ProviderRefreshErrorAction = iota + // ProviderRefreshErrorUseExistingToken 失败后继续使用现有 token。 + ProviderRefreshErrorUseExistingToken +) + +// ProviderLockHeldAction 定义 provider 在刷新锁被占用时的处理动作。 +type ProviderLockHeldAction int + +const ( + // ProviderLockHeldUseExistingToken 直接使用现有 token。 + ProviderLockHeldUseExistingToken ProviderLockHeldAction = iota + // ProviderLockHeldWaitForCache 等待后重试缓存读取。 + ProviderLockHeldWaitForCache +) + +// ProviderRefreshPolicy 描述 provider 的平台差异策略。 +type ProviderRefreshPolicy struct { + OnRefreshError ProviderRefreshErrorAction + OnLockHeld ProviderLockHeldAction + FailureTTL time.Duration +} + +func ClaudeProviderRefreshPolicy() ProviderRefreshPolicy { + return ProviderRefreshPolicy{ + OnRefreshError: ProviderRefreshErrorUseExistingToken, + OnLockHeld: ProviderLockHeldWaitForCache, + FailureTTL: time.Minute, + } +} + +func OpenAIProviderRefreshPolicy() ProviderRefreshPolicy { + return ProviderRefreshPolicy{ + OnRefreshError: ProviderRefreshErrorUseExistingToken, + OnLockHeld: ProviderLockHeldWaitForCache, + FailureTTL: time.Minute, + } +} + +func GeminiProviderRefreshPolicy() ProviderRefreshPolicy { + return ProviderRefreshPolicy{ + OnRefreshError: ProviderRefreshErrorReturn, + OnLockHeld: ProviderLockHeldUseExistingToken, + FailureTTL: 0, + } +} + +func AntigravityProviderRefreshPolicy() ProviderRefreshPolicy { + return ProviderRefreshPolicy{ + OnRefreshError: ProviderRefreshErrorReturn, + OnLockHeld: ProviderLockHeldUseExistingToken, + FailureTTL: 0, + } +} + +// BackgroundSkipAction 定义后台刷新服务在“未实际刷新”场景的计数方式。 +type BackgroundSkipAction int + +const ( + // BackgroundSkipAsSkipped 计入 skipped(保持当前默认行为)。 + BackgroundSkipAsSkipped BackgroundSkipAction = iota + // BackgroundSkipAsSuccess 计入 success(仅用于兼容旧统计口径时可选)。 + BackgroundSkipAsSuccess +) + +// BackgroundRefreshPolicy 描述后台刷新服务的调用侧策略。 +type BackgroundRefreshPolicy struct { + OnLockHeld BackgroundSkipAction + OnAlreadyRefresh BackgroundSkipAction +} + +func DefaultBackgroundRefreshPolicy() BackgroundRefreshPolicy { + return BackgroundRefreshPolicy{ + OnLockHeld: BackgroundSkipAsSkipped, + OnAlreadyRefresh: BackgroundSkipAsSkipped, + } +} + +func (p BackgroundRefreshPolicy) handleLockHeld() error { + if p.OnLockHeld == BackgroundSkipAsSuccess { + return nil + } + return errRefreshSkipped +} + +func (p BackgroundRefreshPolicy) handleAlreadyRefreshed() error { + if p.OnAlreadyRefresh == BackgroundSkipAsSuccess { + return nil + } + return errRefreshSkipped +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 1825257c..cb00d5e5 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "log/slog" "strings" @@ -16,10 +17,13 @@ import ( type TokenRefreshService struct { accountRepo AccountRepository refreshers []TokenRefresher + executors []OAuthRefreshExecutor // 与 refreshers 一一对应的 executor(带 CacheKey) + refreshPolicy BackgroundRefreshPolicy cfg *config.TokenRefreshConfig cacheInvalidator TokenCacheInvalidator schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题 tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存 + refreshAPI *OAuthRefreshAPI // 统一刷新 API // OpenAI privacy: 刷新成功后检查并设置 training opt-out privacyClientFactory PrivacyClientFactory @@ -43,6 +47,7 @@ func NewTokenRefreshService( ) *TokenRefreshService { s := &TokenRefreshService{ accountRepo: accountRepo, + refreshPolicy: DefaultBackgroundRefreshPolicy(), cfg: &cfg.TokenRefresh, cacheInvalidator: cacheInvalidator, schedulerCache: schedulerCache, @@ -53,12 +58,24 @@ func NewTokenRefreshService( openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo) openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts) - // 注册平台特定的刷新器 + claudeRefresher := NewClaudeTokenRefresher(oauthService) + geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService) + agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService) + + // 注册平台特定的刷新器(TokenRefresher 接口) s.refreshers = []TokenRefresher{ - NewClaudeTokenRefresher(oauthService), + claudeRefresher, openAIRefresher, - NewGeminiTokenRefresher(geminiOAuthService), - NewAntigravityTokenRefresher(antigravityOAuthService), + geminiRefresher, + agRefresher, + } + + // 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法) + s.executors = []OAuthRefreshExecutor{ + claudeRefresher, + openAIRefresher, + geminiRefresher, + agRefresher, } return s @@ -82,6 +99,16 @@ func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxy s.proxyRepo = proxyRepo } +// SetRefreshAPI 注入统一的 OAuth 刷新 API +func (s *TokenRefreshService) SetRefreshAPI(api *OAuthRefreshAPI) { + s.refreshAPI = api +} + +// SetRefreshPolicy 注入后台刷新调用侧策略(用于显式化平台/场景差异行为)。 +func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) { + s.refreshPolicy = policy +} + // Start 启动后台刷新服务 func (s *TokenRefreshService) Start() { if !s.cfg.Enabled { @@ -148,13 +175,13 @@ func (s *TokenRefreshService) processRefresh() { totalAccounts := len(accounts) oauthAccounts := 0 // 可刷新的OAuth账号数 needsRefresh := 0 // 需要刷新的账号数 - refreshed, failed := 0, 0 + refreshed, failed, skipped := 0, 0, 0 for i := range accounts { account := &accounts[i] // 遍历所有刷新器,找到能处理此账号的 - for _, refresher := range s.refreshers { + for idx, refresher := range s.refreshers { if !refresher.CanRefresh(account) { continue } @@ -168,14 +195,24 @@ func (s *TokenRefreshService) processRefresh() { needsRefresh++ + // 获取对应的 executor + var executor OAuthRefreshExecutor + if idx < len(s.executors) { + executor = s.executors[idx] + } + // 执行刷新 - if err := s.refreshWithRetry(ctx, account, refresher); err != nil { - slog.Warn("token_refresh.account_refresh_failed", - "account_id", account.ID, - "account_name", account.Name, - "error", err, - ) - failed++ + if err := s.refreshWithRetry(ctx, account, refresher, executor, refreshWindow); err != nil { + if errors.Is(err, errRefreshSkipped) { + skipped++ + } else { + slog.Warn("token_refresh.account_refresh_failed", + "account_id", account.ID, + "account_name", account.Name, + "error", err, + ) + failed++ + } } else { slog.Info("token_refresh.account_refreshed", "account_id", account.ID, @@ -193,13 +230,14 @@ func (s *TokenRefreshService) processRefresh() { if needsRefresh == 0 && failed == 0 { slog.Debug("token_refresh.cycle_completed", "total", totalAccounts, "oauth", oauthAccounts, - "needs_refresh", needsRefresh, "refreshed", refreshed, "failed", failed) + "needs_refresh", needsRefresh, "refreshed", refreshed, "skipped", skipped, "failed", failed) } else { slog.Info("token_refresh.cycle_completed", "total", totalAccounts, "oauth", oauthAccounts, "needs_refresh", needsRefresh, "refreshed", refreshed, + "skipped", skipped, "failed", failed, ) } @@ -212,83 +250,43 @@ func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account } // refreshWithRetry 带重试的刷新 -func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher) error { +func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher, executor OAuthRefreshExecutor, refreshWindow time.Duration) error { var lastErr error for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ { - newCredentials, err := refresher.Refresh(ctx, account) + var newCredentials map[string]any + var err error - // 如果有新凭证,先更新(即使有错误也要保存 token) - if newCredentials != nil { - // 记录刷新版本时间戳,用于解决缓存一致性问题 - // TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入 - newCredentials["_token_version"] = time.Now().UnixMilli() - - account.Credentials = newCredentials - if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil { - return fmt.Errorf("failed to save credentials: %w", saveErr) + // 优先使用统一 API(带分布式锁 + DB 重读保护) + if s.refreshAPI != nil && executor != nil { + result, refreshErr := s.refreshAPI.RefreshIfNeeded(ctx, account, executor, refreshWindow) + if refreshErr != nil { + err = refreshErr + } else if result.LockHeld { + // 锁被其他 worker 持有,由调用侧策略决定如何计数 + return s.refreshPolicy.handleLockHeld() + } else if !result.Refreshed { + // 已被其他路径刷新,由调用侧策略决定如何计数 + return s.refreshPolicy.handleAlreadyRefreshed() + } else { + account = result.Account + newCredentials = result.NewCredentials + // 统一 API 已设置 _token_version 并更新 DB,无需重复操作 + } + } else { + // 降级:直接调用 refresher(兼容旧路径) + newCredentials, err = refresher.Refresh(ctx, account) + if newCredentials != nil { + newCredentials["_token_version"] = time.Now().UnixMilli() + account.Credentials = newCredentials + if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil { + return fmt.Errorf("failed to save credentials: %w", saveErr) + } } } if err == nil { - // Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态 - if account.Platform == PlatformAntigravity && - account.Status == StatusError && - strings.Contains(account.ErrorMessage, "missing_project_id:") { - if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil { - slog.Warn("token_refresh.clear_account_error_failed", - "account_id", account.ID, - "error", clearErr, - ) - } else { - slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID) - } - } - // 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景) - if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { - if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil { - slog.Warn("token_refresh.clear_temp_unschedulable_failed", - "account_id", account.ID, - "error", clearErr, - ) - } else { - slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID) - } - // 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态 - if s.tempUnschedCache != nil { - if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil { - slog.Warn("token_refresh.clear_temp_unsched_cache_failed", - "account_id", account.ID, - "error", clearErr, - ) - } - } - } - // 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理) - if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth { - if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil { - slog.Warn("token_refresh.invalidate_token_cache_failed", - "account_id", account.ID, - "error", err, - ) - } else { - slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID) - } - } - // 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials - // 这解决了 token 刷新后调度器缓存数据不一致的问题(#445) - if s.schedulerCache != nil { - if err := s.schedulerCache.SetAccount(ctx, account); err != nil { - slog.Warn("token_refresh.sync_scheduler_cache_failed", - "account_id", account.ID, - "error", err, - ) - } else { - slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID) - } - } - // OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享 - s.ensureOpenAIPrivacy(ctx, account) + s.postRefreshActions(ctx, account) return nil } @@ -331,6 +329,70 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc return lastErr } +// postRefreshActions 刷新成功后的后续动作(清除错误状态、缓存失效、调度器同步等) +func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *Account) { + // Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态 + if account.Platform == PlatformAntigravity && + account.Status == StatusError && + strings.Contains(account.ErrorMessage, "missing_project_id:") { + if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil { + slog.Warn("token_refresh.clear_account_error_failed", + "account_id", account.ID, + "error", clearErr, + ) + } else { + slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID) + } + } + // 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景) + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil { + slog.Warn("token_refresh.clear_temp_unschedulable_failed", + "account_id", account.ID, + "error", clearErr, + ) + } else { + slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID) + } + // 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态 + if s.tempUnschedCache != nil { + if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil { + slog.Warn("token_refresh.clear_temp_unsched_cache_failed", + "account_id", account.ID, + "error", clearErr, + ) + } + } + } + // 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理) + if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth { + if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil { + slog.Warn("token_refresh.invalidate_token_cache_failed", + "account_id", account.ID, + "error", err, + ) + } else { + slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID) + } + } + // 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials + if s.schedulerCache != nil { + if err := s.schedulerCache.SetAccount(ctx, account); err != nil { + slog.Warn("token_refresh.sync_scheduler_cache_failed", + "account_id", account.ID, + "error", err, + ) + } else { + slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID) + } + } + // OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享 + s.ensureOpenAIPrivacy(ctx, account) +} + +// errRefreshSkipped 表示刷新被跳过(锁竞争或已被其他路径刷新),不计入 failed 或 refreshed +var errRefreshSkipped = fmt.Errorf("refresh skipped") + // isNonRetryableRefreshError 判断是否为不可重试的刷新错误 // 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权 // 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误 diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index bdef0ed7..f48de65e 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -84,6 +84,10 @@ func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map return r.credentials, nil } +func (r *tokenRefresherStub) CacheKey(account *Account) string { + return "test:stub:" + account.Platform +} + func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { repo := &tokenRefreshAccountRepo{} invalidator := &tokenCacheInvalidatorStub{} @@ -105,7 +109,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, invalidator.calls) @@ -133,7 +137,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, invalidator.calls) @@ -159,7 +163,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) } @@ -186,7 +190,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效 @@ -214,7 +218,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效 @@ -242,7 +246,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效 @@ -270,7 +274,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.Error(t, err) require.Contains(t, err.Error(), "failed to save credentials") require.Equal(t, 1, repo.updateCalls) @@ -297,7 +301,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) { err: errors.New("refresh failed"), } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.Error(t, err) require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新 require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效 @@ -324,7 +328,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin err: errors.New("network error"), // 可重试错误 } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.Error(t, err) require.Equal(t, 0, repo.updateCalls) require.Equal(t, 0, invalidator.calls) @@ -351,7 +355,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te err: errors.New("invalid_grant: token revoked"), // 不可重试错误 } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.Error(t, err) require.Equal(t, 0, repo.updateCalls) require.Equal(t, 0, invalidator.calls) @@ -383,7 +387,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing }, } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) require.Equal(t, 1, repo.clearTempCalls) // DB 清除 @@ -422,7 +426,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t err: errors.New("invalid_grant: token revoked"), } - err := service.refreshWithRetry(context.Background(), account, refresher) + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.Error(t, err) require.Equal(t, 1, repo.setErrorCalls) // 所有平台不可重试错误都应 SetError }) @@ -453,3 +457,212 @@ func TestIsNonRetryableRefreshError(t *testing.T) { }) } } + +// ========== Path A (refreshAPI) 测试用例 ========== + +// mockTokenCacheForRefreshAPI 用于 Path A 测试的 GeminiTokenCache mock +type mockTokenCacheForRefreshAPI struct { + lockResult bool + lockErr error + releaseCalls int +} + +func (m *mockTokenCacheForRefreshAPI) GetAccessToken(_ context.Context, _ string) (string, error) { + return "", errors.New("not cached") +} + +func (m *mockTokenCacheForRefreshAPI) SetAccessToken(_ context.Context, _ string, _ string, _ time.Duration) error { + return nil +} + +func (m *mockTokenCacheForRefreshAPI) DeleteAccessToken(_ context.Context, _ string) error { + return nil +} + +func (m *mockTokenCacheForRefreshAPI) AcquireRefreshLock(_ context.Context, _ string, _ time.Duration) (bool, error) { + return m.lockResult, m.lockErr +} + +func (m *mockTokenCacheForRefreshAPI) ReleaseRefreshLock(_ context.Context, _ string) error { + m.releaseCalls++ + return nil +} + +// buildPathAService 构建注入了 refreshAPI 的 service(Path A 测试辅助) +func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, invalidator TokenCacheInvalidator) (*TokenRefreshService, *tokenRefresherStub) { + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + refreshAPI := NewOAuthRefreshAPI(repo, cache) + service.SetRefreshAPI(refreshAPI) + + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "refreshed-token", + }, + } + return service, refresher +} + +// TestPathA_Success 统一 API 路径正常成功:刷新 + DB 更新 + postRefreshActions +func TestPathA_Success(t *testing.T) { + account := &Account{ + ID: 100, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{} + repo.accountsByID = map[int64]*Account{account.ID: account} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: true} + + service, refresher := buildPathAService(repo, cache, invalidator) + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCalls) // DB 更新被调用 + require.Equal(t, 1, invalidator.calls) // 缓存失效被调用 + require.Equal(t, 1, cache.releaseCalls) // 锁被释放 +} + +// TestPathA_LockHeld 锁被其他 worker 持有 → 返回 errRefreshSkipped +func TestPathA_LockHeld(t *testing.T) { + account := &Account{ + ID: 101, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: false} // 锁获取失败(被占) + + service, refresher := buildPathAService(repo, cache, invalidator) + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.ErrorIs(t, err, errRefreshSkipped) + require.Equal(t, 0, repo.updateCalls) // 不应更新 DB + require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效 +} + +// TestPathA_AlreadyRefreshed 二次检查发现已被其他路径刷新 → 返回 errRefreshSkipped +func TestPathA_AlreadyRefreshed(t *testing.T) { + // NeedsRefresh 返回 false → RefreshIfNeeded 返回 {Refreshed: false} + account := &Account{ + ID: 102, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{} + repo.accountsByID = map[int64]*Account{account.ID: account} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: true} + + service, _ := buildPathAService(repo, cache, invalidator) + + // 使用一个 NeedsRefresh 返回 false 的 stub + noRefreshNeeded := &tokenRefresherStub{ + credentials: map[string]any{"access_token": "token"}, + } + // 覆盖 NeedsRefresh 行为 — 我们需要一个新的 stub 类型 + alwaysFreshStub := &alwaysFreshRefresherStub{} + + err := service.refreshWithRetry(context.Background(), account, noRefreshNeeded, alwaysFreshStub, time.Hour) + require.ErrorIs(t, err, errRefreshSkipped) + require.Equal(t, 0, repo.updateCalls) + require.Equal(t, 0, invalidator.calls) +} + +// alwaysFreshRefresherStub 二次检查时认为不需要刷新(模拟已被其他路径刷新) +type alwaysFreshRefresherStub struct{} + +func (r *alwaysFreshRefresherStub) CanRefresh(_ *Account) bool { return true } +func (r *alwaysFreshRefresherStub) NeedsRefresh(_ *Account, _ time.Duration) bool { return false } +func (r *alwaysFreshRefresherStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) { + return nil, errors.New("should not be called") +} +func (r *alwaysFreshRefresherStub) CacheKey(account *Account) string { + return "test:fresh:" + account.Platform +} + +// TestPathA_NonRetryableError 统一 API 路径返回不可重试错误 → SetError +func TestPathA_NonRetryableError(t *testing.T) { + account := &Account{ + ID: 103, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{} + repo.accountsByID = map[int64]*Account{account.ID: account} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: true} + + service, _ := buildPathAService(repo, cache, invalidator) + + refresher := &tokenRefresherStub{ + err: errors.New("invalid_grant: token revoked"), + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.Error(t, err) + require.Equal(t, 1, repo.setErrorCalls) // 应标记 error 状态 + require.Equal(t, 0, repo.updateCalls) // 不应更新 credentials + require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效 +} + +// TestPathA_RetryableErrorExhausted 统一 API 路径可重试错误耗尽 → 不标记 error +func TestPathA_RetryableErrorExhausted(t *testing.T) { + account := &Account{ + ID: 104, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{} + repo.accountsByID = map[int64]*Account{account.ID: account} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: true} + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 2, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil) + refreshAPI := NewOAuthRefreshAPI(repo, cache) + service.SetRefreshAPI(refreshAPI) + + refresher := &tokenRefresherStub{ + err: errors.New("network timeout"), + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.Error(t, err) + require.Equal(t, 0, repo.setErrorCalls) // 可重试错误不标记 error + require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新 + require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效 +} + +// TestPathA_DBUpdateFailed 统一 API 路径 DB 更新失败 → 返回 error,不执行 postRefreshActions +func TestPathA_DBUpdateFailed(t *testing.T) { + account := &Account{ + ID: 105, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + } + repo := &tokenRefreshAccountRepo{updateErr: errors.New("db connection lost")} + repo.accountsByID = map[int64]*Account{account.ID: account} + invalidator := &tokenCacheInvalidatorStub{} + cache := &mockTokenCacheForRefreshAPI{lockResult: true} + + service, refresher := buildPathAService(repo, cache, invalidator) + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.Error(t, err) + require.Contains(t, err.Error(), "DB update failed") + require.Equal(t, 1, repo.updateCalls) // DB 更新被尝试 + require.Equal(t, 0, invalidator.calls) // DB 失败时不应触发缓存失效 +} diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 0dd3cf45..5a214161 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -3,7 +3,6 @@ package service import ( "context" "log" - "strconv" "time" ) @@ -33,6 +32,11 @@ func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher { } } +// CacheKey 返回用于分布式锁的缓存键 +func (r *ClaudeTokenRefresher) CacheKey(account *Account) string { + return ClaudeTokenCacheKey(account) +} + // CanRefresh 检查是否能处理此账号 // 只处理 anthropic 平台的 oauth 类型账号 // setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新 @@ -59,24 +63,8 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m return nil, err } - // 保留现有credentials中的所有字段 - newCredentials := make(map[string]any) - for k, v := range account.Credentials { - newCredentials[k] = v - } - - // 只更新token相关字段 - // 注意:expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型 - newCredentials["access_token"] = tokenInfo.AccessToken - newCredentials["token_type"] = tokenInfo.TokenType - newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10) - newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) - if tokenInfo.RefreshToken != "" { - newCredentials["refresh_token"] = tokenInfo.RefreshToken - } - if tokenInfo.Scope != "" { - newCredentials["scope"] = tokenInfo.Scope - } + newCredentials := BuildClaudeAccountCredentials(tokenInfo) + newCredentials = MergeCredentials(account.Credentials, newCredentials) return newCredentials, nil } @@ -97,6 +85,11 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService, accountRepo } } +// CacheKey 返回用于分布式锁的缓存键 +func (r *OpenAITokenRefresher) CacheKey(account *Account) string { + return OpenAITokenCacheKey(account) +} + // SetSoraAccountRepo 设置 Sora 账号扩展表仓储 // 用于在 Token 刷新时同步更新 sora_accounts 表 // 如果未设置,syncLinkedSoraAccounts 只会更新 accounts.credentials @@ -137,13 +130,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m // 使用服务提供的方法构建新凭证,并保留原有字段 newCredentials := r.openaiOAuthService.BuildAccountCredentials(tokenInfo) - - // 保留原有credentials中非token相关字段 - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } + newCredentials = MergeCredentials(account.Credentials, newCredentials) // 异步同步关联的 Sora 账号(不阻塞主流程) if r.accountRepo != nil && r.syncLinkedSora { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 3d2d5d68..7da72630 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -51,16 +51,77 @@ func ProvideTokenRefreshService( tempUnschedCache TempUnschedCache, privacyClientFactory PrivacyClientFactory, proxyRepo ProxyRepository, + refreshAPI *OAuthRefreshAPI, ) *TokenRefreshService { svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache) // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 svc.SetSoraAccountRepo(soraAccountRepo) // 注入 OpenAI privacy opt-out 依赖 svc.SetPrivacyDeps(privacyClientFactory, proxyRepo) + // 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件) + svc.SetRefreshAPI(refreshAPI) + // 调用侧显式注入后台刷新策略,避免策略漂移 + svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy()) svc.Start() return svc } +// ProvideClaudeTokenProvider creates ClaudeTokenProvider with OAuthRefreshAPI injection +func ProvideClaudeTokenProvider( + accountRepo AccountRepository, + tokenCache GeminiTokenCache, + oauthService *OAuthService, + refreshAPI *OAuthRefreshAPI, +) *ClaudeTokenProvider { + p := NewClaudeTokenProvider(accountRepo, tokenCache, oauthService) + executor := NewClaudeTokenRefresher(oauthService) + p.SetRefreshAPI(refreshAPI, executor) + p.SetRefreshPolicy(ClaudeProviderRefreshPolicy()) + return p +} + +// ProvideOpenAITokenProvider creates OpenAITokenProvider with OAuthRefreshAPI injection +func ProvideOpenAITokenProvider( + accountRepo AccountRepository, + tokenCache GeminiTokenCache, + openaiOAuthService *OpenAIOAuthService, + refreshAPI *OAuthRefreshAPI, +) *OpenAITokenProvider { + p := NewOpenAITokenProvider(accountRepo, tokenCache, openaiOAuthService) + executor := NewOpenAITokenRefresher(openaiOAuthService, accountRepo) + p.SetRefreshAPI(refreshAPI, executor) + p.SetRefreshPolicy(OpenAIProviderRefreshPolicy()) + return p +} + +// ProvideGeminiTokenProvider creates GeminiTokenProvider with OAuthRefreshAPI injection +func ProvideGeminiTokenProvider( + accountRepo AccountRepository, + tokenCache GeminiTokenCache, + geminiOAuthService *GeminiOAuthService, + refreshAPI *OAuthRefreshAPI, +) *GeminiTokenProvider { + p := NewGeminiTokenProvider(accountRepo, tokenCache, geminiOAuthService) + executor := NewGeminiTokenRefresher(geminiOAuthService) + p.SetRefreshAPI(refreshAPI, executor) + p.SetRefreshPolicy(GeminiProviderRefreshPolicy()) + return p +} + +// ProvideAntigravityTokenProvider creates AntigravityTokenProvider with OAuthRefreshAPI injection +func ProvideAntigravityTokenProvider( + accountRepo AccountRepository, + tokenCache GeminiTokenCache, + antigravityOAuthService *AntigravityOAuthService, + refreshAPI *OAuthRefreshAPI, +) *AntigravityTokenProvider { + p := NewAntigravityTokenProvider(accountRepo, tokenCache, antigravityOAuthService) + executor := NewAntigravityTokenRefresher(antigravityOAuthService) + p.SetRefreshAPI(refreshAPI, executor) + p.SetRefreshPolicy(AntigravityProviderRefreshPolicy()) + return p +} + // ProvideDashboardAggregationService 创建并启动仪表盘聚合服务 func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService { svc := NewDashboardAggregationService(repo, timingWheel, cfg) @@ -375,11 +436,12 @@ var ProviderSet = wire.NewSet( NewCompositeTokenCacheInvalidator, wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)), NewAntigravityOAuthService, - NewGeminiTokenProvider, + NewOAuthRefreshAPI, + ProvideGeminiTokenProvider, NewGeminiMessagesCompatService, - NewAntigravityTokenProvider, - NewOpenAITokenProvider, - NewClaudeTokenProvider, + ProvideAntigravityTokenProvider, + ProvideOpenAITokenProvider, + ProvideClaudeTokenProvider, NewAntigravityGatewayService, ProvideRateLimitService, NewAccountUsageService, From 044d3a013d5ca8c9a7764a89b881ae08e97a322b Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 16 Mar 2026 01:38:06 +0800 Subject: [PATCH 2/2] fix: suppress SA4006 unused value warning in Path A branch --- backend/internal/service/token_refresh_service.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index cb00d5e5..cb8841b0 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -270,8 +270,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc return s.refreshPolicy.handleAlreadyRefreshed() } else { account = result.Account - newCredentials = result.NewCredentials - // 统一 API 已设置 _token_version 并更新 DB,无需重复操作 + _ = result.NewCredentials // 统一 API 已设置 _token_version 并更新 DB,无需重复操作 } } else { // 降级:直接调用 refresher(兼容旧路径)