diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index d46bbc45..1e63315b 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -19,6 +19,16 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" ) +// ForbiddenError 表示上游返回 403 Forbidden +type ForbiddenError struct { + StatusCode int + Body string +} + +func (e *ForbiddenError) Error() string { + return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body) +} + // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { // 构建 URL,流式请求添加 ?alt=sse 参数 @@ -514,7 +524,20 @@ type ModelQuotaInfo struct { // ModelInfo 模型信息 type ModelInfo struct { - QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"` + QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"` + DisplayName string `json:"displayName,omitempty"` + SupportsImages *bool `json:"supportsImages,omitempty"` + SupportsThinking *bool `json:"supportsThinking,omitempty"` + ThinkingBudget *int `json:"thinkingBudget,omitempty"` + Recommended *bool `json:"recommended,omitempty"` + MaxTokens *int `json:"maxTokens,omitempty"` + MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` + SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"` +} + +// DeprecatedModelInfo 废弃模型转发信息 +type DeprecatedModelInfo struct { + NewModelID string `json:"newModelId"` } // FetchAvailableModelsRequest fetchAvailableModels 请求 @@ -524,7 +547,8 @@ type FetchAvailableModelsRequest struct { // FetchAvailableModelsResponse fetchAvailableModels 响应 type FetchAvailableModelsResponse struct { - Models map[string]ModelInfo `json:"models"` + Models map[string]ModelInfo `json:"models"` + DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"` } // FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON @@ -573,6 +597,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI continue } + if resp.StatusCode == http.StatusForbidden { + return nil, nil, &ForbiddenError{ + StatusCode: resp.StatusCode, + Body: string(respBodyBytes), + } + } + if resp.StatusCode != http.StatusOK { return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 3dd931be..d41e890a 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "log" + "log/slog" "math/rand/v2" "net/http" "strings" @@ -100,6 +101,7 @@ type antigravityUsageCache struct { const ( apiCacheTTL = 3 * time.Minute apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟 + antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误) apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟 windowStatsCacheTTL = 1 * time.Minute openAIProbeCacheTTL = 10 * time.Minute @@ -108,11 +110,12 @@ const ( // UsageCache 封装账户使用量相关的缓存 type UsageCache struct { - apiCache sync.Map // accountID -> *apiUsageCache - windowStatsCache sync.Map // accountID -> *windowStatsCache - antigravityCache sync.Map // accountID -> *antigravityUsageCache - apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存 - openAIProbeCache sync.Map // accountID -> time.Time + apiCache sync.Map // accountID -> *apiUsageCache + windowStatsCache sync.Map // accountID -> *windowStatsCache + antigravityCache sync.Map // accountID -> *antigravityUsageCache + apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic) + antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存 + openAIProbeCache sync.Map // accountID -> time.Time } // NewUsageCache 创建 UsageCache 实例 @@ -149,6 +152,18 @@ type AntigravityModelQuota struct { ResetTime string `json:"reset_time"` // 重置时间 ISO8601 } +// AntigravityModelDetail Antigravity 单个模型的详细能力信息 +type AntigravityModelDetail struct { + DisplayName string `json:"display_name,omitempty"` + SupportsImages *bool `json:"supports_images,omitempty"` + SupportsThinking *bool `json:"supports_thinking,omitempty"` + ThinkingBudget *int `json:"thinking_budget,omitempty"` + Recommended *bool `json:"recommended,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"` +} + // UsageInfo 账号使用量信息 type UsageInfo struct { UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 @@ -164,6 +179,33 @@ type UsageInfo struct { // Antigravity 多模型配额 AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"` + + // Antigravity 账号级信息 + SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN + SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称 + + // Antigravity 模型详细能力信息(与 antigravity_quota 同 key) + AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"` + + // Antigravity 废弃模型转发规则 (old_model_id -> new_model_id) + ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"` + + // Antigravity 账号是否被上游禁止 (HTTP 403) + IsForbidden bool `json:"is_forbidden,omitempty"` + ForbiddenReason string `json:"forbidden_reason,omitempty"` + ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden" + ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接 + + // 状态标记(从 ForbiddenType / HTTP 错误码推导) + NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证(forbidden_type=validation) + IsBanned bool `json:"is_banned,omitempty"` // 账号被封(forbidden_type=violation) + NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权(401) + + // 错误码(机器可读):forbidden / unauthenticated / rate_limited / network_error + ErrorCode string `json:"error_code,omitempty"` + + // 获取 usage 时的错误信息(降级返回,而非 500) + Error string `json:"error,omitempty"` } // ClaudeUsageResponse Anthropic API返回的usage结构 @@ -648,34 +690,157 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account * return &UsageInfo{UpdatedAt: &now}, nil } - // 1. 检查缓存(10 分钟) + // 1. 检查缓存 if cached, ok := s.cache.antigravityCache.Load(account.ID); ok { - if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { - // 重新计算 RemainingSeconds - usage := cache.usageInfo - if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil { - usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds()) + if cache, ok := cached.(*antigravityUsageCache); ok { + ttl := antigravityCacheTTL(cache.usageInfo) + if time.Since(cache.timestamp) < ttl { + usage := cache.usageInfo + if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil { + usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds()) + } + return usage, nil } - return usage, nil } } - // 2. 获取代理 URL - proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account) + // 2. singleflight 防止并发击穿 + flightKey := fmt.Sprintf("ag-usage:%d", account.ID) + result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) { + // 再次检查缓存(等待期间可能已被填充) + if cached, ok := s.cache.antigravityCache.Load(account.ID); ok { + if cache, ok := cached.(*antigravityUsageCache); ok { + ttl := antigravityCacheTTL(cache.usageInfo) + if time.Since(cache.timestamp) < ttl { + usage := cache.usageInfo + // 重新计算 RemainingSeconds,避免返回过时的剩余秒数 + recalcAntigravityRemainingSeconds(usage) + return usage, nil + } + } + } - // 3. 调用 API 获取额度 - result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL) - if err != nil { - return nil, fmt.Errorf("fetch antigravity quota failed: %w", err) - } + // 使用独立 context,避免调用方 cancel 导致所有共享 flight 的请求失败 + fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer fetchCancel() - // 4. 缓存结果 - s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ - usageInfo: result.UsageInfo, - timestamp: time.Now(), + proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account) + fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL) + if err != nil { + degraded := buildAntigravityDegradedUsage(err) + enrichUsageWithAccountError(degraded, account) + s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ + usageInfo: degraded, + timestamp: time.Now(), + }) + return degraded, nil + } + + enrichUsageWithAccountError(fetchResult.UsageInfo, account) + s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ + usageInfo: fetchResult.UsageInfo, + timestamp: time.Now(), + }) + return fetchResult.UsageInfo, nil }) - return result.UsageInfo, nil + if flightErr != nil { + return nil, flightErr + } + usage, ok := result.(*UsageInfo) + if !ok || usage == nil { + now := time.Now() + return &UsageInfo{UpdatedAt: &now}, nil + } + return usage, nil +} + +// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds +// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数 +func recalcAntigravityRemainingSeconds(info *UsageInfo) { + if info == nil { + return + } + if info.FiveHour != nil && info.FiveHour.ResetsAt != nil { + remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds()) + if remaining < 0 { + remaining = 0 + } + info.FiveHour.RemainingSeconds = remaining + } +} + +// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL +// 403 forbidden 状态稳定,缓存与成功相同(3 分钟); +// 其他错误(401/网络)可能快速恢复,缓存 1 分钟。 +func antigravityCacheTTL(info *UsageInfo) time.Duration { + if info == nil { + return antigravityErrorTTL + } + if info.IsForbidden { + return apiCacheTTL // 封号/验证状态不会很快变 + } + if info.ErrorCode != "" || info.Error != "" { + return antigravityErrorTTL + } + return apiCacheTTL +} + +// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo +func buildAntigravityDegradedUsage(err error) *UsageInfo { + now := time.Now() + errMsg := fmt.Sprintf("usage API error: %v", err) + slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err) + + info := &UsageInfo{ + UpdatedAt: &now, + Error: errMsg, + } + + // 从错误信息推断 error_code 和状态标记 + // 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..." + errStr := err.Error() + switch { + case strings.Contains(errStr, "HTTP 401") || + strings.Contains(errStr, "UNAUTHENTICATED") || + strings.Contains(errStr, "invalid_grant"): + info.ErrorCode = errorCodeUnauthenticated + info.NeedsReauth = true + case strings.Contains(errStr, "HTTP 429"): + info.ErrorCode = errorCodeRateLimited + default: + info.ErrorCode = errorCodeNetworkError + } + + return info +} + +// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo +// 场景 1(成功路径):FetchAvailableModels 正常返回,但账号已因 403 被标记为 error, +// +// 需要在正常 usage 数据上附加 forbidden/validation 信息。 +// +// 场景 2(降级路径):被封号的账号 OAuth token 失效,FetchAvailableModels 返回 401, +// +// 降级逻辑设置了 needs_reauth,但账号实际是 403 封号/需验证,需覆盖为正确状态。 +func enrichUsageWithAccountError(info *UsageInfo, account *Account) { + if info == nil || account == nil || account.Status != StatusError { + return + } + msg := strings.ToLower(account.ErrorMessage) + if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") && + !strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") { + return + } + fbType := classifyForbiddenType(account.ErrorMessage) + info.IsForbidden = true + info.ForbiddenType = fbType + info.ForbiddenReason = account.ErrorMessage + info.NeedsVerify = fbType == forbiddenTypeValidation + info.IsBanned = fbType == forbiddenTypeViolation + info.ValidationURL = extractValidationURL(account.ErrorMessage) + info.ErrorCode = errorCodeForbidden + info.NeedsReauth = false } // addWindowStats 为 usage 数据添加窗口期统计 diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go index e950ec1d..f8990b1a 100644 --- a/backend/internal/service/antigravity_quota_fetcher.go +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -2,12 +2,29 @@ package service import ( "context" + "encoding/json" + "errors" "fmt" + "log/slog" + "regexp" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) +const ( + forbiddenTypeValidation = "validation" + forbiddenTypeViolation = "violation" + forbiddenTypeForbidden = "forbidden" + + // 机器可读的错误码 + errorCodeForbidden = "forbidden" + errorCodeUnauthenticated = "unauthenticated" + errorCodeRateLimited = "rate_limited" + errorCodeNetworkError = "network_error" +) + // AntigravityQuotaFetcher 从 Antigravity API 获取额度 type AntigravityQuotaFetcher struct { proxyRepo ProxyRepository @@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou // 调用 API 获取配额 modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) if err != nil { + // 403 Forbidden: 不报错,返回 is_forbidden 标记 + var forbiddenErr *antigravity.ForbiddenError + if errors.As(err, &forbiddenErr) { + now := time.Now() + fbType := classifyForbiddenType(forbiddenErr.Body) + return &QuotaResult{ + UsageInfo: &UsageInfo{ + UpdatedAt: &now, + IsForbidden: true, + ForbiddenReason: forbiddenErr.Body, + ForbiddenType: fbType, + ValidationURL: extractValidationURL(forbiddenErr.Body), + NeedsVerify: fbType == forbiddenTypeValidation, + IsBanned: fbType == forbiddenTypeViolation, + ErrorCode: errorCodeForbidden, + }, + }, nil + } return nil, err } + // 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程) + tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken) + // 转换为 UsageInfo - usageInfo := f.buildUsageInfo(modelsResp) + usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized) return &QuotaResult{ UsageInfo: usageInfo, @@ -52,15 +90,52 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou }, nil } -// buildUsageInfo 将 API 响应转换为 UsageInfo -func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo { - now := time.Now() - info := &UsageInfo{ - UpdatedAt: &now, - AntigravityQuota: make(map[string]*AntigravityModelQuota), +// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串 +func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) { + loadResp, _, err := client.LoadCodeAssist(ctx, accessToken) + if err != nil { + slog.Warn("failed to fetch subscription tier", "error", err) + return "", "" + } + if loadResp == nil { + return "", "" } - // 遍历所有模型,填充 AntigravityQuota + raw = loadResp.GetTier() // 已有方法:paidTier > currentTier + normalized = normalizeTier(raw) + return raw, normalized +} + +// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN +func normalizeTier(raw string) string { + if raw == "" { + return "" + } + lower := strings.ToLower(raw) + switch { + case strings.Contains(lower, "ultra"): + return "ULTRA" + case strings.Contains(lower, "pro"): + return "PRO" + case strings.Contains(lower, "free"): + return "FREE" + default: + return "UNKNOWN" + } +} + +// buildUsageInfo 将 API 响应转换为 UsageInfo +func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo { + now := time.Now() + info := &UsageInfo{ + UpdatedAt: &now, + AntigravityQuota: make(map[string]*AntigravityModelQuota), + AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail), + SubscriptionTier: tierNormalized, + SubscriptionTierRaw: tierRaw, + } + + // 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails for modelName, modelInfo := range modelsResp.Models { if modelInfo.QuotaInfo == nil { continue @@ -73,6 +148,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv Utilization: utilization, ResetTime: modelInfo.QuotaInfo.ResetTime, } + + // 填充模型详细能力信息 + detail := &AntigravityModelDetail{ + DisplayName: modelInfo.DisplayName, + SupportsImages: modelInfo.SupportsImages, + SupportsThinking: modelInfo.SupportsThinking, + ThinkingBudget: modelInfo.ThinkingBudget, + Recommended: modelInfo.Recommended, + MaxTokens: modelInfo.MaxTokens, + MaxOutputTokens: modelInfo.MaxOutputTokens, + SupportedMimeTypes: modelInfo.SupportedMimeTypes, + } + info.AntigravityQuotaDetails[modelName] = detail + } + + // 废弃模型转发规则 + if len(modelsResp.DeprecatedModelIDs) > 0 { + info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs)) + for oldID, deprecated := range modelsResp.DeprecatedModelIDs { + info.ModelForwardingRules[oldID] = deprecated.NewModelID + } } // 同时设置 FiveHour 用于兼容展示(取主要模型) @@ -108,3 +204,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco } return proxy.URL() } + +// classifyForbiddenType 根据 403 响应体判断禁止类型 +func classifyForbiddenType(body string) string { + lower := strings.ToLower(body) + switch { + case strings.Contains(lower, "validation_required") || + strings.Contains(lower, "verify your account") || + strings.Contains(lower, "validation_url"): + return forbiddenTypeValidation + case strings.Contains(lower, "terms of service") || + strings.Contains(lower, "violation"): + return forbiddenTypeViolation + default: + return forbiddenTypeForbidden + } +} + +// urlPattern 用于从 403 响应体中提取 URL(降级方案) +var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`) + +// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接 +func extractValidationURL(body string) string { + // 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url + var parsed struct { + Error struct { + Details []struct { + Metadata map[string]string `json:"metadata"` + } `json:"details"` + } `json:"error"` + } + if json.Unmarshal([]byte(body), &parsed) == nil { + for _, detail := range parsed.Error.Details { + if u := detail.Metadata["validation_url"]; u != "" { + return u + } + if u := detail.Metadata["appeal_url"]; u != "" { + return u + } + } + } + + // 2. 降级:正则匹配 URL + lower := strings.ToLower(body) + if !strings.Contains(lower, "validation") && + !strings.Contains(lower, "verify") && + !strings.Contains(lower, "appeal") { + return "" + } + // 先解码常见转义再匹配 + normalized := strings.ReplaceAll(body, `\u0026`, "&") + if m := urlPattern.FindString(normalized); m != "" { + return m + } + return "" +} diff --git a/backend/internal/service/antigravity_quota_fetcher_test.go b/backend/internal/service/antigravity_quota_fetcher_test.go new file mode 100644 index 00000000..5ead8e60 --- /dev/null +++ b/backend/internal/service/antigravity_quota_fetcher_test.go @@ -0,0 +1,497 @@ +//go:build unit + +package service + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" +) + +// --------------------------------------------------------------------------- +// normalizeTier +// --------------------------------------------------------------------------- + +func TestNormalizeTier(t *testing.T) { + tests := []struct { + name string + raw string + expected string + }{ + {name: "empty string", raw: "", expected: ""}, + {name: "free-tier", raw: "free-tier", expected: "FREE"}, + {name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"}, + {name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"}, + {name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"}, + {name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"}, + {name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"}, + {name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"}, + {name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeTier(tt.raw) + require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw) + }) + } +} + +// --------------------------------------------------------------------------- +// buildUsageInfo +// --------------------------------------------------------------------------- + +func aqfBoolPtr(v bool) *bool { return &v } +func aqfIntPtr(v int) *int { return &v } + +func TestBuildUsageInfo_BasicModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.75, + ResetTime: "2026-03-08T12:00:00Z", + }, + DisplayName: "Claude Sonnet 4", + SupportsImages: aqfBoolPtr(true), + SupportsThinking: aqfBoolPtr(false), + ThinkingBudget: aqfIntPtr(0), + Recommended: aqfBoolPtr(true), + MaxTokens: aqfIntPtr(200000), + MaxOutputTokens: aqfIntPtr(16384), + SupportedMimeTypes: map[string]bool{ + "image/png": true, + "image/jpeg": true, + }, + }, + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.50, + ResetTime: "2026-03-08T15:00:00Z", + }, + DisplayName: "Gemini 2.5 Pro", + MaxTokens: aqfIntPtr(1000000), + MaxOutputTokens: aqfIntPtr(65536), + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO") + + // 基本字段 + require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set") + require.Equal(t, "PRO", info.SubscriptionTier) + require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw) + + // AntigravityQuota + require.Len(t, info.AntigravityQuota, 2) + + sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"] + require.NotNil(t, sonnetQuota) + require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25 + require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime) + + geminiQuota := info.AntigravityQuota["gemini-2.5-pro"] + require.NotNil(t, geminiQuota) + require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50 + require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime) + + // AntigravityQuotaDetails + require.Len(t, info.AntigravityQuotaDetails, 2) + + sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"] + require.NotNil(t, sonnetDetail) + require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName) + require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages) + require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking) + require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget) + require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended) + require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens) + require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens) + require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes) + + geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"] + require.NotNil(t, geminiDetail) + require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName) + require.Nil(t, geminiDetail.SupportsImages) + require.Nil(t, geminiDetail.SupportsThinking) + require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens) + require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens) +} + +func TestBuildUsageInfo_DeprecatedModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 1.0, + }, + }, + }, + DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{ + "claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"}, + "claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"}, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.Len(t, info.ModelForwardingRules, 2) + require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"]) + require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"]) +} + +func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "some-model": { + QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9}, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models") +} + +func TestBuildUsageInfo_EmptyModels(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{}, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info) + require.NotNil(t, info.AntigravityQuota) + require.Empty(t, info.AntigravityQuota) + require.NotNil(t, info.AntigravityQuotaDetails) + require.Empty(t, info.AntigravityQuotaDetails) + require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists") +} + +func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "model-without-quota": { + DisplayName: "No Quota Model", + // QuotaInfo is nil + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info) + require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped") + require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too") +} + +func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"] + // When the first priority model exists, it should be used for FiveHour + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.40, + ResetTime: "2026-03-08T18:00:00Z", + }, + }, + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.80, + ResetTime: "2026-03-08T12:00:00Z", + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists") + // claude-sonnet-4-20250514 is first in priority list, so it should be used + expectedUtilization := (1.0 - 0.80) * 100 // 20 + require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01) + require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime") +} + +func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514 + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.60, + ResetTime: "2026-03-08T14:00:00Z", + }, + }, + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.30, + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info.FiveHour) + expectedUtilization := (1.0 - 0.60) * 100 // 40 + require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01) +} + +func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // Only gemini-2.5-pro exists (third in priority list) + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "gemini-2.5-pro": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.30, + }, + }, + "other-model": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.90, + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info.FiveHour) + expectedUtilization := (1.0 - 0.30) * 100 // 70 + require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01) +} + +func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + // None of the priority models exist + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "some-other-model": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.50, + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists") +} + +func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.50, + ResetTime: "", // empty reset time + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + require.NotNil(t, info.FiveHour) + require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty") + require.Equal(t, 0, info.FiveHour.RemainingSeconds) +} + +func TestBuildUsageInfo_FullUtilization(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 0.0, // fully used + ResetTime: "2026-03-08T12:00:00Z", + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + quota := info.AntigravityQuota["claude-sonnet-4-20250514"] + require.NotNil(t, quota) + require.Equal(t, 100, quota.Utilization) +} + +func TestBuildUsageInfo_ZeroUtilization(t *testing.T) { + fetcher := &AntigravityQuotaFetcher{} + + modelsResp := &antigravity.FetchAvailableModelsResponse{ + Models: map[string]antigravity.ModelInfo{ + "claude-sonnet-4-20250514": { + QuotaInfo: &antigravity.ModelQuotaInfo{ + RemainingFraction: 1.0, // fully available + }, + }, + }, + } + + info := fetcher.buildUsageInfo(modelsResp, "", "") + + quota := info.AntigravityQuota["claude-sonnet-4-20250514"] + require.NotNil(t, quota) + require.Equal(t, 0, quota.Utilization) +} + +func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) { + // 模拟 FetchQuota 遇到 403 时的行为: + // FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true + forbiddenErr := &antigravity.ForbiddenError{ + StatusCode: 403, + Body: "Access denied", + } + + // 验证 ForbiddenError 满足 errors.As + var target *antigravity.ForbiddenError + require.True(t, errors.As(forbiddenErr, &target)) + require.Equal(t, 403, target.StatusCode) + require.Equal(t, "Access denied", target.Body) + require.Contains(t, forbiddenErr.Error(), "403") +} + +// --------------------------------------------------------------------------- +// classifyForbiddenType +// --------------------------------------------------------------------------- + +func TestClassifyForbiddenType(t *testing.T) { + tests := []struct { + name string + body string + expected string + }{ + { + name: "VALIDATION_REQUIRED keyword", + body: `{"error":{"message":"VALIDATION_REQUIRED"}}`, + expected: "validation", + }, + { + name: "verify your account", + body: `Please verify your account to continue`, + expected: "validation", + }, + { + name: "contains validation_url field", + body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`, + expected: "validation", + }, + { + name: "terms of service violation", + body: `Your account has been suspended for Terms of Service violation`, + expected: "violation", + }, + { + name: "violation keyword", + body: `Account suspended due to policy violation`, + expected: "violation", + }, + { + name: "generic 403", + body: `Access denied`, + expected: "forbidden", + }, + { + name: "empty body", + body: "", + expected: "forbidden", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyForbiddenType(tt.body) + require.Equal(t, tt.expected, got) + }) + } +} + +// --------------------------------------------------------------------------- +// extractValidationURL +// --------------------------------------------------------------------------- + +func TestExtractValidationURL(t *testing.T) { + tests := []struct { + name string + body string + expected string + }{ + { + name: "structured validation_url", + body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`, + expected: "https://accounts.google.com/verify?token=abc", + }, + { + name: "structured appeal_url", + body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`, + expected: "https://support.google.com/appeal/123", + }, + { + name: "validation_url takes priority over appeal_url", + body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`, + expected: "https://v.com", + }, + { + name: "fallback regex with verify keyword", + body: `Please verify your account at https://accounts.google.com/verify`, + expected: "https://accounts.google.com/verify", + }, + { + name: "no URL in generic forbidden", + body: `Access denied`, + expected: "", + }, + { + name: "empty body", + body: "", + expected: "", + }, + { + name: "URL present but no validation keywords", + body: `Error at https://example.com/something`, + expected: "", + }, + { + name: "unicode escaped ampersand", + body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`, + expected: "https://accounts.google.com/verify?a=1&b=2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractValidationURL(tt.body) + require.Equal(t, tt.expected, got) + }) + } +} diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go index dd9850bd..297a954c 100644 --- a/backend/internal/service/error_policy_test.go +++ b/backend/internal/service/error_policy_test.go @@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) { expected: ErrorPolicyTempUnscheduled, }, { - name: "temp_unschedulable_401_second_hit_upgrades_to_none", + // Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制), + // second hit 仍然返回 TempUnscheduled。 + name: "temp_unschedulable_401_second_hit_antigravity_stays_temp", account: &Account{ ID: 15, Type: AccountTypeOAuth, @@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) { }, statusCode: 401, body: []byte(`unauthorized`), - expected: ErrorPolicyNone, + expected: ErrorPolicyTempUnscheduled, }, { name: "temp_unschedulable_body_miss_returns_none", diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 5df2d639..d410555d 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -149,8 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } // 其他 400 错误(如参数问题)不处理,不禁用账号 case 401: - // 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新 - if account.Type == AccountTypeOAuth { + // OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。 + // Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。 + if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity { // 1. 失效缓存 if s.tokenCacheInvalidator != nil { if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { @@ -182,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } shouldDisable = true } else { - // 非 OAuth 账号(APIKey):保持原有 SetError 行为 + // 非 OAuth / Antigravity OAuth:保持 SetError 行为 msg := "Authentication failed (401): invalid or expired credentials" if upstreamMsg != "" { msg = "Authentication failed (401): " + upstreamMsg @@ -199,11 +200,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc s.handleAuthError(ctx, account, msg) shouldDisable = true case 403: - // 禁止访问:停止调度,记录错误 - msg := "Access forbidden (403): account may be suspended or lack permissions" - if upstreamMsg != "" { - msg = "Access forbidden (403): " + upstreamMsg - } logger.LegacyPrintf( "service.ratelimit", "[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s", @@ -215,8 +211,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc upstreamMsg, truncateForLog(responseBody, 1024), ) - s.handleAuthError(ctx, account, msg) - shouldDisable = true + shouldDisable = s.handle403(ctx, account, upstreamMsg, responseBody) case 429: s.handle429(ctx, account, headers, responseBody) shouldDisable = false @@ -621,6 +616,62 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg) } +// handle403 处理 403 Forbidden 错误 +// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用; +// 其他平台保持原有 SetError 行为。 +func (s *RateLimitService) handle403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) { + if account.Platform == PlatformAntigravity { + return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody) + } + // 非 Antigravity 平台:保持原有行为 + msg := "Access forbidden (403): account may be suspended or lack permissions" + if upstreamMsg != "" { + msg = "Access forbidden (403): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + return true +} + +// handleAntigravity403 处理 Antigravity 平台的 403 错误 +// validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复) +// violation(违规封号)→ 永久 SetError(需人工处理) +// generic(通用禁止)→ 永久 SetError +func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) { + fbType := classifyForbiddenType(string(responseBody)) + + switch fbType { + case forbiddenTypeValidation: + // VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复 + msg := "Validation required (403): account needs Google verification" + if upstreamMsg != "" { + msg = "Validation required (403): " + upstreamMsg + } + if validationURL := extractValidationURL(string(responseBody)); validationURL != "" { + msg += " | validation_url: " + validationURL + } + s.handleAuthError(ctx, account, msg) + return true + + case forbiddenTypeViolation: + // 违规封号: 永久禁用,需人工处理 + msg := "Account violation (403): terms of service violation" + if upstreamMsg != "" { + msg = "Account violation (403): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + return true + + default: + // 通用 403: 保持原有行为 + msg := "Access forbidden (403): account may be suspended or lack permissions" + if upstreamMsg != "" { + msg = "Access forbidden (403): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + return true + } +} + // handleCustomErrorCode 处理自定义错误码,停止账号调度 func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) { msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg @@ -1213,7 +1264,8 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac } // 401 首次命中可临时不可调度(给 token 刷新窗口); // 若历史上已因 401 进入过临时不可调度,则本次应升级为 error(返回 false 交由默认错误逻辑处理)。 - if statusCode == http.StatusUnauthorized { + // Antigravity 跳过:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制,无需升级逻辑。 + if statusCode == http.StatusUnauthorized && account.Platform != PlatformAntigravity { reason := account.TempUnschedulableReason // 缓存可能没有 reason,从 DB 回退读取 if reason == "" { diff --git a/backend/internal/service/ratelimit_service_401_db_fallback_test.go b/backend/internal/service/ratelimit_service_401_db_fallback_test.go index e1611425..d245b5d5 100644 --- a/backend/internal/service/ratelimit_service_401_db_fallback_test.go +++ b/backend/internal/service/ratelimit_service_401_db_fallback_test.go @@ -27,34 +27,68 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) { // Scenario: cache account has empty TempUnschedulableReason (cache miss), - // but DB account has a previous 401 record → should escalate to ErrorPolicyNone. - repo := &dbFallbackRepoStub{ - dbAccount: &Account{ - ID: 20, - TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, - }, - } - svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + // but DB account has a previous 401 record. + // Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error). + // Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules). + t.Run("gemini_escalates", func(t *testing.T) { + repo := &dbFallbackRepoStub{ + dbAccount: &Account{ + ID: 20, + TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, + }, + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) - account := &Account{ - ID: 20, - Type: AccountTypeOAuth, - Platform: PlatformAntigravity, - TempUnschedulableReason: "", // cache miss — reason is empty - Credentials: map[string]any{ - "temp_unschedulable_enabled": true, - "temp_unschedulable_rules": []any{ - map[string]any{ - "error_code": float64(401), - "keywords": []any{"unauthorized"}, - "duration_minutes": float64(10), + account := &Account{ + ID: 20, + Type: AccountTypeOAuth, + Platform: PlatformGemini, + TempUnschedulableReason: "", + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, }, }, - }, - } + } - result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) - require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone") + result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) + require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate") + }) + + t.Run("antigravity_stays_temp", func(t *testing.T) { + repo := &dbFallbackRepoStub{ + dbAccount: &Account{ + ID: 20, + TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, + }, + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 20, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + TempUnschedulableReason: "", + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(401), + "keywords": []any{"unauthorized"}, + "duration_minutes": float64(10), + }, + }, + }, + } + + result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) + require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled") + }) } func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) { diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 7bced46f..4a6e5d6c 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -42,45 +42,56 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc } func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) { - tests := []struct { - name string - platform string - }{ - {name: "gemini", platform: PlatformGemini}, - {name: "antigravity", platform: PlatformAntigravity}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - repo := &rateLimitAccountRepoStub{} - invalidator := &tokenCacheInvalidatorRecorder{} - service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) - service.SetTokenCacheInvalidator(invalidator) - account := &Account{ - ID: 100, - Platform: tt.platform, - Type: AccountTypeOAuth, - Credentials: map[string]any{ - "temp_unschedulable_enabled": true, - "temp_unschedulable_rules": []any{ - map[string]any{ - "error_code": 401, - "keywords": []any{"unauthorized"}, - "duration_minutes": 30, - "description": "custom rule", - }, + t.Run("gemini", func(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + invalidator := &tokenCacheInvalidatorRecorder{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetTokenCacheInvalidator(invalidator) + account := &Account{ + ID: 100, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": 401, + "keywords": []any{"unauthorized"}, + "duration_minutes": 30, + "description": "custom rule", }, }, - } + }, + } - shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) - require.True(t, shouldDisable) - require.Equal(t, 0, repo.setErrorCalls) - require.Equal(t, 1, repo.tempCalls) - require.Len(t, invalidator.accounts, 1) - }) - } + require.True(t, shouldDisable) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.tempCalls) + require.Len(t, invalidator.accounts, 1) + }) + + t.Run("antigravity_401_uses_SetError", func(t *testing.T) { + // Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制, + // HandleUpstreamError 中走 SetError 路径。 + repo := &rateLimitAccountRepoStub{} + invalidator := &tokenCacheInvalidatorRecorder{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetTokenCacheInvalidator(invalidator) + account := &Account{ + ID: 100, + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.tempCalls) + require.Empty(t, invalidator.accounts) + }) } func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index e83eaead..883af59a 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -36,6 +36,10 @@