mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
fix(account): preserve runtime state during credentials-only updates
This commit is contained in:
@@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
|
||||||
|
_, err := r.client.Account.UpdateOneID(id).
|
||||||
|
SetCredentials(normalizeJSONMap(credentials)).
|
||||||
|
Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||||
|
}
|
||||||
|
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||||
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
|
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
30
backend/internal/service/account_credentials_persistence.go
Normal file
30
backend/internal/service/account_credentials_persistence.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type accountCredentialsUpdater interface {
|
||||||
|
UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func persistAccountCredentials(ctx context.Context, repo AccountRepository, account *Account, credentials map[string]any) error {
|
||||||
|
if repo == nil || account == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Credentials = cloneCredentials(credentials)
|
||||||
|
if updater, ok := any(repo).(accountCredentialsUpdater); ok {
|
||||||
|
return updater.UpdateCredentials(ctx, account.ID, account.Credentials)
|
||||||
|
}
|
||||||
|
return repo.Update(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneCredentials(in map[string]any) map[string]any {
|
||||||
|
if in == nil {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
out := make(map[string]any, len(in))
|
||||||
|
for k, v := range in {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
p.markBackfillAttempted(account.ID)
|
p.markBackfillAttempted(account.ID)
|
||||||
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
||||||
account.Credentials["project_id"] = projectID
|
account.Credentials["project_id"] = projectID
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
if updateErr := persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials); updateErr != nil {
|
||||||
slog.Warn("antigravity_project_id_backfill_persist_failed",
|
slog.Warn("antigravity_project_id_backfill_persist_failed",
|
||||||
"account_id", account.ID,
|
"account_id", account.ID,
|
||||||
"error", updateErr,
|
"error", updateErr,
|
||||||
|
|||||||
@@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
// 🔄 Refresh OAuth token after creation
|
// 🔄 Refresh OAuth token after creation
|
||||||
if targetType == AccountTypeOAuth {
|
if targetType == AccountTypeOAuth {
|
||||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||||
account.Credentials = refreshedCreds
|
_ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
|
||||||
_ = s.accountRepo.Update(ctx, account)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
item.Action = "created"
|
item.Action = "created"
|
||||||
@@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
// 🔄 Refresh OAuth token after update
|
// 🔄 Refresh OAuth token after update
|
||||||
if targetType == AccountTypeOAuth {
|
if targetType == AccountTypeOAuth {
|
||||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||||
existing.Credentials = refreshedCreds
|
_ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
|
||||||
_ = s.accountRepo.Update(ctx, existing)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
}
|
}
|
||||||
// 🔄 Refresh OAuth token after creation
|
// 🔄 Refresh OAuth token after creation
|
||||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||||
account.Credentials = refreshedCreds
|
_ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
|
||||||
_ = s.accountRepo.Update(ctx, account)
|
|
||||||
}
|
}
|
||||||
item.Action = "created"
|
item.Action = "created"
|
||||||
result.Created++
|
result.Created++
|
||||||
@@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
|
|
||||||
// 🔄 Refresh OAuth token after update
|
// 🔄 Refresh OAuth token after update
|
||||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||||
existing.Credentials = refreshedCreds
|
_ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
|
||||||
_ = s.accountRepo.Update(ctx, existing)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
item.Action = "updated"
|
item.Action = "updated"
|
||||||
@@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||||
account.Credentials = refreshedCreds
|
_ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds)
|
||||||
_ = s.accountRepo.Update(ctx, account)
|
|
||||||
}
|
}
|
||||||
item.Action = "created"
|
item.Action = "created"
|
||||||
result.Created++
|
result.Created++
|
||||||
@@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
}
|
}
|
||||||
|
|
||||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||||
existing.Credentials = refreshedCreds
|
_ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds)
|
||||||
_ = s.accountRepo.Update(ctx, existing)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
item.Action = "updated"
|
item.Action = "updated"
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
if tierID != "" {
|
if tierID != "" {
|
||||||
account.Credentials["tier_id"] = tierID
|
account.Credentials["tier_id"] = tierID
|
||||||
}
|
}
|
||||||
_ = p.accountRepo.Update(ctx, account)
|
_ = persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
|||||||
// 5. 设置版本号 + 更新 DB
|
// 5. 设置版本号 + 更新 DB
|
||||||
if newCredentials != nil {
|
if newCredentials != nil {
|
||||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||||
freshAccount.Credentials = newCredentials
|
if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil {
|
||||||
if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil {
|
|
||||||
slog.Error("oauth_refresh_update_failed",
|
slog.Error("oauth_refresh_update_failed",
|
||||||
"account_id", freshAccount.ID,
|
"account_id", freshAccount.ID,
|
||||||
"error", updateErr,
|
"error", updateErr,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type refreshAPIAccountRepo struct {
|
|||||||
getByIDErr error
|
getByIDErr error
|
||||||
updateErr error
|
updateErr error
|
||||||
updateCalls int
|
updateCalls int
|
||||||
|
updateCredentialsCalls int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
|
func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
|
||||||
@@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
|
|||||||
return r.updateErr
|
return r.updateErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *refreshAPIAccountRepo) UpdateCredentials(_ context.Context, id int64, credentials map[string]any) error {
|
||||||
|
r.updateCalls++
|
||||||
|
r.updateCredentialsCalls++
|
||||||
|
if r.updateErr != nil {
|
||||||
|
return r.updateErr
|
||||||
|
}
|
||||||
|
if r.account == nil || r.account.ID != id {
|
||||||
|
r.account = &Account{ID: id}
|
||||||
|
}
|
||||||
|
r.account.Credentials = cloneCredentials(credentials)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
|
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
|
||||||
type refreshAPIExecutorStub struct {
|
type refreshAPIExecutorStub struct {
|
||||||
needsRefresh bool
|
needsRefresh bool
|
||||||
@@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) {
|
|||||||
require.Equal(t, "new-token", result.NewCredentials["access_token"])
|
require.Equal(t, "new-token", result.NewCredentials["access_token"])
|
||||||
require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set
|
require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set
|
||||||
require.Equal(t, 1, repo.updateCalls) // DB updated
|
require.Equal(t, 1, repo.updateCalls) // DB updated
|
||||||
|
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||||
require.Equal(t, 1, cache.releaseCalls) // lock released
|
require.Equal(t, 1, cache.releaseCalls) // lock released
|
||||||
require.Equal(t, 1, executor.refreshCalls)
|
require.Equal(t, 1, executor.refreshCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState(t *testing.T) {
|
||||||
|
resetAt := time.Now().Add(45 * time.Minute)
|
||||||
|
account := &Account{
|
||||||
|
ID: 11,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
RateLimitResetAt: &resetAt,
|
||||||
|
}
|
||||||
|
repo := &refreshAPIAccountRepo{account: account}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
executor := &refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
credentials: map[string]any{"access_token": "safe-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result.Refreshed)
|
||||||
|
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||||
|
require.NotNil(t, repo.account.RateLimitResetAt)
|
||||||
|
require.WithinDuration(t, resetAt, *repo.account.RateLimitResetAt, time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
func TestRefreshIfNeeded_LockHeld(t *testing.T) {
|
func TestRefreshIfNeeded_LockHeld(t *testing.T) {
|
||||||
account := &Account{ID: 2, Platform: PlatformAnthropic}
|
account := &Account{ID: 2, Platform: PlatformAnthropic}
|
||||||
repo := &refreshAPIAccountRepo{account: account}
|
repo := &refreshAPIAccountRepo{account: account}
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
account.Credentials = make(map[string]any)
|
account.Credentials = make(map[string]any)
|
||||||
}
|
}
|
||||||
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
|
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
|
||||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil {
|
||||||
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
|
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
|
||||||
} else {
|
} else {
|
||||||
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
|
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ type rateLimitAccountRepoStub struct {
|
|||||||
mockAccountRepoForGemini
|
mockAccountRepoForGemini
|
||||||
setErrorCalls int
|
setErrorCalls int
|
||||||
tempCalls int
|
tempCalls int
|
||||||
|
updateCredentialsCalls int
|
||||||
|
lastCredentials map[string]any
|
||||||
lastErrorMsg string
|
lastErrorMsg string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
|
||||||
|
r.updateCredentialsCalls++
|
||||||
|
r.lastCredentials = cloneCredentials(credentials)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type tokenCacheInvalidatorRecorder struct {
|
type tokenCacheInvalidatorRecorder struct {
|
||||||
accounts []*Account
|
accounts []*Account
|
||||||
err error
|
err error
|
||||||
@@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
|
|||||||
require.True(t, shouldDisable)
|
require.True(t, shouldDisable)
|
||||||
require.Equal(t, 0, repo.setErrorCalls)
|
require.Equal(t, 0, repo.setErrorCalls)
|
||||||
require.Equal(t, 1, repo.tempCalls)
|
require.Equal(t, 1, repo.tempCalls)
|
||||||
|
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||||
require.Len(t, invalidator.accounts, 1)
|
require.Len(t, invalidator.accounts, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
|||||||
require.Equal(t, 1, repo.setErrorCalls)
|
require.Equal(t, 1, repo.setErrorCalls)
|
||||||
require.Empty(t, invalidator.accounts)
|
require.Empty(t, invalidator.accounts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 103,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||||
|
require.NotEmpty(t, repo.lastCredentials["expires_at"])
|
||||||
|
}
|
||||||
|
|||||||
@@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
if c.accountRepo != nil {
|
if c.accountRepo != nil {
|
||||||
if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() {
|
if err := persistAccountCredentials(ctx, c.accountRepo, account, account.Credentials); err != nil && c.debugEnabled() {
|
||||||
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
|||||||
newCredentials, err = refresher.Refresh(ctx, account)
|
newCredentials, err = refresher.Refresh(ctx, account)
|
||||||
if newCredentials != nil {
|
if newCredentials != nil {
|
||||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||||
account.Credentials = newCredentials
|
if saveErr := persistAccountCredentials(ctx, s.accountRepo, account, newCredentials); saveErr != nil {
|
||||||
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
|
||||||
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ import (
|
|||||||
type tokenRefreshAccountRepo struct {
|
type tokenRefreshAccountRepo struct {
|
||||||
mockAccountRepoForGemini
|
mockAccountRepoForGemini
|
||||||
updateCalls int
|
updateCalls int
|
||||||
|
fullUpdateCalls int
|
||||||
|
updateCredentialsCalls int
|
||||||
setErrorCalls int
|
setErrorCalls int
|
||||||
clearTempCalls int
|
clearTempCalls int
|
||||||
lastAccount *Account
|
lastAccount *Account
|
||||||
@@ -23,10 +25,29 @@ type tokenRefreshAccountRepo struct {
|
|||||||
|
|
||||||
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
|
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
|
||||||
r.updateCalls++
|
r.updateCalls++
|
||||||
|
r.fullUpdateCalls++
|
||||||
r.lastAccount = account
|
r.lastAccount = account
|
||||||
return r.updateErr
|
return r.updateErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *tokenRefreshAccountRepo) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
|
||||||
|
r.updateCalls++
|
||||||
|
r.updateCredentialsCalls++
|
||||||
|
if r.updateErr != nil {
|
||||||
|
return r.updateErr
|
||||||
|
}
|
||||||
|
cloned := cloneCredentials(credentials)
|
||||||
|
if r.accountsByID != nil {
|
||||||
|
if acc, ok := r.accountsByID[id]; ok && acc != nil {
|
||||||
|
acc.Credentials = cloned
|
||||||
|
r.lastAccount = acc
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.lastAccount = &Account{ID: id, Credentials: cloned}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
|
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
r.setErrorCalls++
|
r.setErrorCalls++
|
||||||
return nil
|
return nil
|
||||||
@@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
|||||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, repo.updateCalls)
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||||
|
require.Equal(t, 0, repo.fullUpdateCalls)
|
||||||
require.Equal(t, 1, invalidator.calls)
|
require.Equal(t, 1, invalidator.calls)
|
||||||
require.Equal(t, "new-token", account.GetCredential("access_token"))
|
require.Equal(t, "new-token", account.GetCredential("access_token"))
|
||||||
}
|
}
|
||||||
@@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
|||||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, repo.updateCalls)
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
|
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||||
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
|
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.T) {
|
||||||
|
repo := &tokenRefreshAccountRepo{}
|
||||||
|
cfg := &config.Config{
|
||||||
|
TokenRefresh: config.TokenRefreshConfig{
|
||||||
|
MaxRetries: 1,
|
||||||
|
RetryBackoffSeconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||||
|
resetAt := time.Now().Add(30 * time.Minute)
|
||||||
|
account := &Account{
|
||||||
|
ID: 17,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
RateLimitResetAt: &resetAt,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "old-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
refresher := &tokenRefresherStub{
|
||||||
|
credentials: map[string]any{
|
||||||
|
"access_token": "new-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||||
|
require.Equal(t, 0, repo.fullUpdateCalls)
|
||||||
|
require.NotNil(t, account.RateLimitResetAt)
|
||||||
|
require.WithinDuration(t, resetAt, *account.RateLimitResetAt, time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
|
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
|
||||||
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
||||||
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
|
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
|
||||||
|
|||||||
Reference in New Issue
Block a user