mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-08 09:10:20 +08:00
The 403 detection PR changed the 401 handler condition from `account.Type == AccountTypeOAuth` to `account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI`, which accidentally excluded Gemini OAuth from the temp-unschedulable path. Fix: use `!= PlatformAntigravity` instead, preserving Gemini behavior while correctly excluding Antigravity (whose 401 is handled by applyErrorPolicy's temp_unschedulable_rules). Update tests to reflect Antigravity's new 401 semantics: - HandleUpstreamError: Antigravity OAuth 401 now uses SetError - CheckErrorPolicy: Antigravity 401 second hit stays TempUnscheduled - DB fallback: split into Gemini (escalates) and Antigravity (stays temp)
133 lines
3.9 KiB
Go
133 lines
3.9 KiB
Go
//go:build unit
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type rateLimitAccountRepoStub struct {
|
|
mockAccountRepoForGemini
|
|
setErrorCalls int
|
|
tempCalls int
|
|
lastErrorMsg string
|
|
}
|
|
|
|
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
|
r.setErrorCalls++
|
|
r.lastErrorMsg = errorMsg
|
|
return nil
|
|
}
|
|
|
|
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
|
r.tempCalls++
|
|
return nil
|
|
}
|
|
|
|
type tokenCacheInvalidatorRecorder struct {
|
|
accounts []*Account
|
|
err error
|
|
}
|
|
|
|
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
|
|
r.accounts = append(r.accounts, account)
|
|
return r.err
|
|
}
|
|
|
|
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
|
|
t.Run("gemini", func(t *testing.T) {
|
|
repo := &rateLimitAccountRepoStub{}
|
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
service.SetTokenCacheInvalidator(invalidator)
|
|
account := &Account{
|
|
ID: 100,
|
|
Platform: PlatformGemini,
|
|
Type: AccountTypeOAuth,
|
|
Credentials: map[string]any{
|
|
"temp_unschedulable_enabled": true,
|
|
"temp_unschedulable_rules": []any{
|
|
map[string]any{
|
|
"error_code": 401,
|
|
"keywords": []any{"unauthorized"},
|
|
"duration_minutes": 30,
|
|
"description": "custom rule",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
|
|
require.True(t, shouldDisable)
|
|
require.Equal(t, 0, repo.setErrorCalls)
|
|
require.Equal(t, 1, repo.tempCalls)
|
|
require.Len(t, invalidator.accounts, 1)
|
|
})
|
|
|
|
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
|
|
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
|
|
// HandleUpstreamError 中走 SetError 路径。
|
|
repo := &rateLimitAccountRepoStub{}
|
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
service.SetTokenCacheInvalidator(invalidator)
|
|
account := &Account{
|
|
ID: 100,
|
|
Platform: PlatformAntigravity,
|
|
Type: AccountTypeOAuth,
|
|
}
|
|
|
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
|
|
require.True(t, shouldDisable)
|
|
require.Equal(t, 1, repo.setErrorCalls)
|
|
require.Equal(t, 0, repo.tempCalls)
|
|
require.Empty(t, invalidator.accounts)
|
|
})
|
|
}
|
|
|
|
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
|
|
repo := &rateLimitAccountRepoStub{}
|
|
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
|
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
service.SetTokenCacheInvalidator(invalidator)
|
|
account := &Account{
|
|
ID: 101,
|
|
Platform: PlatformGemini,
|
|
Type: AccountTypeOAuth,
|
|
}
|
|
|
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
|
|
require.True(t, shouldDisable)
|
|
require.Equal(t, 0, repo.setErrorCalls)
|
|
require.Equal(t, 1, repo.tempCalls)
|
|
require.Len(t, invalidator.accounts, 1)
|
|
}
|
|
|
|
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
|
repo := &rateLimitAccountRepoStub{}
|
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
service.SetTokenCacheInvalidator(invalidator)
|
|
account := &Account{
|
|
ID: 102,
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
}
|
|
|
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
|
|
require.True(t, shouldDisable)
|
|
require.Equal(t, 1, repo.setErrorCalls)
|
|
require.Empty(t, invalidator.accounts)
|
|
}
|