Files
sub2api/backend/internal/service/token_refresh_service.go
erio 1fc9dd7b68 feat: unified OAuth token refresh API with distributed locking
Introduce OAuthRefreshAPI as the single entry point for all OAuth token
refresh operations, eliminating the race condition where background
refresh and inline refresh could simultaneously use the same
refresh_token (fixes #1035).

Key changes:
- Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey
- Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow
- Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types
- Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI
- Rewrite TokenRefreshService.refreshWithRetry to use unified API path
- Add MergeCredentials and BuildClaudeAccountCredentials helpers
- Add 40 unit tests covering all new and modified code paths
2026-03-16 01:31:54 +08:00

464 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token
type TokenRefreshService struct {
accountRepo AccountRepository
refreshers []TokenRefresher
executors []OAuthRefreshExecutor // 与 refreshers 一一对应的 executor带 CacheKey
refreshPolicy BackgroundRefreshPolicy
cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
refreshAPI *OAuthRefreshAPI // 统一刷新 API
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
privacyClientFactory PrivacyClientFactory
proxyRepo ProxyRepository
stopCh chan struct{}
wg sync.WaitGroup
}
// NewTokenRefreshService 创建token刷新服务
func NewTokenRefreshService(
accountRepo AccountRepository,
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache,
cfg *config.Config,
tempUnschedCache TempUnschedCache,
) *TokenRefreshService {
s := &TokenRefreshService{
accountRepo: accountRepo,
refreshPolicy: DefaultBackgroundRefreshPolicy(),
cfg: &cfg.TokenRefresh,
cacheInvalidator: cacheInvalidator,
schedulerCache: schedulerCache,
tempUnschedCache: tempUnschedCache,
stopCh: make(chan struct{}),
}
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
claudeRefresher := NewClaudeTokenRefresher(oauthService)
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService)
// 注册平台特定的刷新器TokenRefresher 接口)
s.refreshers = []TokenRefresher{
claudeRefresher,
openAIRefresher,
geminiRefresher,
agRefresher,
}
// 注册对应的 OAuthRefreshExecutor带 CacheKey 方法)
s.executors = []OAuthRefreshExecutor{
claudeRefresher,
openAIRefresher,
geminiRefresher,
agRefresher,
}
return s
}
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
// 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表
// 需要在 Start() 之前调用
func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
// 将 soraAccountRepo 注入到 OpenAITokenRefresher
for _, refresher := range s.refreshers {
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
openaiRefresher.SetSoraAccountRepo(repo)
}
}
}
// SetPrivacyDeps 注入 OpenAI privacy opt-out 所需依赖
func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxyRepo ProxyRepository) {
s.privacyClientFactory = factory
s.proxyRepo = proxyRepo
}
// SetRefreshAPI 注入统一的 OAuth 刷新 API
func (s *TokenRefreshService) SetRefreshAPI(api *OAuthRefreshAPI) {
s.refreshAPI = api
}
// SetRefreshPolicy 注入后台刷新调用侧策略(用于显式化平台/场景差异行为)。
func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) {
s.refreshPolicy = policy
}
// Start 启动后台刷新服务
func (s *TokenRefreshService) Start() {
if !s.cfg.Enabled {
slog.Info("token_refresh.service_disabled")
return
}
s.wg.Add(1)
go s.refreshLoop()
slog.Info("token_refresh.service_started",
"check_interval_minutes", s.cfg.CheckIntervalMinutes,
"refresh_before_expiry_hours", s.cfg.RefreshBeforeExpiryHours,
)
}
// Stop 停止刷新服务
func (s *TokenRefreshService) Stop() {
close(s.stopCh)
s.wg.Wait()
slog.Info("token_refresh.service_stopped")
}
// refreshLoop 刷新循环
func (s *TokenRefreshService) refreshLoop() {
defer s.wg.Done()
// 计算检查间隔
checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute
if checkInterval < time.Minute {
checkInterval = 5 * time.Minute
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
// 启动时立即执行一次检查
s.processRefresh()
for {
select {
case <-ticker.C:
s.processRefresh()
case <-s.stopCh:
return
}
}
}
// processRefresh 执行一次刷新检查
func (s *TokenRefreshService) processRefresh() {
ctx := context.Background()
// 计算刷新窗口
refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour))
// 获取所有active状态的账号
accounts, err := s.listActiveAccounts(ctx)
if err != nil {
slog.Error("token_refresh.list_accounts_failed", "error", err)
return
}
totalAccounts := len(accounts)
oauthAccounts := 0 // 可刷新的OAuth账号数
needsRefresh := 0 // 需要刷新的账号数
refreshed, failed, skipped := 0, 0, 0
for i := range accounts {
account := &accounts[i]
// 遍历所有刷新器,找到能处理此账号的
for idx, refresher := range s.refreshers {
if !refresher.CanRefresh(account) {
continue
}
oauthAccounts++
// 检查是否需要刷新
if !refresher.NeedsRefresh(account, refreshWindow) {
break // 不需要刷新,跳过
}
needsRefresh++
// 获取对应的 executor
var executor OAuthRefreshExecutor
if idx < len(s.executors) {
executor = s.executors[idx]
}
// 执行刷新
if err := s.refreshWithRetry(ctx, account, refresher, executor, refreshWindow); err != nil {
if errors.Is(err, errRefreshSkipped) {
skipped++
} else {
slog.Warn("token_refresh.account_refresh_failed",
"account_id", account.ID,
"account_name", account.Name,
"error", err,
)
failed++
}
} else {
slog.Info("token_refresh.account_refreshed",
"account_id", account.ID,
"account_name", account.Name,
)
refreshed++
}
// 每个账号只由一个refresher处理
break
}
}
// 无刷新活动时降级为 Debug有实际刷新活动时保持 Info
if needsRefresh == 0 && failed == 0 {
slog.Debug("token_refresh.cycle_completed",
"total", totalAccounts, "oauth", oauthAccounts,
"needs_refresh", needsRefresh, "refreshed", refreshed, "skipped", skipped, "failed", failed)
} else {
slog.Info("token_refresh.cycle_completed",
"total", totalAccounts,
"oauth", oauthAccounts,
"needs_refresh", needsRefresh,
"refreshed", refreshed,
"skipped", skipped,
"failed", failed,
)
}
}
// listActiveAccounts 获取所有active状态的账号
// 使用ListActive确保刷新所有活跃账号的token包括临时禁用的
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account, error) {
return s.accountRepo.ListActive(ctx)
}
// refreshWithRetry 带重试的刷新
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher, executor OAuthRefreshExecutor, refreshWindow time.Duration) error {
var lastErr error
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
var newCredentials map[string]any
var err error
// 优先使用统一 API带分布式锁 + DB 重读保护)
if s.refreshAPI != nil && executor != nil {
result, refreshErr := s.refreshAPI.RefreshIfNeeded(ctx, account, executor, refreshWindow)
if refreshErr != nil {
err = refreshErr
} else if result.LockHeld {
// 锁被其他 worker 持有,由调用侧策略决定如何计数
return s.refreshPolicy.handleLockHeld()
} else if !result.Refreshed {
// 已被其他路径刷新,由调用侧策略决定如何计数
return s.refreshPolicy.handleAlreadyRefreshed()
} else {
account = result.Account
newCredentials = result.NewCredentials
// 统一 API 已设置 _token_version 并更新 DB无需重复操作
}
} else {
// 降级:直接调用 refresher兼容旧路径
newCredentials, err = refresher.Refresh(ctx, account)
if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli()
account.Credentials = newCredentials
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
return fmt.Errorf("failed to save credentials: %w", saveErr)
}
}
}
if err == nil {
s.postRefreshActions(ctx, account)
return nil
}
// 不可重试错误invalid_grant/invalid_client 等)直接标记 error 状态并返回
if isNonRetryableRefreshError(err) {
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
slog.Error("token_refresh.set_error_status_failed",
"account_id", account.ID,
"error", setErr,
)
}
return err
}
lastErr = err
slog.Warn("token_refresh.retry_attempt_failed",
"account_id", account.ID,
"attempt", attempt,
"max_retries", s.cfg.MaxRetries,
"error", err,
)
// 如果还有重试机会,等待后重试
if attempt < s.cfg.MaxRetries {
// 指数退避2^(attempt-1) * baseSeconds
backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1))
time.Sleep(backoff)
}
}
// 可重试错误耗尽:仅记录日志,不标记 error可能是临时网络问题下个周期继续重试
slog.Warn("token_refresh.retry_exhausted",
"account_id", account.ID,
"platform", account.Platform,
"max_retries", s.cfg.MaxRetries,
"error", lastErr,
)
return lastErr
}
// postRefreshActions 刷新成功后的后续动作(清除错误状态、缓存失效、调度器同步等)
func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *Account) {
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error现在成功获取到了清除错误状态
if account.Platform == PlatformAntigravity &&
account.Status == StatusError &&
strings.Contains(account.ErrorMessage, "missing_project_id:") {
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_account_error_failed",
"account_id", account.ID,
"error", clearErr,
)
} else {
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
}
}
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_temp_unschedulable_failed",
"account_id", account.ID,
"error", clearErr,
)
} else {
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
}
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
if s.tempUnschedCache != nil {
if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_temp_unsched_cache_failed",
"account_id", account.ID,
"error", clearErr,
)
}
}
}
// 对所有 OAuth 账号调用缓存失效InvalidateToken 内部根据平台判断是否需要处理)
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
slog.Warn("token_refresh.invalidate_token_cache_failed",
"account_id", account.ID,
"error", err,
)
} else {
slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID)
}
}
// 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials
if s.schedulerCache != nil {
if err := s.schedulerCache.SetAccount(ctx, account); err != nil {
slog.Warn("token_refresh.sync_scheduler_cache_failed",
"account_id", account.ID,
"error", err,
)
} else {
slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID)
}
}
// OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode未设置则尝试关闭训练数据共享
s.ensureOpenAIPrivacy(ctx, account)
}
// errRefreshSkipped 表示刷新被跳过(锁竞争或已被其他路径刷新),不计入 failed 或 refreshed
var errRefreshSkipped = fmt.Errorf("refresh skipped")
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
// 注意missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
func isNonRetryableRefreshError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
nonRetryable := []string{
"invalid_grant", // refresh_token 已失效
"invalid_client", // 客户端配置错误
"unauthorized_client", // 客户端未授权
"access_denied", // 访问被拒绝
"missing_project_id", // 缺少 project_id
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
return true
}
}
return false
}
// ensureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode
// 未设置则调用 disableOpenAITraining 并持久化结果到 Extra。
func (s *TokenRefreshService) ensureOpenAIPrivacy(ctx context.Context, account *Account) {
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return
}
if s.privacyClientFactory == nil {
return
}
// 已设置过则跳过
if account.Extra != nil {
if _, ok := account.Extra["privacy_mode"]; ok {
return
}
}
token, _ := account.Credentials["access_token"].(string)
if token == "" {
return
}
var proxyURL string
if account.ProxyID != nil && s.proxyRepo != nil {
if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil {
proxyURL = p.URL()
}
}
mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL)
if mode == "" {
return
}
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}); err != nil {
slog.Warn("token_refresh.update_privacy_mode_failed",
"account_id", account.ID,
"error", err,
)
} else {
slog.Info("token_refresh.privacy_mode_set",
"account_id", account.ID,
"privacy_mode", mode,
)
}
}