diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 3556da88..9a0488a4 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -326,7 +326,7 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC var lastErr error for urlIdx, baseURL := range availableURLs { apiURL := baseURL + "/v1internal:loadCodeAssist" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes)) if err != nil { lastErr = fmt.Errorf("创建请求失败: %w", err) continue @@ -405,7 +405,7 @@ func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (s apiURL := baseURL + "/v1internal:onboardUser" for attempt := 1; attempt <= 5; attempt++ { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes)) if err != nil { lastErr = fmt.Errorf("创建请求失败: %w", err) break @@ -456,7 +456,11 @@ func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (s } // done=false 时等待后重试(与 CLIProxyAPI 行为一致) - time.Sleep(2 * time.Second) + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(2 * time.Second): + } } } @@ -521,7 +525,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI var lastErr error for urlIdx, baseURL := range availableURLs { apiURL := baseURL + "/v1internal:fetchAvailableModels" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes)) if err != nil { lastErr = fmt.Errorf("创建请求失败: %w", err) continue diff --git a/backend/internal/service/antigravity_oauth_service_test.go b/backend/internal/service/antigravity_oauth_service_test.go index e041c2b4..0325d9bc 100644 --- a/backend/internal/service/antigravity_oauth_service_test.go +++ b/backend/internal/service/antigravity_oauth_service_test.go @@ -51,7 +51,6 @@ func TestResolveDefaultTierID(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index e0ada9f1..774c1c75 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -3,16 +3,24 @@ package service import ( "context" "errors" - "log" "log/slog" "strconv" "strings" + "sync" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) const ( antigravityTokenRefreshSkew = 3 * time.Minute antigravityTokenCacheSkew = 5 * time.Minute + + // projectIDFillCooldown 同一账号 project_id 补齐失败后的冷却时间 + projectIDFillCooldown = 60 * time.Second + + // fallbackProjectID 所有获取方式都失败时的兜底值(与 Antigravity-Manager 一致) + fallbackProjectID = "bamboo-precept-lgxtn" ) // AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) @@ -23,6 +31,9 @@ type AntigravityTokenProvider struct { accountRepo AccountRepository tokenCache AntigravityTokenCache antigravityOAuthService *AntigravityOAuthService + + // projectIDFillAttempts 记录每个账号最近一次 project_id 补齐尝试时间,用于冷却去重 + projectIDFillAttempts sync.Map // map[int64]time.Time } func NewAntigravityTokenProvider( @@ -94,14 +105,10 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return "", err } newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } + mergeCredentials(newCredentials, account.Credentials) account.Credentials = newCredentials if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr) + slog.Error("failed to update account credentials after token refresh", "account_id", account.ID, "error", updateErr) } expiresAt = account.GetCredentialAsTime("expires_at") } @@ -113,27 +120,12 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return "", errors.New("access_token not found in credentials") } - // 如果账号还没有 project_id,优先尝试在线补齐,避免请求 daily/sandbox 时出现 - // "Invalid project resource name projects/"。 - if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil { - if tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account); err == nil { - newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } - account.Credentials = newCredentials - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - log.Printf("[AntigravityTokenProvider] Failed to persist project_id补齐: %v", updateErr) - } - if refreshed := strings.TrimSpace(account.GetCredential("access_token")); refreshed != "" { - accessToken = refreshed - } - } + // 3. 如果缺少 project_id,轻量补齐(不刷新 token) + if strings.TrimSpace(account.GetCredential("project_id")) == "" { + p.tryFillProjectID(ctx, account, accessToken) } - // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) + // 4. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) if isStale && latestAccount != nil { @@ -164,6 +156,70 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return accessToken, nil } +// tryFillProjectID 轻量级 project_id 补齐(与 Antigravity-Manager 保持一致) +// 只调用 loadCodeAssist + onboardUser,不刷新 token。 +// 带冷却去重:同一账号 60s 内不重复尝试。 +func (p *AntigravityTokenProvider) tryFillProjectID(ctx context.Context, account *Account, accessToken string) { + // 冷却检查:60s 内不重复尝试 + if lastAttempt, ok := p.projectIDFillAttempts.Load(account.ID); ok { + if time.Since(lastAttempt.(time.Time)) < projectIDFillCooldown { + return + } + } + p.projectIDFillAttempts.Store(account.ID, time.Now()) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + client := antigravity.NewClient(proxyURL) + + // Step 1: loadCodeAssist(单次调用,不重试) + loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken) + if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { + p.persistProjectID(ctx, account, loadResp.CloudAICompanionProject) + p.projectIDFillAttempts.Delete(account.ID) // 成功后清除冷却 + return + } + + // Step 2: onboardUser(loadCodeAssist 成功但未返回 project_id 时) + if err == nil { + if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" { + p.persistProjectID(ctx, account, projectID) + p.projectIDFillAttempts.Delete(account.ID) + return + } + } + + // Step 3: 兜底值(与 Antigravity-Manager 一致) + slog.Warn("project_id fill failed, using fallback", + "account_id", account.ID, + "fallback", fallbackProjectID, + ) + p.persistProjectID(ctx, account, fallbackProjectID) +} + +// persistProjectID 将 project_id 写入账号凭证并持久化 +func (p *AntigravityTokenProvider) persistProjectID(ctx context.Context, account *Account, projectID string) { + account.Credentials["project_id"] = projectID + if p.accountRepo == nil { + return + } + if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + slog.Error("failed to persist project_id", "account_id", account.ID, "error", updateErr) + } +} + +// mergeCredentials 将 old 中不存在于 new 的字段合并到 new +func mergeCredentials(newCreds, oldCreds map[string]any) { + for k, v := range oldCreds { + if _, exists := newCreds[k]; !exists { + newCreds[k] = v + } + } +} + func AntigravityTokenCacheKey(account *Account) string { projectID := strings.TrimSpace(account.GetCredential("project_id")) if projectID != "" {