mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
fix: restore OAuth 401 temp-unschedulable for Gemini, update Antigravity tests
The 403 detection PR changed the 401 handler condition from `account.Type == AccountTypeOAuth` to `account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI`, which accidentally excluded Gemini OAuth from the temp-unschedulable path. Fix: use `!= PlatformAntigravity` instead, preserving Gemini behavior while correctly excluding Antigravity (whose 401 is handled by applyErrorPolicy's temp_unschedulable_rules). Update tests to reflect Antigravity's new 401 semantics: - HandleUpstreamError: Antigravity OAuth 401 now uses SetError - CheckErrorPolicy: Antigravity 401 second hit stays TempUnscheduled - DB fallback: split into Gemini (escalates) and Antigravity (stays temp)
This commit is contained in:
@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
|
||||
// Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
|
||||
// second hit 仍然返回 TempUnscheduled。
|
||||
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
|
||||
account: &Account{
|
||||
ID: 15,
|
||||
Type: AccountTypeOAuth,
|
||||
@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyNone,
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
|
||||
@@ -149,9 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
// 其他 400 错误(如参数问题)不处理,不禁用账号
|
||||
case 401:
|
||||
// OpenAI OAuth 账号在 401 错误时临时不可调度;其他平台 OAuth 账号保持原有 SetError 行为
|
||||
// (Antigravity 主流程不走此路径,其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制)
|
||||
if account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI {
|
||||
// OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。
|
||||
// Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。
|
||||
if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity {
|
||||
// 1. 失效缓存
|
||||
if s.tokenCacheInvalidator != nil {
|
||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
@@ -183,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
shouldDisable = true
|
||||
} else {
|
||||
// 非 OAuth 账号(APIKey):保持原有 SetError 行为
|
||||
// 非 OAuth / Antigravity OAuth:保持 SetError 行为
|
||||
msg := "Authentication failed (401): invalid or expired credentials"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Authentication failed (401): " + upstreamMsg
|
||||
|
||||
@@ -27,34 +27,68 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e
|
||||
|
||||
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
|
||||
// Scenario: cache account has empty TempUnschedulableReason (cache miss),
|
||||
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone.
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
// but DB account has a previous 401 record.
|
||||
// Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error).
|
||||
// Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules).
|
||||
t.Run("gemini_escalates", func(t *testing.T) {
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: "", // cache miss — reason is empty
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformGemini,
|
||||
TempUnschedulableReason: "",
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone")
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate")
|
||||
})
|
||||
|
||||
t.Run("antigravity_stays_temp", func(t *testing.T) {
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: "",
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {
|
||||
|
||||
@@ -42,45 +42,56 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
}{
|
||||
{name: "gemini", platform: PlatformGemini},
|
||||
{name: "antigravity", platform: PlatformAntigravity},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: tt.platform,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": 401,
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": 30,
|
||||
"description": "custom rule",
|
||||
},
|
||||
t.Run("gemini", func(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": 401,
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": 30,
|
||||
"description": "custom rule",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
})
|
||||
}
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
})
|
||||
|
||||
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
|
||||
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
|
||||
// HandleUpstreamError 中走 SetError 路径。
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
require.Empty(t, invalidator.accounts)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user