mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-18 05:44:46 +08:00
Merge pull request #1042 from touwaeriol/feat/unified-oauth-refresh-api
feat: unified OAuth token refresh API with distributed locking
This commit is contained in:
@@ -124,6 +124,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||||
|
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||||
@@ -132,11 +133,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
usageCache := service.NewUsageCache()
|
usageCache := service.NewUsageCache()
|
||||||
identityCache := repository.NewIdentityCache(redisClient)
|
identityCache := repository.NewIdentityCache(redisClient)
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI)
|
||||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||||
@@ -166,10 +167,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||||
digestSessionStore := service.NewDigestSessionStore()
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
@@ -232,7 +233,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -17,15 +16,18 @@ const (
|
|||||||
antigravityBackfillCooldown = 5 * time.Minute
|
antigravityBackfillCooldown = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
// AntigravityTokenCache token cache interface.
|
||||||
type AntigravityTokenCache = GeminiTokenCache
|
type AntigravityTokenCache = GeminiTokenCache
|
||||||
|
|
||||||
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
|
// AntigravityTokenProvider manages access_token for antigravity accounts.
|
||||||
type AntigravityTokenProvider struct {
|
type AntigravityTokenProvider struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache AntigravityTokenCache
|
tokenCache AntigravityTokenCache
|
||||||
antigravityOAuthService *AntigravityOAuthService
|
antigravityOAuthService *AntigravityOAuthService
|
||||||
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
|
backfillCooldown sync.Map // key: accountID -> last attempt time
|
||||||
|
refreshAPI *OAuthRefreshAPI
|
||||||
|
executor OAuthRefreshExecutor
|
||||||
|
refreshPolicy ProviderRefreshPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAntigravityTokenProvider(
|
func NewAntigravityTokenProvider(
|
||||||
@@ -37,10 +39,22 @@ func NewAntigravityTokenProvider(
|
|||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
tokenCache: tokenCache,
|
tokenCache: tokenCache,
|
||||||
antigravityOAuthService: antigravityOAuthService,
|
antigravityOAuthService: antigravityOAuthService,
|
||||||
|
refreshPolicy: AntigravityProviderRefreshPolicy(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccessToken 获取有效的 access_token
|
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||||
|
func (p *AntigravityTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||||
|
p.refreshAPI = api
|
||||||
|
p.executor = executor
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRefreshPolicy injects caller-side refresh policy.
|
||||||
|
func (p *AntigravityTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||||
|
p.refreshPolicy = policy
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessToken returns a valid access_token.
|
||||||
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
@@ -48,7 +62,8 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
if account.Platform != PlatformAntigravity {
|
if account.Platform != PlatformAntigravity {
|
||||||
return "", errors.New("not an antigravity account")
|
return "", errors.New("not an antigravity account")
|
||||||
}
|
}
|
||||||
// upstream 类型:直接从 credentials 读取 api_key,不走 OAuth 刷新流程
|
|
||||||
|
// upstream accounts use static api_key and never refresh oauth token.
|
||||||
if account.Type == AccountTypeUpstream {
|
if account.Type == AccountTypeUpstream {
|
||||||
apiKey := account.GetCredential("api_key")
|
apiKey := account.GetCredential("api_key")
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
@@ -62,46 +77,38 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
|
|
||||||
cacheKey := AntigravityTokenCacheKey(account)
|
cacheKey := AntigravityTokenCacheKey(account)
|
||||||
|
|
||||||
// 1. 先尝试缓存
|
// 1) Try cache first.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 如果即将过期则刷新
|
// 2) Refresh if needed (pre-expiry skew).
|
||||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
|
||||||
if needsRefresh && p.tokenCache != nil {
|
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||||
|
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, antigravityTokenRefreshSkew)
|
||||||
|
if err != nil {
|
||||||
|
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else if result.LockHeld {
|
||||||
|
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||||
|
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// default policy: continue with existing token.
|
||||||
|
} else {
|
||||||
|
account = result.Account
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
} else if needsRefresh && p.tokenCache != nil {
|
||||||
|
// Backward-compatible test path when refreshAPI is not injected.
|
||||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
if err == nil && locked {
|
if err == nil && locked {
|
||||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
|
||||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从数据库获取最新账户信息
|
|
||||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
|
||||||
if err == nil && fresh != nil {
|
|
||||||
account = fresh
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
|
|
||||||
if p.antigravityOAuthService == nil {
|
|
||||||
return "", errors.New("antigravity oauth service not configured")
|
|
||||||
}
|
|
||||||
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
p.mergeCredentials(account, tokenInfo)
|
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
|
||||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,32 +117,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return "", errors.New("access_token not found in credentials")
|
return "", errors.New("access_token not found in credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现
|
// Backfill project_id online when missing, with cooldown to avoid hammering.
|
||||||
// "Invalid project resource name projects/"。
|
|
||||||
// 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。
|
|
||||||
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
||||||
if p.shouldAttemptBackfill(account.ID) {
|
if p.shouldAttemptBackfill(account.ID) {
|
||||||
p.markBackfillAttempted(account.ID)
|
p.markBackfillAttempted(account.ID)
|
||||||
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
||||||
account.Credentials["project_id"] = projectID
|
account.Credentials["project_id"] = projectID
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
|
slog.Warn("antigravity_project_id_backfill_persist_failed",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"error", updateErr,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
// 3) Populate cache with TTL.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
if isStale && latestAccount != nil {
|
if isStale && latestAccount != nil {
|
||||||
// 版本过时,使用 DB 中的最新 token
|
|
||||||
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
|
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
|
||||||
accessToken = latestAccount.GetCredential("access_token")
|
accessToken = latestAccount.GetCredential("access_token")
|
||||||
if strings.TrimSpace(accessToken) == "" {
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
return "", errors.New("access_token not found after version check")
|
return "", errors.New("access_token not found after version check")
|
||||||
}
|
}
|
||||||
// 不写入缓存,让下次请求重新处理
|
|
||||||
} else {
|
} else {
|
||||||
ttl := 30 * time.Minute
|
ttl := 30 * time.Minute
|
||||||
if expiresAt != nil {
|
if expiresAt != nil {
|
||||||
@@ -156,18 +162,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
|
// shouldAttemptBackfill checks backfill cooldown.
|
||||||
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
|
|
||||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
}
|
|
||||||
|
|
||||||
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试)
|
|
||||||
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
|
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
|
||||||
if v, ok := p.backfillCooldown.Load(accountID); ok {
|
if v, ok := p.backfillCooldown.Load(accountID); ok {
|
||||||
if lastAttempt, ok := v.(time.Time); ok {
|
if lastAttempt, ok := v.(time.Time); ok {
|
||||||
|
|||||||
@@ -25,6 +25,11 @@ func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthServi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CacheKey 返回用于分布式锁的缓存键
|
||||||
|
func (r *AntigravityTokenRefresher) CacheKey(account *Account) string {
|
||||||
|
return AntigravityTokenCacheKey(account)
|
||||||
|
}
|
||||||
|
|
||||||
// CanRefresh 检查是否可以刷新此账户
|
// CanRefresh 检查是否可以刷新此账户
|
||||||
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
|
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
|
||||||
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
|
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
|
||||||
@@ -58,11 +63,7 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
|
|||||||
|
|
||||||
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
// 合并旧的 credentials,保留新 credentials 中不存在的字段
|
// 合并旧的 credentials,保留新 credentials 中不存在的字段
|
||||||
for k, v := range account.Credentials {
|
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -15,14 +14,17 @@ const (
|
|||||||
claudeLockWaitTime = 200 * time.Millisecond
|
claudeLockWaitTime = 200 * time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
// ClaudeTokenCache token cache interface.
|
||||||
type ClaudeTokenCache = GeminiTokenCache
|
type ClaudeTokenCache = GeminiTokenCache
|
||||||
|
|
||||||
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
|
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
|
||||||
type ClaudeTokenProvider struct {
|
type ClaudeTokenProvider struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache ClaudeTokenCache
|
tokenCache ClaudeTokenCache
|
||||||
oauthService *OAuthService
|
oauthService *OAuthService
|
||||||
|
refreshAPI *OAuthRefreshAPI
|
||||||
|
executor OAuthRefreshExecutor
|
||||||
|
refreshPolicy ProviderRefreshPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClaudeTokenProvider(
|
func NewClaudeTokenProvider(
|
||||||
@@ -31,13 +33,25 @@ func NewClaudeTokenProvider(
|
|||||||
oauthService *OAuthService,
|
oauthService *OAuthService,
|
||||||
) *ClaudeTokenProvider {
|
) *ClaudeTokenProvider {
|
||||||
return &ClaudeTokenProvider{
|
return &ClaudeTokenProvider{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
tokenCache: tokenCache,
|
tokenCache: tokenCache,
|
||||||
oauthService: oauthService,
|
oauthService: oauthService,
|
||||||
|
refreshPolicy: ClaudeProviderRefreshPolicy(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccessToken 获取有效的 access_token
|
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||||
|
func (p *ClaudeTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||||
|
p.refreshAPI = api
|
||||||
|
p.executor = executor
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRefreshPolicy injects caller-side refresh policy.
|
||||||
|
func (p *ClaudeTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||||
|
p.refreshPolicy = policy
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessToken returns a valid access_token.
|
||||||
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
@@ -48,7 +62,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
|
|
||||||
cacheKey := ClaudeTokenCacheKey(account)
|
cacheKey := ClaudeTokenCacheKey(account)
|
||||||
|
|
||||||
// 1. 先尝试缓存
|
// 1) Try cache first.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
|
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
|
||||||
@@ -60,114 +74,39 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
|
|
||||||
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
|
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
|
||||||
|
|
||||||
// 2. 如果即将过期则刷新
|
// 2) Refresh if needed (pre-expiry skew).
|
||||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||||
refreshFailed := false
|
refreshFailed := false
|
||||||
if needsRefresh && p.tokenCache != nil {
|
|
||||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
|
||||||
if lockErr == nil && locked {
|
|
||||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
|
||||||
|
|
||||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, claudeTokenRefreshSkew)
|
||||||
return token, nil
|
if err != nil {
|
||||||
|
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||||
|
return "", err
|
||||||
}
|
}
|
||||||
|
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||||
// 从数据库获取最新账户信息
|
refreshFailed = true
|
||||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
} else if result.LockHeld {
|
||||||
if err == nil && fresh != nil {
|
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||||
account = fresh
|
time.Sleep(claudeLockWaitTime)
|
||||||
}
|
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
return token, nil
|
||||||
if p.oauthService == nil {
|
|
||||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
|
||||||
refreshFailed = true // 无法刷新,标记失败
|
|
||||||
} else {
|
|
||||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
|
||||||
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
|
||||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
|
||||||
} else {
|
|
||||||
// 构建新 credentials,保留原有字段
|
|
||||||
newCredentials := make(map[string]any)
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
|
||||||
newCredentials["token_type"] = tokenInfo.TokenType
|
|
||||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
|
||||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
|
||||||
if tokenInfo.RefreshToken != "" {
|
|
||||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
|
||||||
}
|
|
||||||
if tokenInfo.Scope != "" {
|
|
||||||
newCredentials["scope"] = tokenInfo.Scope
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
|
||||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if lockErr != nil {
|
|
||||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
|
||||||
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
|
||||||
|
|
||||||
// 检查 ctx 是否已取消
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return "", ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从数据库获取最新账户信息
|
|
||||||
if p.accountRepo != nil {
|
|
||||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
|
||||||
if err == nil && fresh != nil {
|
|
||||||
account = fresh
|
|
||||||
}
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
|
|
||||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
|
||||||
if p.oauthService == nil {
|
|
||||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
|
||||||
refreshFailed = true
|
|
||||||
} else {
|
|
||||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
|
||||||
refreshFailed = true
|
|
||||||
} else {
|
|
||||||
// 构建新 credentials,保留原有字段
|
|
||||||
newCredentials := make(map[string]any)
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
|
||||||
newCredentials["token_type"] = tokenInfo.TokenType
|
|
||||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
|
||||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
|
||||||
if tokenInfo.RefreshToken != "" {
|
|
||||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
|
||||||
}
|
|
||||||
if tokenInfo.Scope != "" {
|
|
||||||
newCredentials["scope"] = tokenInfo.Scope
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
|
||||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
account = result.Account
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
} else if needsRefresh && p.tokenCache != nil {
|
||||||
|
// Backward-compatible test path when refreshAPI is not injected.
|
||||||
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
|
if lockErr == nil && locked {
|
||||||
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
} else if lockErr != nil {
|
||||||
|
slog.Warn("claude_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||||
|
} else {
|
||||||
time.Sleep(claudeLockWaitTime)
|
time.Sleep(claudeLockWaitTime)
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||||
@@ -181,22 +120,23 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return "", errors.New("access_token not found in credentials")
|
return "", errors.New("access_token not found in credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
// 3) Populate cache with TTL.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
if isStale && latestAccount != nil {
|
if isStale && latestAccount != nil {
|
||||||
// 版本过时,使用 DB 中的最新 token
|
|
||||||
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
|
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
|
||||||
accessToken = latestAccount.GetCredential("access_token")
|
accessToken = latestAccount.GetCredential("access_token")
|
||||||
if strings.TrimSpace(accessToken) == "" {
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
return "", errors.New("access_token not found after version check")
|
return "", errors.New("access_token not found after version check")
|
||||||
}
|
}
|
||||||
// 不写入缓存,让下次请求重新处理
|
|
||||||
} else {
|
} else {
|
||||||
ttl := 30 * time.Minute
|
ttl := 30 * time.Minute
|
||||||
if refreshFailed {
|
if refreshFailed {
|
||||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
if p.refreshPolicy.FailureTTL > 0 {
|
||||||
ttl = time.Minute
|
ttl = p.refreshPolicy.FailureTTL
|
||||||
|
} else {
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||||
} else if expiresAt != nil {
|
} else if expiresAt != nil {
|
||||||
until := time.Until(*expiresAt)
|
until := time.Until(*expiresAt)
|
||||||
|
|||||||
@@ -15,10 +15,14 @@ const (
|
|||||||
geminiTokenCacheSkew = 5 * time.Minute
|
geminiTokenCacheSkew = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
|
||||||
type GeminiTokenProvider struct {
|
type GeminiTokenProvider struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache GeminiTokenCache
|
tokenCache GeminiTokenCache
|
||||||
geminiOAuthService *GeminiOAuthService
|
geminiOAuthService *GeminiOAuthService
|
||||||
|
refreshAPI *OAuthRefreshAPI
|
||||||
|
executor OAuthRefreshExecutor
|
||||||
|
refreshPolicy ProviderRefreshPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGeminiTokenProvider(
|
func NewGeminiTokenProvider(
|
||||||
@@ -30,9 +34,21 @@ func NewGeminiTokenProvider(
|
|||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
tokenCache: tokenCache,
|
tokenCache: tokenCache,
|
||||||
geminiOAuthService: geminiOAuthService,
|
geminiOAuthService: geminiOAuthService,
|
||||||
|
refreshPolicy: GeminiProviderRefreshPolicy(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||||
|
func (p *GeminiTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||||
|
p.refreshAPI = api
|
||||||
|
p.executor = executor
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRefreshPolicy injects caller-side refresh policy.
|
||||||
|
func (p *GeminiTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||||
|
p.refreshPolicy = policy
|
||||||
|
}
|
||||||
|
|
||||||
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
@@ -53,39 +69,31 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
// 2) Refresh if needed (pre-expiry skew).
|
// 2) Refresh if needed (pre-expiry skew).
|
||||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew
|
||||||
if needsRefresh && p.tokenCache != nil {
|
|
||||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
|
||||||
if err == nil && locked {
|
|
||||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
|
||||||
|
|
||||||
// Re-check after lock (another worker may have refreshed).
|
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, geminiTokenRefreshSkew)
|
||||||
return token, nil
|
if err != nil {
|
||||||
|
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||||
|
return "", err
|
||||||
}
|
}
|
||||||
|
} else if result.LockHeld {
|
||||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||||
if err == nil && fresh != nil {
|
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||||
account = fresh
|
return token, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
slog.Debug("gemini_token_lock_held_use_old", "account_id", account.ID)
|
||||||
|
} else {
|
||||||
|
account = result.Account
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew {
|
}
|
||||||
if p.geminiOAuthService == nil {
|
} else if needsRefresh && p.tokenCache != nil {
|
||||||
return "", errors.New("gemini oauth service not configured")
|
// Backward-compatible test path when refreshAPI is not injected.
|
||||||
}
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
tokenInfo, err := p.geminiOAuthService.RefreshAccountToken(ctx, account)
|
if lockErr == nil && locked {
|
||||||
if err != nil {
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
return "", err
|
} else if lockErr != nil {
|
||||||
}
|
slog.Warn("gemini_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||||
newCredentials := p.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
_ = p.accountRepo.Update(ctx, account)
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,15 +103,14 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
// project_id is optional now:
|
// project_id is optional now:
|
||||||
// - If present: will use Code Assist API (requires project_id)
|
// - If present: use Code Assist API (requires project_id)
|
||||||
// - If absent: will use AI Studio API with OAuth token (like regular API key mode)
|
// - If absent: use AI Studio API with OAuth token.
|
||||||
// Auto-detect project_id only if explicitly enabled via a credential flag
|
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true"
|
autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true"
|
||||||
|
|
||||||
if projectID == "" && autoDetectProjectID {
|
if projectID == "" && autoDetectProjectID {
|
||||||
if p.geminiOAuthService == nil {
|
if p.geminiOAuthService == nil {
|
||||||
return accessToken, nil // Fallback to AI Studio API mode
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var proxyURL string
|
var proxyURL string
|
||||||
@@ -132,17 +139,15 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
// 3) Populate cache with TTL.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
if isStale && latestAccount != nil {
|
if isStale && latestAccount != nil {
|
||||||
// 版本过时,使用 DB 中的最新 token
|
|
||||||
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
|
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
|
||||||
accessToken = latestAccount.GetCredential("access_token")
|
accessToken = latestAccount.GetCredential("access_token")
|
||||||
if strings.TrimSpace(accessToken) == "" {
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
return "", errors.New("access_token not found after version check")
|
return "", errors.New("access_token not found after version check")
|
||||||
}
|
}
|
||||||
// 不写入缓存,让下次请求重新处理
|
|
||||||
} else {
|
} else {
|
||||||
ttl := 30 * time.Minute
|
ttl := 30 * time.Minute
|
||||||
if expiresAt != nil {
|
if expiresAt != nil {
|
||||||
|
|||||||
@@ -13,6 +13,11 @@ func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiToke
|
|||||||
return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService}
|
return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CacheKey 返回用于分布式锁的缓存键
|
||||||
|
func (r *GeminiTokenRefresher) CacheKey(account *Account) string {
|
||||||
|
return GeminiTokenCacheKey(account)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool {
|
func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool {
|
||||||
return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth
|
return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth
|
||||||
}
|
}
|
||||||
@@ -35,11 +40,7 @@ func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
|||||||
}
|
}
|
||||||
|
|
||||||
newCredentials := r.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials := r.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
for k, v := range account.Credentials {
|
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return newCredentials, nil
|
return newCredentials, nil
|
||||||
}
|
}
|
||||||
|
|||||||
159
backend/internal/service/oauth_refresh_api.go
Normal file
159
backend/internal/service/oauth_refresh_api.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
|
||||||
|
// TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁
|
||||||
|
type OAuthRefreshExecutor interface {
|
||||||
|
TokenRefresher
|
||||||
|
|
||||||
|
// CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致)
|
||||||
|
CacheKey(account *Account) string
|
||||||
|
}
|
||||||
|
|
||||||
|
const refreshLockTTL = 30 * time.Second
|
||||||
|
|
||||||
|
// OAuthRefreshResult 统一刷新结果
|
||||||
|
type OAuthRefreshResult struct {
|
||||||
|
Refreshed bool // 实际执行了刷新
|
||||||
|
NewCredentials map[string]any // 刷新后的 credentials(nil 表示未刷新)
|
||||||
|
Account *Account // 从 DB 重新读取的最新 account
|
||||||
|
LockHeld bool // 锁被其他 worker 持有(未执行刷新)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
|
||||||
|
// 封装分布式锁、DB 重读、已刷新检查等通用逻辑
|
||||||
|
type OAuthRefreshAPI struct {
|
||||||
|
accountRepo AccountRepository
|
||||||
|
tokenCache GeminiTokenCache // 可选,nil = 无锁
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOAuthRefreshAPI 创建统一刷新 API
|
||||||
|
func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
|
||||||
|
return &OAuthRefreshAPI{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
tokenCache: tokenCache,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
|
||||||
|
//
|
||||||
|
// 流程:
|
||||||
|
// 1. 获取分布式锁
|
||||||
|
// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token)
|
||||||
|
// 3. 二次检查是否仍需刷新
|
||||||
|
// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑
|
||||||
|
// 5. 设置 _token_version + 更新 DB
|
||||||
|
// 6. 释放锁
|
||||||
|
func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
||||||
|
ctx context.Context,
|
||||||
|
account *Account,
|
||||||
|
executor OAuthRefreshExecutor,
|
||||||
|
refreshWindow time.Duration,
|
||||||
|
) (*OAuthRefreshResult, error) {
|
||||||
|
cacheKey := executor.CacheKey(account)
|
||||||
|
|
||||||
|
// 1. 获取分布式锁
|
||||||
|
lockAcquired := false
|
||||||
|
if api.tokenCache != nil {
|
||||||
|
acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL)
|
||||||
|
if lockErr != nil {
|
||||||
|
// Redis 错误,降级为无锁刷新
|
||||||
|
slog.Warn("oauth_refresh_lock_failed_degraded",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"cache_key", cacheKey,
|
||||||
|
"error", lockErr,
|
||||||
|
)
|
||||||
|
} else if !acquired {
|
||||||
|
// 锁被其他 worker 持有
|
||||||
|
return &OAuthRefreshResult{LockHeld: true}, nil
|
||||||
|
} else {
|
||||||
|
lockAcquired = true
|
||||||
|
defer func() { _ = api.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 从 DB 重读最新 account(锁保护下,确保使用最新的 refresh_token)
|
||||||
|
freshAccount, err := api.accountRepo.GetByID(ctx, account.ID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("oauth_refresh_db_reread_failed",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
// 降级使用传入的 account
|
||||||
|
freshAccount = account
|
||||||
|
} else if freshAccount == nil {
|
||||||
|
freshAccount = account
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 二次检查是否仍需刷新(另一条路径可能已刷新)
|
||||||
|
if !executor.NeedsRefresh(freshAccount, refreshWindow) {
|
||||||
|
return &OAuthRefreshResult{
|
||||||
|
Account: freshAccount,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 执行平台特定刷新逻辑
|
||||||
|
newCredentials, refreshErr := executor.Refresh(ctx, freshAccount)
|
||||||
|
if refreshErr != nil {
|
||||||
|
return nil, refreshErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 设置版本号 + 更新 DB
|
||||||
|
if newCredentials != nil {
|
||||||
|
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||||
|
freshAccount.Credentials = newCredentials
|
||||||
|
if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil {
|
||||||
|
slog.Error("oauth_refresh_update_failed",
|
||||||
|
"account_id", freshAccount.ID,
|
||||||
|
"error", updateErr,
|
||||||
|
)
|
||||||
|
return nil, fmt.Errorf("oauth refresh succeeded but DB update failed: %w", updateErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = lockAcquired // suppress unused warning when tokenCache is nil
|
||||||
|
|
||||||
|
return &OAuthRefreshResult{
|
||||||
|
Refreshed: true,
|
||||||
|
NewCredentials: newCredentials,
|
||||||
|
Account: freshAccount,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中
|
||||||
|
func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any {
|
||||||
|
if newCreds == nil {
|
||||||
|
newCreds = make(map[string]any)
|
||||||
|
}
|
||||||
|
for k, v := range oldCreds {
|
||||||
|
if _, exists := newCreds[k]; !exists {
|
||||||
|
newCreds[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newCreds
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map
|
||||||
|
// 消除 Claude 平台没有 BuildAccountCredentials 方法的问题
|
||||||
|
func BuildClaudeAccountCredentials(tokenInfo *TokenInfo) map[string]any {
|
||||||
|
creds := map[string]any{
|
||||||
|
"access_token": tokenInfo.AccessToken,
|
||||||
|
"token_type": tokenInfo.TokenType,
|
||||||
|
"expires_in": strconv.FormatInt(tokenInfo.ExpiresIn, 10),
|
||||||
|
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
|
||||||
|
}
|
||||||
|
if tokenInfo.RefreshToken != "" {
|
||||||
|
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||||
|
}
|
||||||
|
if tokenInfo.Scope != "" {
|
||||||
|
creds["scope"] = tokenInfo.Scope
|
||||||
|
}
|
||||||
|
return creds
|
||||||
|
}
|
||||||
395
backend/internal/service/oauth_refresh_api_test.go
Normal file
395
backend/internal/service/oauth_refresh_api_test.go
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------- mock helpers ----------
|
||||||
|
|
||||||
|
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
|
||||||
|
type refreshAPIAccountRepo struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
account *Account // returned by GetByID
|
||||||
|
getByIDErr error
|
||||||
|
updateErr error
|
||||||
|
updateCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
|
||||||
|
if r.getByIDErr != nil {
|
||||||
|
return nil, r.getByIDErr
|
||||||
|
}
|
||||||
|
return r.account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
|
||||||
|
r.updateCalls++
|
||||||
|
return r.updateErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
|
||||||
|
type refreshAPIExecutorStub struct {
|
||||||
|
needsRefresh bool
|
||||||
|
credentials map[string]any
|
||||||
|
err error
|
||||||
|
refreshCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *refreshAPIExecutorStub) CanRefresh(_ *Account) bool { return true }
|
||||||
|
|
||||||
|
func (e *refreshAPIExecutorStub) NeedsRefresh(_ *Account, _ time.Duration) bool {
|
||||||
|
return e.needsRefresh
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *refreshAPIExecutorStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
|
||||||
|
e.refreshCalls++
|
||||||
|
if e.err != nil {
|
||||||
|
return nil, e.err
|
||||||
|
}
|
||||||
|
return e.credentials, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *refreshAPIExecutorStub) CacheKey(account *Account) string {
|
||||||
|
return "test:api:" + account.Platform
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshAPICacheStub implements GeminiTokenCache for OAuthRefreshAPI tests.
|
||||||
|
type refreshAPICacheStub struct {
|
||||||
|
lockResult bool
|
||||||
|
lockErr error
|
||||||
|
releaseCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *refreshAPICacheStub) GetAccessToken(context.Context, string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *refreshAPICacheStub) SetAccessToken(context.Context, string, string, time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *refreshAPICacheStub) DeleteAccessToken(context.Context, string) error { return nil }
|
||||||
|
|
||||||
|
func (c *refreshAPICacheStub) AcquireRefreshLock(context.Context, string, time.Duration) (bool, error) {
|
||||||
|
return c.lockResult, c.lockErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *refreshAPICacheStub) ReleaseRefreshLock(context.Context, string) error {
|
||||||
|
c.releaseCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== RefreshIfNeeded tests ==========
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_Success(t *testing.T) {
|
||||||
|
account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
|
||||||
|
repo := &refreshAPIAccountRepo{account: account}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
executor := &refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
credentials: map[string]any{"access_token": "new-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result.Refreshed)
|
||||||
|
require.NotNil(t, result.NewCredentials)
|
||||||
|
require.Equal(t, "new-token", result.NewCredentials["access_token"])
|
||||||
|
require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set
|
||||||
|
require.Equal(t, 1, repo.updateCalls) // DB updated
|
||||||
|
require.Equal(t, 1, cache.releaseCalls) // lock released
|
||||||
|
require.Equal(t, 1, executor.refreshCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_LockHeld(t *testing.T) {
|
||||||
|
account := &Account{ID: 2, Platform: PlatformAnthropic}
|
||||||
|
repo := &refreshAPIAccountRepo{account: account}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: false} // lock not acquired
|
||||||
|
executor := &refreshAPIExecutorStub{needsRefresh: true}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result.LockHeld)
|
||||||
|
require.False(t, result.Refreshed)
|
||||||
|
require.Equal(t, 0, repo.updateCalls)
|
||||||
|
require.Equal(t, 0, executor.refreshCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_LockErrorDegrades(t *testing.T) {
|
||||||
|
account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||||
|
repo := &refreshAPIAccountRepo{account: account}
|
||||||
|
cache := &refreshAPICacheStub{lockErr: errors.New("redis down")} // lock error
|
||||||
|
executor := &refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
credentials: map[string]any{"access_token": "degraded-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result.Refreshed) // still refreshed (degraded mode)
|
||||||
|
require.Equal(t, 1, repo.updateCalls) // DB updated
|
||||||
|
require.Equal(t, 0, cache.releaseCalls) // no lock to release
|
||||||
|
require.Equal(t, 1, executor.refreshCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_NoCacheNoLock(t *testing.T) {
|
||||||
|
account := &Account{ID: 4, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||||
|
repo := &refreshAPIAccountRepo{account: account}
|
||||||
|
executor := &refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
credentials: map[string]any{"access_token": "no-cache-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, nil) // no cache = no lock
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result.Refreshed)
|
||||||
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_AlreadyRefreshed(t *testing.T) {
|
||||||
|
account := &Account{ID: 5, Platform: PlatformAnthropic}
|
||||||
|
repo := &refreshAPIAccountRepo{account: account}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
executor := &refreshAPIExecutorStub{needsRefresh: false} // already refreshed
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, result.Refreshed)
|
||||||
|
require.False(t, result.LockHeld)
|
||||||
|
require.NotNil(t, result.Account) // returns fresh account
|
||||||
|
require.Equal(t, 0, repo.updateCalls)
|
||||||
|
require.Equal(t, 0, executor.refreshCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_RefreshError(t *testing.T) {
|
||||||
|
account := &Account{ID: 6, Platform: PlatformAnthropic}
|
||||||
|
repo := &refreshAPIAccountRepo{account: account}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
executor := &refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
err: errors.New("invalid_grant: token revoked"),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Contains(t, err.Error(), "invalid_grant")
|
||||||
|
require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error
|
||||||
|
require.Equal(t, 1, cache.releaseCalls) // lock still released via defer
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_DBUpdateError(t *testing.T) {
|
||||||
|
account := &Account{ID: 7, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||||
|
repo := &refreshAPIAccountRepo{
|
||||||
|
account: account,
|
||||||
|
updateErr: errors.New("db connection lost"),
|
||||||
|
}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
executor := &refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
credentials: map[string]any{"access_token": "token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Contains(t, err.Error(), "DB update failed")
|
||||||
|
require.Equal(t, 1, repo.updateCalls) // attempted
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_DBRereadFails(t *testing.T) {
|
||||||
|
account := &Account{ID: 8, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
|
||||||
|
repo := &refreshAPIAccountRepo{
|
||||||
|
account: nil, // GetByID returns nil
|
||||||
|
getByIDErr: errors.New("db timeout"),
|
||||||
|
}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
executor := &refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
credentials: map[string]any{"access_token": "fallback-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result.Refreshed)
|
||||||
|
require.Equal(t, 1, executor.refreshCalls) // still refreshes using passed-in account
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_NilCredentials(t *testing.T) {
|
||||||
|
account := &Account{ID: 9, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||||
|
repo := &refreshAPIAccountRepo{account: account}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
executor := &refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
credentials: nil, // Refresh returns nil credentials
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result.Refreshed)
|
||||||
|
require.Nil(t, result.NewCredentials)
|
||||||
|
require.Equal(t, 0, repo.updateCalls) // no DB update when credentials are nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== MergeCredentials tests ==========
|
||||||
|
|
||||||
|
func TestMergeCredentials_Basic(t *testing.T) {
|
||||||
|
old := map[string]any{"a": "1", "b": "2", "c": "3"}
|
||||||
|
new := map[string]any{"a": "new", "d": "4"}
|
||||||
|
|
||||||
|
result := MergeCredentials(old, new)
|
||||||
|
|
||||||
|
require.Equal(t, "new", result["a"]) // new value preserved
|
||||||
|
require.Equal(t, "2", result["b"]) // old value kept
|
||||||
|
require.Equal(t, "3", result["c"]) // old value kept
|
||||||
|
require.Equal(t, "4", result["d"]) // new value preserved
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCredentials_NilNew(t *testing.T) {
|
||||||
|
old := map[string]any{"a": "1"}
|
||||||
|
|
||||||
|
result := MergeCredentials(old, nil)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "1", result["a"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCredentials_NilOld(t *testing.T) {
|
||||||
|
new := map[string]any{"a": "1"}
|
||||||
|
|
||||||
|
result := MergeCredentials(nil, new)
|
||||||
|
|
||||||
|
require.Equal(t, "1", result["a"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCredentials_BothNil(t *testing.T) {
|
||||||
|
result := MergeCredentials(nil, nil)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Empty(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCredentials_NewOverridesOld(t *testing.T) {
|
||||||
|
old := map[string]any{"access_token": "old-token", "refresh_token": "old-refresh"}
|
||||||
|
new := map[string]any{"access_token": "new-token"}
|
||||||
|
|
||||||
|
result := MergeCredentials(old, new)
|
||||||
|
|
||||||
|
require.Equal(t, "new-token", result["access_token"]) // overridden
|
||||||
|
require.Equal(t, "old-refresh", result["refresh_token"]) // preserved
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== BuildClaudeAccountCredentials tests ==========
|
||||||
|
|
||||||
|
func TestBuildClaudeAccountCredentials_Full(t *testing.T) {
|
||||||
|
tokenInfo := &TokenInfo{
|
||||||
|
AccessToken: "at-123",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
ExpiresAt: 1700000000,
|
||||||
|
RefreshToken: "rt-456",
|
||||||
|
Scope: "openid",
|
||||||
|
}
|
||||||
|
|
||||||
|
creds := BuildClaudeAccountCredentials(tokenInfo)
|
||||||
|
|
||||||
|
require.Equal(t, "at-123", creds["access_token"])
|
||||||
|
require.Equal(t, "Bearer", creds["token_type"])
|
||||||
|
require.Equal(t, "3600", creds["expires_in"])
|
||||||
|
require.Equal(t, "1700000000", creds["expires_at"])
|
||||||
|
require.Equal(t, "rt-456", creds["refresh_token"])
|
||||||
|
require.Equal(t, "openid", creds["scope"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeAccountCredentials_Minimal(t *testing.T) {
|
||||||
|
tokenInfo := &TokenInfo{
|
||||||
|
AccessToken: "at-789",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 7200,
|
||||||
|
ExpiresAt: 1700003600,
|
||||||
|
}
|
||||||
|
|
||||||
|
creds := BuildClaudeAccountCredentials(tokenInfo)
|
||||||
|
|
||||||
|
require.Equal(t, "at-789", creds["access_token"])
|
||||||
|
require.Equal(t, "Bearer", creds["token_type"])
|
||||||
|
require.Equal(t, "7200", creds["expires_in"])
|
||||||
|
require.Equal(t, "1700003600", creds["expires_at"])
|
||||||
|
_, hasRefresh := creds["refresh_token"]
|
||||||
|
_, hasScope := creds["scope"]
|
||||||
|
require.False(t, hasRefresh, "refresh_token should not be set when empty")
|
||||||
|
require.False(t, hasScope, "scope should not be set when empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== BackgroundRefreshPolicy tests ==========
|
||||||
|
|
||||||
|
func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) {
|
||||||
|
p := DefaultBackgroundRefreshPolicy()
|
||||||
|
|
||||||
|
require.ErrorIs(t, p.handleLockHeld(), errRefreshSkipped)
|
||||||
|
require.ErrorIs(t, p.handleAlreadyRefreshed(), errRefreshSkipped)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackgroundRefreshPolicy_SuccessOverride(t *testing.T) {
|
||||||
|
p := BackgroundRefreshPolicy{
|
||||||
|
OnLockHeld: BackgroundSkipAsSuccess,
|
||||||
|
OnAlreadyRefresh: BackgroundSkipAsSuccess,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, p.handleLockHeld())
|
||||||
|
require.NoError(t, p.handleAlreadyRefreshed())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== ProviderRefreshPolicy tests ==========
|
||||||
|
|
||||||
|
func TestClaudeProviderRefreshPolicy(t *testing.T) {
|
||||||
|
p := ClaudeProviderRefreshPolicy()
|
||||||
|
require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError)
|
||||||
|
require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld)
|
||||||
|
require.Equal(t, time.Minute, p.FailureTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIProviderRefreshPolicy(t *testing.T) {
|
||||||
|
p := OpenAIProviderRefreshPolicy()
|
||||||
|
require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError)
|
||||||
|
require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld)
|
||||||
|
require.Equal(t, time.Minute, p.FailureTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiProviderRefreshPolicy(t *testing.T) {
|
||||||
|
p := GeminiProviderRefreshPolicy()
|
||||||
|
require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError)
|
||||||
|
require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld)
|
||||||
|
require.Equal(t, time.Duration(0), p.FailureTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityProviderRefreshPolicy(t *testing.T) {
|
||||||
|
p := AntigravityProviderRefreshPolicy()
|
||||||
|
require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError)
|
||||||
|
require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld)
|
||||||
|
require.Equal(t, time.Duration(0), p.FailureTTL)
|
||||||
|
}
|
||||||
@@ -20,7 +20,7 @@ const (
|
|||||||
openAILockWarnThresholdMs = 250
|
openAILockWarnThresholdMs = 250
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。
|
// OpenAITokenRuntimeMetrics is a snapshot of refresh and lock contention metrics.
|
||||||
type OpenAITokenRuntimeMetrics struct {
|
type OpenAITokenRuntimeMetrics struct {
|
||||||
RefreshRequests int64
|
RefreshRequests int64
|
||||||
RefreshSuccess int64
|
RefreshSuccess int64
|
||||||
@@ -72,15 +72,18 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
|
|||||||
m.lastObservedUnixMs.Store(time.Now().UnixMilli())
|
m.lastObservedUnixMs.Store(time.Now().UnixMilli())
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
// OpenAITokenCache token cache interface.
|
||||||
type OpenAITokenCache = GeminiTokenCache
|
type OpenAITokenCache = GeminiTokenCache
|
||||||
|
|
||||||
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
|
// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts.
|
||||||
type OpenAITokenProvider struct {
|
type OpenAITokenProvider struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache OpenAITokenCache
|
tokenCache OpenAITokenCache
|
||||||
openAIOAuthService *OpenAIOAuthService
|
openAIOAuthService *OpenAIOAuthService
|
||||||
metrics *openAITokenRuntimeMetricsStore
|
metrics *openAITokenRuntimeMetricsStore
|
||||||
|
refreshAPI *OAuthRefreshAPI
|
||||||
|
executor OAuthRefreshExecutor
|
||||||
|
refreshPolicy ProviderRefreshPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOpenAITokenProvider(
|
func NewOpenAITokenProvider(
|
||||||
@@ -93,9 +96,21 @@ func NewOpenAITokenProvider(
|
|||||||
tokenCache: tokenCache,
|
tokenCache: tokenCache,
|
||||||
openAIOAuthService: openAIOAuthService,
|
openAIOAuthService: openAIOAuthService,
|
||||||
metrics: &openAITokenRuntimeMetricsStore{},
|
metrics: &openAITokenRuntimeMetricsStore{},
|
||||||
|
refreshPolicy: OpenAIProviderRefreshPolicy(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||||
|
func (p *OpenAITokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||||
|
p.refreshAPI = api
|
||||||
|
p.executor = executor
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRefreshPolicy injects caller-side refresh policy.
|
||||||
|
func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||||
|
p.refreshPolicy = policy
|
||||||
|
}
|
||||||
|
|
||||||
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
|
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return OpenAITokenRuntimeMetrics{}
|
return OpenAITokenRuntimeMetrics{}
|
||||||
@@ -110,7 +125,7 @@ func (p *OpenAITokenProvider) ensureMetrics() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccessToken 获取有效的 access_token
|
// GetAccessToken returns a valid access_token.
|
||||||
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
p.ensureMetrics()
|
p.ensureMetrics()
|
||||||
if account == nil {
|
if account == nil {
|
||||||
@@ -122,7 +137,7 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
|
|
||||||
cacheKey := OpenAITokenCacheKey(account)
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
|
||||||
// 1. 先尝试缓存
|
// 1) Try cache first.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||||
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
|
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
|
||||||
@@ -134,114 +149,62 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
|
|
||||||
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
|
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
|
||||||
|
|
||||||
// 2. 如果即将过期则刷新
|
// 2) Refresh if needed (pre-expiry skew).
|
||||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||||
refreshFailed := false
|
refreshFailed := false
|
||||||
if needsRefresh && p.tokenCache != nil {
|
|
||||||
|
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||||
|
p.metrics.refreshRequests.Add(1)
|
||||||
|
p.metrics.touchNow()
|
||||||
|
|
||||||
|
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
|
||||||
|
if account.Platform == PlatformSora {
|
||||||
|
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
|
||||||
|
refreshFailed = true
|
||||||
|
} else {
|
||||||
|
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
|
||||||
|
if err != nil {
|
||||||
|
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||||
|
p.metrics.refreshFailure.Add(1)
|
||||||
|
refreshFailed = true
|
||||||
|
} else if result.LockHeld {
|
||||||
|
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
|
||||||
|
p.metrics.lockContention.Add(1)
|
||||||
|
p.metrics.touchNow()
|
||||||
|
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
||||||
|
if waitErr != nil {
|
||||||
|
return "", waitErr
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(token) != "" {
|
||||||
|
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if result.Refreshed {
|
||||||
|
p.metrics.refreshSuccess.Add(1)
|
||||||
|
account = result.Account
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
} else {
|
||||||
|
account = result.Account
|
||||||
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if needsRefresh && p.tokenCache != nil {
|
||||||
|
// Backward-compatible test path when refreshAPI is not injected.
|
||||||
p.metrics.refreshRequests.Add(1)
|
p.metrics.refreshRequests.Add(1)
|
||||||
p.metrics.touchNow()
|
p.metrics.touchNow()
|
||||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
if lockErr == nil && locked {
|
if lockErr == nil && locked {
|
||||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
|
|
||||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从数据库获取最新账户信息
|
|
||||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
|
||||||
if err == nil && fresh != nil {
|
|
||||||
account = fresh
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
|
||||||
if account.Platform == PlatformSora {
|
|
||||||
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
|
|
||||||
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
|
||||||
refreshFailed = true
|
|
||||||
} else if p.openAIOAuthService == nil {
|
|
||||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
|
||||||
p.metrics.refreshFailure.Add(1)
|
|
||||||
refreshFailed = true // 无法刷新,标记失败
|
|
||||||
} else {
|
|
||||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
|
||||||
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
|
||||||
p.metrics.refreshFailure.Add(1)
|
|
||||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
|
||||||
} else {
|
|
||||||
p.metrics.refreshSuccess.Add(1)
|
|
||||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
|
||||||
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if lockErr != nil {
|
} else if lockErr != nil {
|
||||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
|
||||||
p.metrics.lockAcquireFailure.Add(1)
|
p.metrics.lockAcquireFailure.Add(1)
|
||||||
p.metrics.touchNow()
|
p.metrics.touchNow()
|
||||||
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
slog.Warn("openai_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||||
|
|
||||||
// 检查 ctx 是否已取消
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return "", ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从数据库获取最新账户信息
|
|
||||||
if p.accountRepo != nil {
|
|
||||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
|
||||||
if err == nil && fresh != nil {
|
|
||||||
account = fresh
|
|
||||||
}
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
|
|
||||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
|
||||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
|
||||||
if account.Platform == PlatformSora {
|
|
||||||
slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
|
|
||||||
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
|
|
||||||
refreshFailed = true
|
|
||||||
} else if p.openAIOAuthService == nil {
|
|
||||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
|
||||||
p.metrics.refreshFailure.Add(1)
|
|
||||||
refreshFailed = true
|
|
||||||
} else {
|
|
||||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
|
||||||
p.metrics.refreshFailure.Add(1)
|
|
||||||
refreshFailed = true
|
|
||||||
} else {
|
|
||||||
p.metrics.refreshSuccess.Add(1)
|
|
||||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
|
||||||
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
|
||||||
}
|
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// 锁被其他 worker 持有:使用短轮询+jitter,降低固定等待导致的尾延迟台阶。
|
|
||||||
p.metrics.lockContention.Add(1)
|
p.metrics.lockContention.Add(1)
|
||||||
p.metrics.touchNow()
|
p.metrics.touchNow()
|
||||||
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
||||||
@@ -260,22 +223,23 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return "", errors.New("access_token not found in credentials")
|
return "", errors.New("access_token not found in credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
// 3) Populate cache with TTL.
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
if isStale && latestAccount != nil {
|
if isStale && latestAccount != nil {
|
||||||
// 版本过时,使用 DB 中的最新 token
|
|
||||||
slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID)
|
slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID)
|
||||||
accessToken = latestAccount.GetOpenAIAccessToken()
|
accessToken = latestAccount.GetOpenAIAccessToken()
|
||||||
if strings.TrimSpace(accessToken) == "" {
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
return "", errors.New("access_token not found after version check")
|
return "", errors.New("access_token not found after version check")
|
||||||
}
|
}
|
||||||
// 不写入缓存,让下次请求重新处理
|
|
||||||
} else {
|
} else {
|
||||||
ttl := 30 * time.Minute
|
ttl := 30 * time.Minute
|
||||||
if refreshFailed {
|
if refreshFailed {
|
||||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
if p.refreshPolicy.FailureTTL > 0 {
|
||||||
ttl = time.Minute
|
ttl = p.refreshPolicy.FailureTTL
|
||||||
|
} else {
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||||
} else if expiresAt != nil {
|
} else if expiresAt != nil {
|
||||||
until := time.Until(*expiresAt)
|
until := time.Until(*expiresAt)
|
||||||
|
|||||||
99
backend/internal/service/refresh_policy.go
Normal file
99
backend/internal/service/refresh_policy.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// ProviderRefreshErrorAction 定义 provider 在刷新失败时的处理动作。
|
||||||
|
type ProviderRefreshErrorAction int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ProviderRefreshErrorReturn 失败即返回错误(不降级旧 token)。
|
||||||
|
ProviderRefreshErrorReturn ProviderRefreshErrorAction = iota
|
||||||
|
// ProviderRefreshErrorUseExistingToken 失败后继续使用现有 token。
|
||||||
|
ProviderRefreshErrorUseExistingToken
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderLockHeldAction 定义 provider 在刷新锁被占用时的处理动作。
|
||||||
|
type ProviderLockHeldAction int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ProviderLockHeldUseExistingToken 直接使用现有 token。
|
||||||
|
ProviderLockHeldUseExistingToken ProviderLockHeldAction = iota
|
||||||
|
// ProviderLockHeldWaitForCache 等待后重试缓存读取。
|
||||||
|
ProviderLockHeldWaitForCache
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderRefreshPolicy 描述 provider 的平台差异策略。
|
||||||
|
type ProviderRefreshPolicy struct {
|
||||||
|
OnRefreshError ProviderRefreshErrorAction
|
||||||
|
OnLockHeld ProviderLockHeldAction
|
||||||
|
FailureTTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClaudeProviderRefreshPolicy() ProviderRefreshPolicy {
|
||||||
|
return ProviderRefreshPolicy{
|
||||||
|
OnRefreshError: ProviderRefreshErrorUseExistingToken,
|
||||||
|
OnLockHeld: ProviderLockHeldWaitForCache,
|
||||||
|
FailureTTL: time.Minute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenAIProviderRefreshPolicy() ProviderRefreshPolicy {
|
||||||
|
return ProviderRefreshPolicy{
|
||||||
|
OnRefreshError: ProviderRefreshErrorUseExistingToken,
|
||||||
|
OnLockHeld: ProviderLockHeldWaitForCache,
|
||||||
|
FailureTTL: time.Minute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeminiProviderRefreshPolicy() ProviderRefreshPolicy {
|
||||||
|
return ProviderRefreshPolicy{
|
||||||
|
OnRefreshError: ProviderRefreshErrorReturn,
|
||||||
|
OnLockHeld: ProviderLockHeldUseExistingToken,
|
||||||
|
FailureTTL: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AntigravityProviderRefreshPolicy() ProviderRefreshPolicy {
|
||||||
|
return ProviderRefreshPolicy{
|
||||||
|
OnRefreshError: ProviderRefreshErrorReturn,
|
||||||
|
OnLockHeld: ProviderLockHeldUseExistingToken,
|
||||||
|
FailureTTL: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackgroundSkipAction 定义后台刷新服务在“未实际刷新”场景的计数方式。
|
||||||
|
type BackgroundSkipAction int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// BackgroundSkipAsSkipped 计入 skipped(保持当前默认行为)。
|
||||||
|
BackgroundSkipAsSkipped BackgroundSkipAction = iota
|
||||||
|
// BackgroundSkipAsSuccess 计入 success(仅用于兼容旧统计口径时可选)。
|
||||||
|
BackgroundSkipAsSuccess
|
||||||
|
)
|
||||||
|
|
||||||
|
// BackgroundRefreshPolicy 描述后台刷新服务的调用侧策略。
|
||||||
|
type BackgroundRefreshPolicy struct {
|
||||||
|
OnLockHeld BackgroundSkipAction
|
||||||
|
OnAlreadyRefresh BackgroundSkipAction
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultBackgroundRefreshPolicy() BackgroundRefreshPolicy {
|
||||||
|
return BackgroundRefreshPolicy{
|
||||||
|
OnLockHeld: BackgroundSkipAsSkipped,
|
||||||
|
OnAlreadyRefresh: BackgroundSkipAsSkipped,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p BackgroundRefreshPolicy) handleLockHeld() error {
|
||||||
|
if p.OnLockHeld == BackgroundSkipAsSuccess {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errRefreshSkipped
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p BackgroundRefreshPolicy) handleAlreadyRefreshed() error {
|
||||||
|
if p.OnAlreadyRefresh == BackgroundSkipAsSuccess {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errRefreshSkipped
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -16,10 +17,13 @@ import (
|
|||||||
type TokenRefreshService struct {
|
type TokenRefreshService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
refreshers []TokenRefresher
|
refreshers []TokenRefresher
|
||||||
|
executors []OAuthRefreshExecutor // 与 refreshers 一一对应的 executor(带 CacheKey)
|
||||||
|
refreshPolicy BackgroundRefreshPolicy
|
||||||
cfg *config.TokenRefreshConfig
|
cfg *config.TokenRefreshConfig
|
||||||
cacheInvalidator TokenCacheInvalidator
|
cacheInvalidator TokenCacheInvalidator
|
||||||
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
|
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
|
||||||
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
|
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
|
||||||
|
refreshAPI *OAuthRefreshAPI // 统一刷新 API
|
||||||
|
|
||||||
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
|
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
|
||||||
privacyClientFactory PrivacyClientFactory
|
privacyClientFactory PrivacyClientFactory
|
||||||
@@ -43,6 +47,7 @@ func NewTokenRefreshService(
|
|||||||
) *TokenRefreshService {
|
) *TokenRefreshService {
|
||||||
s := &TokenRefreshService{
|
s := &TokenRefreshService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
|
refreshPolicy: DefaultBackgroundRefreshPolicy(),
|
||||||
cfg: &cfg.TokenRefresh,
|
cfg: &cfg.TokenRefresh,
|
||||||
cacheInvalidator: cacheInvalidator,
|
cacheInvalidator: cacheInvalidator,
|
||||||
schedulerCache: schedulerCache,
|
schedulerCache: schedulerCache,
|
||||||
@@ -53,12 +58,24 @@ func NewTokenRefreshService(
|
|||||||
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
|
openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
|
||||||
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
|
openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
|
||||||
|
|
||||||
// 注册平台特定的刷新器
|
claudeRefresher := NewClaudeTokenRefresher(oauthService)
|
||||||
|
geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService)
|
||||||
|
agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService)
|
||||||
|
|
||||||
|
// 注册平台特定的刷新器(TokenRefresher 接口)
|
||||||
s.refreshers = []TokenRefresher{
|
s.refreshers = []TokenRefresher{
|
||||||
NewClaudeTokenRefresher(oauthService),
|
claudeRefresher,
|
||||||
openAIRefresher,
|
openAIRefresher,
|
||||||
NewGeminiTokenRefresher(geminiOAuthService),
|
geminiRefresher,
|
||||||
NewAntigravityTokenRefresher(antigravityOAuthService),
|
agRefresher,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法)
|
||||||
|
s.executors = []OAuthRefreshExecutor{
|
||||||
|
claudeRefresher,
|
||||||
|
openAIRefresher,
|
||||||
|
geminiRefresher,
|
||||||
|
agRefresher,
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
@@ -82,6 +99,16 @@ func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxy
|
|||||||
s.proxyRepo = proxyRepo
|
s.proxyRepo = proxyRepo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRefreshAPI 注入统一的 OAuth 刷新 API
|
||||||
|
func (s *TokenRefreshService) SetRefreshAPI(api *OAuthRefreshAPI) {
|
||||||
|
s.refreshAPI = api
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRefreshPolicy 注入后台刷新调用侧策略(用于显式化平台/场景差异行为)。
|
||||||
|
func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) {
|
||||||
|
s.refreshPolicy = policy
|
||||||
|
}
|
||||||
|
|
||||||
// Start 启动后台刷新服务
|
// Start 启动后台刷新服务
|
||||||
func (s *TokenRefreshService) Start() {
|
func (s *TokenRefreshService) Start() {
|
||||||
if !s.cfg.Enabled {
|
if !s.cfg.Enabled {
|
||||||
@@ -148,13 +175,13 @@ func (s *TokenRefreshService) processRefresh() {
|
|||||||
totalAccounts := len(accounts)
|
totalAccounts := len(accounts)
|
||||||
oauthAccounts := 0 // 可刷新的OAuth账号数
|
oauthAccounts := 0 // 可刷新的OAuth账号数
|
||||||
needsRefresh := 0 // 需要刷新的账号数
|
needsRefresh := 0 // 需要刷新的账号数
|
||||||
refreshed, failed := 0, 0
|
refreshed, failed, skipped := 0, 0, 0
|
||||||
|
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
account := &accounts[i]
|
account := &accounts[i]
|
||||||
|
|
||||||
// 遍历所有刷新器,找到能处理此账号的
|
// 遍历所有刷新器,找到能处理此账号的
|
||||||
for _, refresher := range s.refreshers {
|
for idx, refresher := range s.refreshers {
|
||||||
if !refresher.CanRefresh(account) {
|
if !refresher.CanRefresh(account) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -168,14 +195,24 @@ func (s *TokenRefreshService) processRefresh() {
|
|||||||
|
|
||||||
needsRefresh++
|
needsRefresh++
|
||||||
|
|
||||||
|
// 获取对应的 executor
|
||||||
|
var executor OAuthRefreshExecutor
|
||||||
|
if idx < len(s.executors) {
|
||||||
|
executor = s.executors[idx]
|
||||||
|
}
|
||||||
|
|
||||||
// 执行刷新
|
// 执行刷新
|
||||||
if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
|
if err := s.refreshWithRetry(ctx, account, refresher, executor, refreshWindow); err != nil {
|
||||||
slog.Warn("token_refresh.account_refresh_failed",
|
if errors.Is(err, errRefreshSkipped) {
|
||||||
"account_id", account.ID,
|
skipped++
|
||||||
"account_name", account.Name,
|
} else {
|
||||||
"error", err,
|
slog.Warn("token_refresh.account_refresh_failed",
|
||||||
)
|
"account_id", account.ID,
|
||||||
failed++
|
"account_name", account.Name,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
failed++
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
slog.Info("token_refresh.account_refreshed",
|
slog.Info("token_refresh.account_refreshed",
|
||||||
"account_id", account.ID,
|
"account_id", account.ID,
|
||||||
@@ -193,13 +230,14 @@ func (s *TokenRefreshService) processRefresh() {
|
|||||||
if needsRefresh == 0 && failed == 0 {
|
if needsRefresh == 0 && failed == 0 {
|
||||||
slog.Debug("token_refresh.cycle_completed",
|
slog.Debug("token_refresh.cycle_completed",
|
||||||
"total", totalAccounts, "oauth", oauthAccounts,
|
"total", totalAccounts, "oauth", oauthAccounts,
|
||||||
"needs_refresh", needsRefresh, "refreshed", refreshed, "failed", failed)
|
"needs_refresh", needsRefresh, "refreshed", refreshed, "skipped", skipped, "failed", failed)
|
||||||
} else {
|
} else {
|
||||||
slog.Info("token_refresh.cycle_completed",
|
slog.Info("token_refresh.cycle_completed",
|
||||||
"total", totalAccounts,
|
"total", totalAccounts,
|
||||||
"oauth", oauthAccounts,
|
"oauth", oauthAccounts,
|
||||||
"needs_refresh", needsRefresh,
|
"needs_refresh", needsRefresh,
|
||||||
"refreshed", refreshed,
|
"refreshed", refreshed,
|
||||||
|
"skipped", skipped,
|
||||||
"failed", failed,
|
"failed", failed,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -212,83 +250,42 @@ func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account
|
|||||||
}
|
}
|
||||||
|
|
||||||
// refreshWithRetry 带重试的刷新
|
// refreshWithRetry 带重试的刷新
|
||||||
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher) error {
|
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher, executor OAuthRefreshExecutor, refreshWindow time.Duration) error {
|
||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||||
newCredentials, err := refresher.Refresh(ctx, account)
|
var newCredentials map[string]any
|
||||||
|
var err error
|
||||||
|
|
||||||
// 如果有新凭证,先更新(即使有错误也要保存 token)
|
// 优先使用统一 API(带分布式锁 + DB 重读保护)
|
||||||
if newCredentials != nil {
|
if s.refreshAPI != nil && executor != nil {
|
||||||
// 记录刷新版本时间戳,用于解决缓存一致性问题
|
result, refreshErr := s.refreshAPI.RefreshIfNeeded(ctx, account, executor, refreshWindow)
|
||||||
// TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入
|
if refreshErr != nil {
|
||||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
err = refreshErr
|
||||||
|
} else if result.LockHeld {
|
||||||
account.Credentials = newCredentials
|
// 锁被其他 worker 持有,由调用侧策略决定如何计数
|
||||||
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
return s.refreshPolicy.handleLockHeld()
|
||||||
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
} else if !result.Refreshed {
|
||||||
|
// 已被其他路径刷新,由调用侧策略决定如何计数
|
||||||
|
return s.refreshPolicy.handleAlreadyRefreshed()
|
||||||
|
} else {
|
||||||
|
account = result.Account
|
||||||
|
_ = result.NewCredentials // 统一 API 已设置 _token_version 并更新 DB,无需重复操作
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 降级:直接调用 refresher(兼容旧路径)
|
||||||
|
newCredentials, err = refresher.Refresh(ctx, account)
|
||||||
|
if newCredentials != nil {
|
||||||
|
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||||
|
account.Credentials = newCredentials
|
||||||
|
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
||||||
|
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
|
s.postRefreshActions(ctx, account)
|
||||||
if account.Platform == PlatformAntigravity &&
|
|
||||||
account.Status == StatusError &&
|
|
||||||
strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
|
||||||
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
|
|
||||||
slog.Warn("token_refresh.clear_account_error_failed",
|
|
||||||
"account_id", account.ID,
|
|
||||||
"error", clearErr,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
|
|
||||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
|
||||||
if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil {
|
|
||||||
slog.Warn("token_refresh.clear_temp_unschedulable_failed",
|
|
||||||
"account_id", account.ID,
|
|
||||||
"error", clearErr,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
|
|
||||||
}
|
|
||||||
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
|
|
||||||
if s.tempUnschedCache != nil {
|
|
||||||
if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil {
|
|
||||||
slog.Warn("token_refresh.clear_temp_unsched_cache_failed",
|
|
||||||
"account_id", account.ID,
|
|
||||||
"error", clearErr,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
|
||||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
|
||||||
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
|
||||||
slog.Warn("token_refresh.invalidate_token_cache_failed",
|
|
||||||
"account_id", account.ID,
|
|
||||||
"error", err,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials
|
|
||||||
// 这解决了 token 刷新后调度器缓存数据不一致的问题(#445)
|
|
||||||
if s.schedulerCache != nil {
|
|
||||||
if err := s.schedulerCache.SetAccount(ctx, account); err != nil {
|
|
||||||
slog.Warn("token_refresh.sync_scheduler_cache_failed",
|
|
||||||
"account_id", account.ID,
|
|
||||||
"error", err,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享
|
|
||||||
s.ensureOpenAIPrivacy(ctx, account)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -331,6 +328,70 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
|||||||
return lastErr
|
return lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// postRefreshActions 刷新成功后的后续动作(清除错误状态、缓存失效、调度器同步等)
|
||||||
|
func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *Account) {
|
||||||
|
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
|
||||||
|
if account.Platform == PlatformAntigravity &&
|
||||||
|
account.Status == StatusError &&
|
||||||
|
strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||||
|
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
|
||||||
|
slog.Warn("token_refresh.clear_account_error_failed",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"error", clearErr,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
|
||||||
|
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||||
|
if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil {
|
||||||
|
slog.Warn("token_refresh.clear_temp_unschedulable_failed",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"error", clearErr,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
|
||||||
|
}
|
||||||
|
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
|
||||||
|
if s.tempUnschedCache != nil {
|
||||||
|
if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil {
|
||||||
|
slog.Warn("token_refresh.clear_temp_unsched_cache_failed",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"error", clearErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||||
|
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||||
|
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||||
|
slog.Warn("token_refresh.invalidate_token_cache_failed",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
slog.Debug("token_refresh.token_cache_invalidated", "account_id", account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 同步更新调度器缓存,确保调度获取的 Account 对象包含最新的 credentials
|
||||||
|
if s.schedulerCache != nil {
|
||||||
|
if err := s.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||||
|
slog.Warn("token_refresh.sync_scheduler_cache_failed",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享
|
||||||
|
s.ensureOpenAIPrivacy(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
// errRefreshSkipped 表示刷新被跳过(锁竞争或已被其他路径刷新),不计入 failed 或 refreshed
|
||||||
|
var errRefreshSkipped = fmt.Errorf("refresh skipped")
|
||||||
|
|
||||||
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
|
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
|
||||||
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
|
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
|
||||||
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
|
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
|
||||||
|
|||||||
@@ -84,6 +84,10 @@ func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map
|
|||||||
return r.credentials, nil
|
return r.credentials, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *tokenRefresherStub) CacheKey(account *Account) string {
|
||||||
|
return "test:stub:" + account.Platform
|
||||||
|
}
|
||||||
|
|
||||||
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||||
repo := &tokenRefreshAccountRepo{}
|
repo := &tokenRefreshAccountRepo{}
|
||||||
invalidator := &tokenCacheInvalidatorStub{}
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
@@ -105,7 +109,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, repo.updateCalls)
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
require.Equal(t, 1, invalidator.calls)
|
require.Equal(t, 1, invalidator.calls)
|
||||||
@@ -133,7 +137,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, repo.updateCalls)
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
require.Equal(t, 1, invalidator.calls)
|
require.Equal(t, 1, invalidator.calls)
|
||||||
@@ -159,7 +163,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, repo.updateCalls)
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
}
|
}
|
||||||
@@ -186,7 +190,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, repo.updateCalls)
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
|
require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
|
||||||
@@ -214,7 +218,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, repo.updateCalls)
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
|
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
|
||||||
@@ -242,7 +246,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, repo.updateCalls)
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
|
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
|
||||||
@@ -270,7 +274,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "failed to save credentials")
|
require.Contains(t, err.Error(), "failed to save credentials")
|
||||||
require.Equal(t, 1, repo.updateCalls)
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
@@ -297,7 +301,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
|
|||||||
err: errors.New("refresh failed"),
|
err: errors.New("refresh failed"),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
|
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
|
||||||
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
|
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
|
||||||
@@ -324,7 +328,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
|
|||||||
err: errors.New("network error"), // 可重试错误
|
err: errors.New("network error"), // 可重试错误
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(t, 0, repo.updateCalls)
|
require.Equal(t, 0, repo.updateCalls)
|
||||||
require.Equal(t, 0, invalidator.calls)
|
require.Equal(t, 0, invalidator.calls)
|
||||||
@@ -351,7 +355,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
|
|||||||
err: errors.New("invalid_grant: token revoked"), // 不可重试错误
|
err: errors.New("invalid_grant: token revoked"), // 不可重试错误
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(t, 0, repo.updateCalls)
|
require.Equal(t, 0, repo.updateCalls)
|
||||||
require.Equal(t, 0, invalidator.calls)
|
require.Equal(t, 0, invalidator.calls)
|
||||||
@@ -383,7 +387,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, repo.updateCalls)
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
|
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
|
||||||
@@ -422,7 +426,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t
|
|||||||
err: errors.New("invalid_grant: token revoked"),
|
err: errors.New("invalid_grant: token revoked"),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(t, 1, repo.setErrorCalls) // 所有平台不可重试错误都应 SetError
|
require.Equal(t, 1, repo.setErrorCalls) // 所有平台不可重试错误都应 SetError
|
||||||
})
|
})
|
||||||
@@ -453,3 +457,212 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ========== Path A (refreshAPI) 测试用例 ==========
|
||||||
|
|
||||||
|
// mockTokenCacheForRefreshAPI 用于 Path A 测试的 GeminiTokenCache mock
|
||||||
|
type mockTokenCacheForRefreshAPI struct {
|
||||||
|
lockResult bool
|
||||||
|
lockErr error
|
||||||
|
releaseCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTokenCacheForRefreshAPI) GetAccessToken(_ context.Context, _ string) (string, error) {
|
||||||
|
return "", errors.New("not cached")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTokenCacheForRefreshAPI) SetAccessToken(_ context.Context, _ string, _ string, _ time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTokenCacheForRefreshAPI) DeleteAccessToken(_ context.Context, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTokenCacheForRefreshAPI) AcquireRefreshLock(_ context.Context, _ string, _ time.Duration) (bool, error) {
|
||||||
|
return m.lockResult, m.lockErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTokenCacheForRefreshAPI) ReleaseRefreshLock(_ context.Context, _ string) error {
|
||||||
|
m.releaseCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildPathAService 构建注入了 refreshAPI 的 service(Path A 测试辅助)
|
||||||
|
func buildPathAService(repo *tokenRefreshAccountRepo, cache GeminiTokenCache, invalidator TokenCacheInvalidator) (*TokenRefreshService, *tokenRefresherStub) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 1,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||||
|
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
service.SetRefreshAPI(refreshAPI)
|
||||||
|
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
credentials: map[string]any{
|
||||||
|
"access_token": "refreshed-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return service, refresher
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPathA_Success 统一 API 路径正常成功:刷新 + DB 更新 + postRefreshActions
|
||||||
|
func TestPathA_Success(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 100,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||||
|
|
||||||
|
service, refresher := buildPathAService(repo, cache, invalidator)
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, repo.updateCalls) // DB 更新被调用
|
||||||
|
require.Equal(t, 1, invalidator.calls) // 缓存失效被调用
|
||||||
|
require.Equal(t, 1, cache.releaseCalls) // 锁被释放
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPathA_LockHeld 锁被其他 worker 持有 → 返回 errRefreshSkipped
|
||||||
|
func TestPathA_LockHeld(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 101,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cache := &mockTokenCacheForRefreshAPI{lockResult: false} // 锁获取失败(被占)
|
||||||
|
|
||||||
|
service, refresher := buildPathAService(repo, cache, invalidator)
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
|
require.ErrorIs(t, err, errRefreshSkipped)
|
||||||
|
require.Equal(t, 0, repo.updateCalls) // 不应更新 DB
|
||||||
|
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPathA_AlreadyRefreshed 二次检查发现已被其他路径刷新 → 返回 errRefreshSkipped
|
||||||
|
func TestPathA_AlreadyRefreshed(t *testing.T) {
|
||||||
|
// NeedsRefresh 返回 false → RefreshIfNeeded 返回 {Refreshed: false}
|
||||||
|
account := &Account{
|
||||||
|
ID: 102,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||||
|
|
||||||
|
service, _ := buildPathAService(repo, cache, invalidator)
|
||||||
|
|
||||||
|
// 使用一个 NeedsRefresh 返回 false 的 stub
|
||||||
|
noRefreshNeeded := &tokenRefresherStub{
|
||||||
|
credentials: map[string]any{"access_token": "token"},
|
||||||
|
}
|
||||||
|
// 覆盖 NeedsRefresh 行为 — 我们需要一个新的 stub 类型
|
||||||
|
alwaysFreshStub := &alwaysFreshRefresherStub{}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, noRefreshNeeded, alwaysFreshStub, time.Hour)
|
||||||
|
require.ErrorIs(t, err, errRefreshSkipped)
|
||||||
|
require.Equal(t, 0, repo.updateCalls)
|
||||||
|
require.Equal(t, 0, invalidator.calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
// alwaysFreshRefresherStub 二次检查时认为不需要刷新(模拟已被其他路径刷新)
|
||||||
|
type alwaysFreshRefresherStub struct{}
|
||||||
|
|
||||||
|
func (r *alwaysFreshRefresherStub) CanRefresh(_ *Account) bool { return true }
|
||||||
|
func (r *alwaysFreshRefresherStub) NeedsRefresh(_ *Account, _ time.Duration) bool { return false }
|
||||||
|
func (r *alwaysFreshRefresherStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
|
||||||
|
return nil, errors.New("should not be called")
|
||||||
|
}
|
||||||
|
func (r *alwaysFreshRefresherStub) CacheKey(account *Account) string {
|
||||||
|
return "test:fresh:" + account.Platform
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPathA_NonRetryableError 统一 API 路径返回不可重试错误 → SetError
|
||||||
|
func TestPathA_NonRetryableError(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 103,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||||
|
|
||||||
|
service, _ := buildPathAService(repo, cache, invalidator)
|
||||||
|
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
err: errors.New("invalid_grant: token revoked"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls) // 应标记 error 状态
|
||||||
|
require.Equal(t, 0, repo.updateCalls) // 不应更新 credentials
|
||||||
|
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPathA_RetryableErrorExhausted 统一 API 路径可重试错误耗尽 → 不标记 error
|
||||||
|
func TestPathA_RetryableErrorExhausted(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 104,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 2,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
|
||||||
|
refreshAPI := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
service.SetRefreshAPI(refreshAPI)
|
||||||
|
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
err: errors.New("network timeout"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, 0, repo.setErrorCalls) // 可重试错误不标记 error
|
||||||
|
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
|
||||||
|
require.Equal(t, 0, invalidator.calls) // 不应触发缓存失效
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPathA_DBUpdateFailed 统一 API 路径 DB 更新失败 → 返回 error,不执行 postRefreshActions
|
||||||
|
func TestPathA_DBUpdateFailed(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
ID: 105,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
repo := &tokenRefreshAccountRepo{updateErr: errors.New("db connection lost")}
|
||||||
|
repo.accountsByID = map[int64]*Account{account.ID: account}
|
||||||
|
invalidator := &tokenCacheInvalidatorStub{}
|
||||||
|
cache := &mockTokenCacheForRefreshAPI{lockResult: true}
|
||||||
|
|
||||||
|
service, refresher := buildPathAService(repo, cache, invalidator)
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "DB update failed")
|
||||||
|
require.Equal(t, 1, repo.updateCalls) // DB 更新被尝试
|
||||||
|
require.Equal(t, 0, invalidator.calls) // DB 失败时不应触发缓存失效
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
"log"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,6 +32,11 @@ func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CacheKey 返回用于分布式锁的缓存键
|
||||||
|
func (r *ClaudeTokenRefresher) CacheKey(account *Account) string {
|
||||||
|
return ClaudeTokenCacheKey(account)
|
||||||
|
}
|
||||||
|
|
||||||
// CanRefresh 检查是否能处理此账号
|
// CanRefresh 检查是否能处理此账号
|
||||||
// 只处理 anthropic 平台的 oauth 类型账号
|
// 只处理 anthropic 平台的 oauth 类型账号
|
||||||
// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新
|
// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新
|
||||||
@@ -59,24 +63,8 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保留现有credentials中的所有字段
|
newCredentials := BuildClaudeAccountCredentials(tokenInfo)
|
||||||
newCredentials := make(map[string]any)
|
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||||
for k, v := range account.Credentials {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// 只更新token相关字段
|
|
||||||
// 注意:expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型
|
|
||||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
|
||||||
newCredentials["token_type"] = tokenInfo.TokenType
|
|
||||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
|
||||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
|
||||||
if tokenInfo.RefreshToken != "" {
|
|
||||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
|
||||||
}
|
|
||||||
if tokenInfo.Scope != "" {
|
|
||||||
newCredentials["scope"] = tokenInfo.Scope
|
|
||||||
}
|
|
||||||
|
|
||||||
return newCredentials, nil
|
return newCredentials, nil
|
||||||
}
|
}
|
||||||
@@ -97,6 +85,11 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService, accountRepo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CacheKey 返回用于分布式锁的缓存键
|
||||||
|
func (r *OpenAITokenRefresher) CacheKey(account *Account) string {
|
||||||
|
return OpenAITokenCacheKey(account)
|
||||||
|
}
|
||||||
|
|
||||||
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
|
// SetSoraAccountRepo 设置 Sora 账号扩展表仓储
|
||||||
// 用于在 Token 刷新时同步更新 sora_accounts 表
|
// 用于在 Token 刷新时同步更新 sora_accounts 表
|
||||||
// 如果未设置,syncLinkedSoraAccounts 只会更新 accounts.credentials
|
// 如果未设置,syncLinkedSoraAccounts 只会更新 accounts.credentials
|
||||||
@@ -137,13 +130,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
|||||||
|
|
||||||
// 使用服务提供的方法构建新凭证,并保留原有字段
|
// 使用服务提供的方法构建新凭证,并保留原有字段
|
||||||
newCredentials := r.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials := r.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||||
// 保留原有credentials中非token相关字段
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 异步同步关联的 Sora 账号(不阻塞主流程)
|
// 异步同步关联的 Sora 账号(不阻塞主流程)
|
||||||
if r.accountRepo != nil && r.syncLinkedSora {
|
if r.accountRepo != nil && r.syncLinkedSora {
|
||||||
|
|||||||
@@ -51,16 +51,77 @@ func ProvideTokenRefreshService(
|
|||||||
tempUnschedCache TempUnschedCache,
|
tempUnschedCache TempUnschedCache,
|
||||||
privacyClientFactory PrivacyClientFactory,
|
privacyClientFactory PrivacyClientFactory,
|
||||||
proxyRepo ProxyRepository,
|
proxyRepo ProxyRepository,
|
||||||
|
refreshAPI *OAuthRefreshAPI,
|
||||||
) *TokenRefreshService {
|
) *TokenRefreshService {
|
||||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||||
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
||||||
svc.SetSoraAccountRepo(soraAccountRepo)
|
svc.SetSoraAccountRepo(soraAccountRepo)
|
||||||
// 注入 OpenAI privacy opt-out 依赖
|
// 注入 OpenAI privacy opt-out 依赖
|
||||||
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
|
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
|
||||||
|
// 注入统一 OAuth 刷新 API(消除 TokenRefreshService 与 TokenProvider 之间的竞争条件)
|
||||||
|
svc.SetRefreshAPI(refreshAPI)
|
||||||
|
// 调用侧显式注入后台刷新策略,避免策略漂移
|
||||||
|
svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy())
|
||||||
svc.Start()
|
svc.Start()
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideClaudeTokenProvider creates ClaudeTokenProvider with OAuthRefreshAPI injection
|
||||||
|
func ProvideClaudeTokenProvider(
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
tokenCache GeminiTokenCache,
|
||||||
|
oauthService *OAuthService,
|
||||||
|
refreshAPI *OAuthRefreshAPI,
|
||||||
|
) *ClaudeTokenProvider {
|
||||||
|
p := NewClaudeTokenProvider(accountRepo, tokenCache, oauthService)
|
||||||
|
executor := NewClaudeTokenRefresher(oauthService)
|
||||||
|
p.SetRefreshAPI(refreshAPI, executor)
|
||||||
|
p.SetRefreshPolicy(ClaudeProviderRefreshPolicy())
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideOpenAITokenProvider creates OpenAITokenProvider with OAuthRefreshAPI injection
|
||||||
|
func ProvideOpenAITokenProvider(
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
tokenCache GeminiTokenCache,
|
||||||
|
openaiOAuthService *OpenAIOAuthService,
|
||||||
|
refreshAPI *OAuthRefreshAPI,
|
||||||
|
) *OpenAITokenProvider {
|
||||||
|
p := NewOpenAITokenProvider(accountRepo, tokenCache, openaiOAuthService)
|
||||||
|
executor := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
|
||||||
|
p.SetRefreshAPI(refreshAPI, executor)
|
||||||
|
p.SetRefreshPolicy(OpenAIProviderRefreshPolicy())
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideGeminiTokenProvider creates GeminiTokenProvider with OAuthRefreshAPI injection
|
||||||
|
func ProvideGeminiTokenProvider(
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
tokenCache GeminiTokenCache,
|
||||||
|
geminiOAuthService *GeminiOAuthService,
|
||||||
|
refreshAPI *OAuthRefreshAPI,
|
||||||
|
) *GeminiTokenProvider {
|
||||||
|
p := NewGeminiTokenProvider(accountRepo, tokenCache, geminiOAuthService)
|
||||||
|
executor := NewGeminiTokenRefresher(geminiOAuthService)
|
||||||
|
p.SetRefreshAPI(refreshAPI, executor)
|
||||||
|
p.SetRefreshPolicy(GeminiProviderRefreshPolicy())
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideAntigravityTokenProvider creates AntigravityTokenProvider with OAuthRefreshAPI injection
|
||||||
|
func ProvideAntigravityTokenProvider(
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
tokenCache GeminiTokenCache,
|
||||||
|
antigravityOAuthService *AntigravityOAuthService,
|
||||||
|
refreshAPI *OAuthRefreshAPI,
|
||||||
|
) *AntigravityTokenProvider {
|
||||||
|
p := NewAntigravityTokenProvider(accountRepo, tokenCache, antigravityOAuthService)
|
||||||
|
executor := NewAntigravityTokenRefresher(antigravityOAuthService)
|
||||||
|
p.SetRefreshAPI(refreshAPI, executor)
|
||||||
|
p.SetRefreshPolicy(AntigravityProviderRefreshPolicy())
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
|
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
|
||||||
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||||
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
|
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
|
||||||
@@ -375,11 +436,12 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewCompositeTokenCacheInvalidator,
|
NewCompositeTokenCacheInvalidator,
|
||||||
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
||||||
NewAntigravityOAuthService,
|
NewAntigravityOAuthService,
|
||||||
NewGeminiTokenProvider,
|
NewOAuthRefreshAPI,
|
||||||
|
ProvideGeminiTokenProvider,
|
||||||
NewGeminiMessagesCompatService,
|
NewGeminiMessagesCompatService,
|
||||||
NewAntigravityTokenProvider,
|
ProvideAntigravityTokenProvider,
|
||||||
NewOpenAITokenProvider,
|
ProvideOpenAITokenProvider,
|
||||||
NewClaudeTokenProvider,
|
ProvideClaudeTokenProvider,
|
||||||
NewAntigravityGatewayService,
|
NewAntigravityGatewayService,
|
||||||
ProvideRateLimitService,
|
ProvideRateLimitService,
|
||||||
NewAccountUsageService,
|
NewAccountUsageService,
|
||||||
|
|||||||
Reference in New Issue
Block a user