fix: restore OAuth 401 temp-unschedulable for Gemini, update Antigravity tests

The 403 detection PR changed the 401 handler condition from
`account.Type == AccountTypeOAuth` to
`account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI`,
which accidentally excluded Gemini OAuth from the temp-unschedulable path.

Fix: use `!= PlatformAntigravity` instead, preserving Gemini behavior
while correctly excluding Antigravity (whose 401 is handled by
applyErrorPolicy's temp_unschedulable_rules).

Update tests to reflect Antigravity's new 401 semantics:
- HandleUpstreamError: Antigravity OAuth 401 now uses SetError
- CheckErrorPolicy: Antigravity 401 second hit stays TempUnscheduled
- DB fallback: split into Gemini (escalates) and Antigravity (stays temp)
This commit is contained in:
erio
2026-03-14 02:21:22 +08:00
parent 6344fa2a86
commit 45456fa24c
4 changed files with 112 additions and 65 deletions

View File

@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) {
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
// Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
// second hit 仍然返回 TempUnscheduled。
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
account: &Account{
ID: 15,
Type: AccountTypeOAuth,
@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) {
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyNone,
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_body_miss_returns_none",

View File

@@ -149,9 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
// 其他 400 错误(如参数问题)不处理,不禁用账号
case 401:
// OpenAI OAuth 账号在 401 错误时临时不可调度;其他平台 OAuth 账号保持原有 SetError 行为
// Antigravity 主流程不走此路径,其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制
if account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI {
// OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为
// Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制
if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity {
// 1. 失效缓存
if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
@@ -183,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
shouldDisable = true
} else {
// 非 OAuth 账号APIKey:保持原有 SetError 行为
// 非 OAuth / Antigravity OAuth:保持 SetError 行为
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg

View File

@@ -27,34 +27,68 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason (cache miss),
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone.
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
// but DB account has a previous 401 record.
// Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error).
// Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules).
t.Run("gemini_escalates", func(t *testing.T) {
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "", // cache miss — reason is empty
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformGemini,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
},
}
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone")
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate")
})
t.Run("antigravity_stays_temp", func(t *testing.T) {
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled")
})
}
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {

View File

@@ -42,45 +42,56 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
}
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
tests := []struct {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: tt.platform,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
t.Run("gemini", func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
},
}
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1)
})
}
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1)
})
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
// HandleUpstreamError 中走 SetError 路径。
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Empty(t, invalidator.accounts)
})
}
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {