mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
Introduce OAuthRefreshAPI as the single entry point for all OAuth token refresh operations, eliminating the race condition where background refresh and inline refresh could simultaneously use the same refresh_token (fixes #1035). Key changes: - Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey - Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow - Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types - Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI - Rewrite TokenRefreshService.refreshWithRetry to use unified API path - Add MergeCredentials and BuildClaudeAccountCredentials helpers - Add 40 unit tests covering all new and modified code paths
324 lines
9.6 KiB
Go
324 lines
9.6 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"math/rand/v2"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
openAITokenRefreshSkew = 3 * time.Minute
|
|
openAITokenCacheSkew = 5 * time.Minute
|
|
openAILockInitialWait = 20 * time.Millisecond
|
|
openAILockMaxWait = 120 * time.Millisecond
|
|
openAILockMaxAttempts = 5
|
|
openAILockJitterRatio = 0.2
|
|
openAILockWarnThresholdMs = 250
|
|
)
|
|
|
|
// OpenAITokenRuntimeMetrics is a snapshot of refresh and lock contention metrics.
|
|
type OpenAITokenRuntimeMetrics struct {
|
|
RefreshRequests int64
|
|
RefreshSuccess int64
|
|
RefreshFailure int64
|
|
LockAcquireFailure int64
|
|
LockContention int64
|
|
LockWaitSamples int64
|
|
LockWaitTotalMs int64
|
|
LockWaitHit int64
|
|
LockWaitMiss int64
|
|
LastObservedUnixMs int64
|
|
}
|
|
|
|
type openAITokenRuntimeMetricsStore struct {
|
|
refreshRequests atomic.Int64
|
|
refreshSuccess atomic.Int64
|
|
refreshFailure atomic.Int64
|
|
lockAcquireFailure atomic.Int64
|
|
lockContention atomic.Int64
|
|
lockWaitSamples atomic.Int64
|
|
lockWaitTotalMs atomic.Int64
|
|
lockWaitHit atomic.Int64
|
|
lockWaitMiss atomic.Int64
|
|
lastObservedUnixMs atomic.Int64
|
|
}
|
|
|
|
func (m *openAITokenRuntimeMetricsStore) snapshot() OpenAITokenRuntimeMetrics {
|
|
if m == nil {
|
|
return OpenAITokenRuntimeMetrics{}
|
|
}
|
|
return OpenAITokenRuntimeMetrics{
|
|
RefreshRequests: m.refreshRequests.Load(),
|
|
RefreshSuccess: m.refreshSuccess.Load(),
|
|
RefreshFailure: m.refreshFailure.Load(),
|
|
LockAcquireFailure: m.lockAcquireFailure.Load(),
|
|
LockContention: m.lockContention.Load(),
|
|
LockWaitSamples: m.lockWaitSamples.Load(),
|
|
LockWaitTotalMs: m.lockWaitTotalMs.Load(),
|
|
LockWaitHit: m.lockWaitHit.Load(),
|
|
LockWaitMiss: m.lockWaitMiss.Load(),
|
|
LastObservedUnixMs: m.lastObservedUnixMs.Load(),
|
|
}
|
|
}
|
|
|
|
func (m *openAITokenRuntimeMetricsStore) touchNow() {
|
|
if m == nil {
|
|
return
|
|
}
|
|
m.lastObservedUnixMs.Store(time.Now().UnixMilli())
|
|
}
|
|
|
|
// OpenAITokenCache token cache interface.
|
|
type OpenAITokenCache = GeminiTokenCache
|
|
|
|
// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts.
|
|
type OpenAITokenProvider struct {
|
|
accountRepo AccountRepository
|
|
tokenCache OpenAITokenCache
|
|
openAIOAuthService *OpenAIOAuthService
|
|
metrics *openAITokenRuntimeMetricsStore
|
|
refreshAPI *OAuthRefreshAPI
|
|
executor OAuthRefreshExecutor
|
|
refreshPolicy ProviderRefreshPolicy
|
|
}
|
|
|
|
func NewOpenAITokenProvider(
|
|
accountRepo AccountRepository,
|
|
tokenCache OpenAITokenCache,
|
|
openAIOAuthService *OpenAIOAuthService,
|
|
) *OpenAITokenProvider {
|
|
return &OpenAITokenProvider{
|
|
accountRepo: accountRepo,
|
|
tokenCache: tokenCache,
|
|
openAIOAuthService: openAIOAuthService,
|
|
metrics: &openAITokenRuntimeMetricsStore{},
|
|
refreshPolicy: OpenAIProviderRefreshPolicy(),
|
|
}
|
|
}
|
|
|
|
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
|
func (p *OpenAITokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
|
p.refreshAPI = api
|
|
p.executor = executor
|
|
}
|
|
|
|
// SetRefreshPolicy injects caller-side refresh policy.
|
|
func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
|
p.refreshPolicy = policy
|
|
}
|
|
|
|
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
|
|
if p == nil {
|
|
return OpenAITokenRuntimeMetrics{}
|
|
}
|
|
p.ensureMetrics()
|
|
return p.metrics.snapshot()
|
|
}
|
|
|
|
func (p *OpenAITokenProvider) ensureMetrics() {
|
|
if p != nil && p.metrics == nil {
|
|
p.metrics = &openAITokenRuntimeMetricsStore{}
|
|
}
|
|
}
|
|
|
|
// GetAccessToken returns a valid access_token.
|
|
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
|
p.ensureMetrics()
|
|
if account == nil {
|
|
return "", errors.New("account is nil")
|
|
}
|
|
if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth {
|
|
return "", errors.New("not an openai/sora oauth account")
|
|
}
|
|
|
|
cacheKey := OpenAITokenCacheKey(account)
|
|
|
|
// 1) Try cache first.
|
|
if p.tokenCache != nil {
|
|
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
|
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
|
|
return token, nil
|
|
} else if err != nil {
|
|
slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err)
|
|
}
|
|
}
|
|
|
|
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
|
|
|
|
// 2) Refresh if needed (pre-expiry skew).
|
|
expiresAt := account.GetCredentialAsTime("expires_at")
|
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
|
refreshFailed := false
|
|
|
|
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
|
p.metrics.refreshRequests.Add(1)
|
|
p.metrics.touchNow()
|
|
|
|
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
|
|
if account.Platform == PlatformSora {
|
|
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
|
|
refreshFailed = true
|
|
} else {
|
|
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
|
|
if err != nil {
|
|
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
|
return "", err
|
|
}
|
|
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
|
p.metrics.refreshFailure.Add(1)
|
|
refreshFailed = true
|
|
} else if result.LockHeld {
|
|
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
|
|
p.metrics.lockContention.Add(1)
|
|
p.metrics.touchNow()
|
|
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
|
if waitErr != nil {
|
|
return "", waitErr
|
|
}
|
|
if strings.TrimSpace(token) != "" {
|
|
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
|
return token, nil
|
|
}
|
|
}
|
|
} else if result.Refreshed {
|
|
p.metrics.refreshSuccess.Add(1)
|
|
account = result.Account
|
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
} else {
|
|
account = result.Account
|
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
|
}
|
|
}
|
|
} else if needsRefresh && p.tokenCache != nil {
|
|
// Backward-compatible test path when refreshAPI is not injected.
|
|
p.metrics.refreshRequests.Add(1)
|
|
p.metrics.touchNow()
|
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
|
if lockErr == nil && locked {
|
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
|
} else if lockErr != nil {
|
|
p.metrics.lockAcquireFailure.Add(1)
|
|
p.metrics.touchNow()
|
|
slog.Warn("openai_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
|
} else {
|
|
p.metrics.lockContention.Add(1)
|
|
p.metrics.touchNow()
|
|
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
|
if waitErr != nil {
|
|
return "", waitErr
|
|
}
|
|
if strings.TrimSpace(token) != "" {
|
|
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
|
return token, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
accessToken := account.GetCredential("access_token")
|
|
if strings.TrimSpace(accessToken) == "" {
|
|
return "", errors.New("access_token not found in credentials")
|
|
}
|
|
|
|
// 3) Populate cache with TTL.
|
|
if p.tokenCache != nil {
|
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
|
if isStale && latestAccount != nil {
|
|
slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID)
|
|
accessToken = latestAccount.GetOpenAIAccessToken()
|
|
if strings.TrimSpace(accessToken) == "" {
|
|
return "", errors.New("access_token not found after version check")
|
|
}
|
|
} else {
|
|
ttl := 30 * time.Minute
|
|
if refreshFailed {
|
|
if p.refreshPolicy.FailureTTL > 0 {
|
|
ttl = p.refreshPolicy.FailureTTL
|
|
} else {
|
|
ttl = time.Minute
|
|
}
|
|
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
|
} else if expiresAt != nil {
|
|
until := time.Until(*expiresAt)
|
|
switch {
|
|
case until > openAITokenCacheSkew:
|
|
ttl = until - openAITokenCacheSkew
|
|
case until > 0:
|
|
ttl = until
|
|
default:
|
|
ttl = time.Minute
|
|
}
|
|
}
|
|
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
|
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return accessToken, nil
|
|
}
|
|
|
|
func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) {
|
|
wait := openAILockInitialWait
|
|
totalWaitMs := int64(0)
|
|
for i := 0; i < openAILockMaxAttempts; i++ {
|
|
actualWait := jitterLockWait(wait)
|
|
timer := time.NewTimer(actualWait)
|
|
select {
|
|
case <-ctx.Done():
|
|
if !timer.Stop() {
|
|
select {
|
|
case <-timer.C:
|
|
default:
|
|
}
|
|
}
|
|
return "", ctx.Err()
|
|
case <-timer.C:
|
|
}
|
|
|
|
waitMs := actualWait.Milliseconds()
|
|
if waitMs < 0 {
|
|
waitMs = 0
|
|
}
|
|
totalWaitMs += waitMs
|
|
p.metrics.lockWaitSamples.Add(1)
|
|
p.metrics.lockWaitTotalMs.Add(waitMs)
|
|
p.metrics.touchNow()
|
|
|
|
token, err := p.tokenCache.GetAccessToken(ctx, cacheKey)
|
|
if err == nil && strings.TrimSpace(token) != "" {
|
|
p.metrics.lockWaitHit.Add(1)
|
|
if totalWaitMs >= openAILockWarnThresholdMs {
|
|
slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", i+1)
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
if wait < openAILockMaxWait {
|
|
wait *= 2
|
|
if wait > openAILockMaxWait {
|
|
wait = openAILockMaxWait
|
|
}
|
|
}
|
|
}
|
|
|
|
p.metrics.lockWaitMiss.Add(1)
|
|
if totalWaitMs >= openAILockWarnThresholdMs {
|
|
slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", openAILockMaxAttempts)
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func jitterLockWait(base time.Duration) time.Duration {
|
|
if base <= 0 {
|
|
return 0
|
|
}
|
|
minFactor := 1 - openAILockJitterRatio
|
|
maxFactor := 1 + openAILockJitterRatio
|
|
factor := minFactor + rand.Float64()*(maxFactor-minFactor)
|
|
return time.Duration(float64(base) * factor)
|
|
}
|