Files
sub2api/backend/internal/service/oauth_refresh_api_test.go

396 lines
13 KiB
Go
Raw Permalink Normal View History

//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------- mock helpers ----------
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
type refreshAPIAccountRepo struct {
mockAccountRepoForGemini
account *Account // returned by GetByID
getByIDErr error
updateErr error
updateCalls int
}
func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
if r.getByIDErr != nil {
return nil, r.getByIDErr
}
return r.account, nil
}
func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
r.updateCalls++
return r.updateErr
}
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
type refreshAPIExecutorStub struct {
needsRefresh bool
credentials map[string]any
err error
refreshCalls int
}
func (e *refreshAPIExecutorStub) CanRefresh(_ *Account) bool { return true }
func (e *refreshAPIExecutorStub) NeedsRefresh(_ *Account, _ time.Duration) bool {
return e.needsRefresh
}
func (e *refreshAPIExecutorStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
e.refreshCalls++
if e.err != nil {
return nil, e.err
}
return e.credentials, nil
}
func (e *refreshAPIExecutorStub) CacheKey(account *Account) string {
return "test:api:" + account.Platform
}
// refreshAPICacheStub implements GeminiTokenCache for OAuthRefreshAPI tests.
type refreshAPICacheStub struct {
lockResult bool
lockErr error
releaseCalls int
}
func (c *refreshAPICacheStub) GetAccessToken(context.Context, string) (string, error) {
return "", nil
}
func (c *refreshAPICacheStub) SetAccessToken(context.Context, string, string, time.Duration) error {
return nil
}
func (c *refreshAPICacheStub) DeleteAccessToken(context.Context, string) error { return nil }
func (c *refreshAPICacheStub) AcquireRefreshLock(context.Context, string, time.Duration) (bool, error) {
return c.lockResult, c.lockErr
}
func (c *refreshAPICacheStub) ReleaseRefreshLock(context.Context, string) error {
c.releaseCalls++
return nil
}
// ========== RefreshIfNeeded tests ==========
func TestRefreshIfNeeded_Success(t *testing.T) {
account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "new-token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.NotNil(t, result.NewCredentials)
require.Equal(t, "new-token", result.NewCredentials["access_token"])
require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set
require.Equal(t, 1, repo.updateCalls) // DB updated
require.Equal(t, 1, cache.releaseCalls) // lock released
require.Equal(t, 1, executor.refreshCalls)
}
func TestRefreshIfNeeded_LockHeld(t *testing.T) {
account := &Account{ID: 2, Platform: PlatformAnthropic}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: false} // lock not acquired
executor := &refreshAPIExecutorStub{needsRefresh: true}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.LockHeld)
require.False(t, result.Refreshed)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, executor.refreshCalls)
}
func TestRefreshIfNeeded_LockErrorDegrades(t *testing.T) {
account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockErr: errors.New("redis down")} // lock error
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "degraded-token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed) // still refreshed (degraded mode)
require.Equal(t, 1, repo.updateCalls) // DB updated
require.Equal(t, 0, cache.releaseCalls) // no lock to release
require.Equal(t, 1, executor.refreshCalls)
}
func TestRefreshIfNeeded_NoCacheNoLock(t *testing.T) {
account := &Account{ID: 4, Platform: PlatformGemini, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{account: account}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "no-cache-token"},
}
api := NewOAuthRefreshAPI(repo, nil) // no cache = no lock
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.Equal(t, 1, repo.updateCalls)
}
func TestRefreshIfNeeded_AlreadyRefreshed(t *testing.T) {
account := &Account{ID: 5, Platform: PlatformAnthropic}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{needsRefresh: false} // already refreshed
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.False(t, result.Refreshed)
require.False(t, result.LockHeld)
require.NotNil(t, result.Account) // returns fresh account
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, executor.refreshCalls)
}
func TestRefreshIfNeeded_RefreshError(t *testing.T) {
account := &Account{ID: 6, Platform: PlatformAnthropic}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
err: errors.New("invalid_grant: token revoked"),
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "invalid_grant")
require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error
require.Equal(t, 1, cache.releaseCalls) // lock still released via defer
}
func TestRefreshIfNeeded_DBUpdateError(t *testing.T) {
account := &Account{ID: 7, Platform: PlatformGemini, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{
account: account,
updateErr: errors.New("db connection lost"),
}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "DB update failed")
require.Equal(t, 1, repo.updateCalls) // attempted
}
func TestRefreshIfNeeded_DBRereadFails(t *testing.T) {
account := &Account{ID: 8, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{
account: nil, // GetByID returns nil
getByIDErr: errors.New("db timeout"),
}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "fallback-token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.Equal(t, 1, executor.refreshCalls) // still refreshes using passed-in account
}
func TestRefreshIfNeeded_NilCredentials(t *testing.T) {
account := &Account{ID: 9, Platform: PlatformGemini, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: nil, // Refresh returns nil credentials
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.Nil(t, result.NewCredentials)
require.Equal(t, 0, repo.updateCalls) // no DB update when credentials are nil
}
// ========== MergeCredentials tests ==========
func TestMergeCredentials_Basic(t *testing.T) {
old := map[string]any{"a": "1", "b": "2", "c": "3"}
new := map[string]any{"a": "new", "d": "4"}
result := MergeCredentials(old, new)
require.Equal(t, "new", result["a"]) // new value preserved
require.Equal(t, "2", result["b"]) // old value kept
require.Equal(t, "3", result["c"]) // old value kept
require.Equal(t, "4", result["d"]) // new value preserved
}
func TestMergeCredentials_NilNew(t *testing.T) {
old := map[string]any{"a": "1"}
result := MergeCredentials(old, nil)
require.NotNil(t, result)
require.Equal(t, "1", result["a"])
}
func TestMergeCredentials_NilOld(t *testing.T) {
new := map[string]any{"a": "1"}
result := MergeCredentials(nil, new)
require.Equal(t, "1", result["a"])
}
func TestMergeCredentials_BothNil(t *testing.T) {
result := MergeCredentials(nil, nil)
require.NotNil(t, result)
require.Empty(t, result)
}
func TestMergeCredentials_NewOverridesOld(t *testing.T) {
old := map[string]any{"access_token": "old-token", "refresh_token": "old-refresh"}
new := map[string]any{"access_token": "new-token"}
result := MergeCredentials(old, new)
require.Equal(t, "new-token", result["access_token"]) // overridden
require.Equal(t, "old-refresh", result["refresh_token"]) // preserved
}
// ========== BuildClaudeAccountCredentials tests ==========
func TestBuildClaudeAccountCredentials_Full(t *testing.T) {
tokenInfo := &TokenInfo{
AccessToken: "at-123",
TokenType: "Bearer",
ExpiresIn: 3600,
ExpiresAt: 1700000000,
RefreshToken: "rt-456",
Scope: "openid",
}
creds := BuildClaudeAccountCredentials(tokenInfo)
require.Equal(t, "at-123", creds["access_token"])
require.Equal(t, "Bearer", creds["token_type"])
require.Equal(t, "3600", creds["expires_in"])
require.Equal(t, "1700000000", creds["expires_at"])
require.Equal(t, "rt-456", creds["refresh_token"])
require.Equal(t, "openid", creds["scope"])
}
func TestBuildClaudeAccountCredentials_Minimal(t *testing.T) {
tokenInfo := &TokenInfo{
AccessToken: "at-789",
TokenType: "Bearer",
ExpiresIn: 7200,
ExpiresAt: 1700003600,
}
creds := BuildClaudeAccountCredentials(tokenInfo)
require.Equal(t, "at-789", creds["access_token"])
require.Equal(t, "Bearer", creds["token_type"])
require.Equal(t, "7200", creds["expires_in"])
require.Equal(t, "1700003600", creds["expires_at"])
_, hasRefresh := creds["refresh_token"]
_, hasScope := creds["scope"]
require.False(t, hasRefresh, "refresh_token should not be set when empty")
require.False(t, hasScope, "scope should not be set when empty")
}
// ========== BackgroundRefreshPolicy tests ==========
func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) {
p := DefaultBackgroundRefreshPolicy()
require.ErrorIs(t, p.handleLockHeld(), errRefreshSkipped)
require.ErrorIs(t, p.handleAlreadyRefreshed(), errRefreshSkipped)
}
func TestBackgroundRefreshPolicy_SuccessOverride(t *testing.T) {
p := BackgroundRefreshPolicy{
OnLockHeld: BackgroundSkipAsSuccess,
OnAlreadyRefresh: BackgroundSkipAsSuccess,
}
require.NoError(t, p.handleLockHeld())
require.NoError(t, p.handleAlreadyRefreshed())
}
// ========== ProviderRefreshPolicy tests ==========
func TestClaudeProviderRefreshPolicy(t *testing.T) {
p := ClaudeProviderRefreshPolicy()
require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError)
require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld)
require.Equal(t, time.Minute, p.FailureTTL)
}
func TestOpenAIProviderRefreshPolicy(t *testing.T) {
p := OpenAIProviderRefreshPolicy()
require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError)
require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld)
require.Equal(t, time.Minute, p.FailureTTL)
}
func TestGeminiProviderRefreshPolicy(t *testing.T) {
p := GeminiProviderRefreshPolicy()
require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError)
require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld)
require.Equal(t, time.Duration(0), p.FailureTTL)
}
func TestAntigravityProviderRefreshPolicy(t *testing.T) {
p := AntigravityProviderRefreshPolicy()
require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError)
require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld)
require.Equal(t, time.Duration(0), p.FailureTTL)
}