Files
sub2api/backend/internal/service/gemini_token_provider.go
erio 1fc9dd7b68 feat: unified OAuth token refresh API with distributed locking
Introduce OAuthRefreshAPI as the single entry point for all OAuth token
refresh operations, eliminating the race condition where background
refresh and inline refresh could simultaneously use the same
refresh_token (fixes #1035).

Key changes:
- Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey
- Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow
- Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types
- Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI
- Rewrite TokenRefreshService.refreshWithRetry to use unified API path
- Add MergeCredentials and BuildClaudeAccountCredentials helpers
- Add 40 unit tests covering all new and modified code paths
2026-03-16 01:31:54 +08:00

178 lines
5.5 KiB
Go

package service
import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
"time"
)
const (
geminiTokenRefreshSkew = 3 * time.Minute
geminiTokenCacheSkew = 5 * time.Minute
)
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
type GeminiTokenProvider struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache
geminiOAuthService *GeminiOAuthService
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
}
func NewGeminiTokenProvider(
accountRepo AccountRepository,
tokenCache GeminiTokenCache,
geminiOAuthService *GeminiOAuthService,
) *GeminiTokenProvider {
return &GeminiTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
geminiOAuthService: geminiOAuthService,
refreshPolicy: GeminiProviderRefreshPolicy(),
}
}
// SetRefreshAPI injects unified OAuth refresh API and executor.
func (p *GeminiTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func (p *GeminiTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
return "", errors.New("not a gemini oauth account")
}
cacheKey := GeminiTokenCacheKey(account)
// 1) Try cache first.
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// 2) Refresh if needed (pre-expiry skew).
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, geminiTokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
slog.Debug("gemini_token_lock_held_use_old", "account_id", account.ID)
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("gemini_token_lock_failed", "account_id", account.ID, "error", lockErr)
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// project_id is optional now:
// - If present: use Code Assist API (requires project_id)
// - If absent: use AI Studio API with OAuth token.
projectID := strings.TrimSpace(account.GetCredential("project_id"))
autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true"
if projectID == "" && autoDetectProjectID {
if p.geminiOAuthService == nil {
return accessToken, nil
}
var proxyURL string
if account.ProxyID != nil && p.geminiOAuthService.proxyRepo != nil {
if proxy, err := p.geminiOAuthService.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
if err != nil {
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
return accessToken, nil
}
detected = strings.TrimSpace(detected)
tierID = strings.TrimSpace(tierID)
if detected != "" {
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials["project_id"] = detected
if tierID != "" {
account.Credentials["tier_id"] = tierID
}
_ = p.accountRepo.Update(ctx, account)
}
}
// 3) Populate cache with TTL.
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > geminiTokenCacheSkew:
ttl = until - geminiTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
}
return accessToken, nil
}
func GeminiTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "gemini:" + projectID
}
return "gemini:account:" + strconv.FormatInt(account.ID, 10)
}