mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 15:02:13 +08:00
fix: 补齐 Antigravity OAuth 账号 project_id 获取逻辑
部分账号 loadCodeAssist 不会立即返回 cloudaicompanionProject, 导致转发时 project 字段为空,上游返回 400 "Invalid project resource name projects/"。 - 新增 OnboardUser API:当 loadCodeAssist 未返回 project_id 时, 通过 onboardUser 完成账号初始化并获取 project_id - token 刷新时增加 onboard 兜底逻辑 - GetAccessToken 按需补齐:转发时发现 project_id 为空立即触发刷新 - 新增 resolveDefaultTierID 单元测试
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -135,4 +135,6 @@ docs/*
|
||||
# ===================
|
||||
# 压测工具
|
||||
# ===================
|
||||
tools/loadtest/
|
||||
tools/loadtest/
|
||||
# Antigravity Manager
|
||||
Antigravity-Manager/
|
||||
|
||||
323
antigravity_projectid_fix.patch
Normal file
323
antigravity_projectid_fix.patch
Normal file
@@ -0,0 +1,323 @@
|
||||
diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go
|
||||
index a6279b11..3556da88 100644
|
||||
--- a/backend/internal/pkg/antigravity/client.go
|
||||
+++ b/backend/internal/pkg/antigravity/client.go
|
||||
@@ -115,6 +115,23 @@ type LoadCodeAssistResponse struct {
|
||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||
}
|
||||
|
||||
+// OnboardUserRequest onboardUser 请求
|
||||
+type OnboardUserRequest struct {
|
||||
+ TierID string `json:"tierId"`
|
||||
+ Metadata struct {
|
||||
+ IDEType string `json:"ideType"`
|
||||
+ Platform string `json:"platform,omitempty"`
|
||||
+ PluginType string `json:"pluginType,omitempty"`
|
||||
+ } `json:"metadata"`
|
||||
+}
|
||||
+
|
||||
+// OnboardUserResponse onboardUser 响应
|
||||
+type OnboardUserResponse struct {
|
||||
+ Name string `json:"name,omitempty"`
|
||||
+ Done bool `json:"done,omitempty"`
|
||||
+ Response map[string]any `json:"response,omitempty"`
|
||||
+}
|
||||
+
|
||||
// GetTier 获取账户类型
|
||||
// 优先返回 paidTier(付费订阅级别),否则返回 currentTier
|
||||
func (r *LoadCodeAssistResponse) GetTier() string {
|
||||
@@ -361,6 +378,113 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
+// OnboardUser 触发账号 onboarding,并返回 project_id
|
||||
+// 说明:
|
||||
+// 1) 部分账号 loadCodeAssist 不会立即返回 cloudaicompanionProject;
|
||||
+// 2) 这时需要调用 onboardUser 完成初始化,之后才能拿到 project_id。
|
||||
+func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
|
||||
+ tierID = strings.TrimSpace(tierID)
|
||||
+ if tierID == "" {
|
||||
+ return "", fmt.Errorf("tier_id 为空")
|
||||
+ }
|
||||
+
|
||||
+ reqBody := OnboardUserRequest{TierID: tierID}
|
||||
+ reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||
+ reqBody.Metadata.Platform = "PLATFORM_UNSPECIFIED"
|
||||
+ reqBody.Metadata.PluginType = "GEMINI"
|
||||
+
|
||||
+ bodyBytes, err := json.Marshal(reqBody)
|
||||
+ if err != nil {
|
||||
+ return "", fmt.Errorf("序列化请求失败: %w", err)
|
||||
+ }
|
||||
+
|
||||
+ availableURLs := BaseURLs
|
||||
+ var lastErr error
|
||||
+
|
||||
+ for urlIdx, baseURL := range availableURLs {
|
||||
+ apiURL := baseURL + "/v1internal:onboardUser"
|
||||
+
|
||||
+ for attempt := 1; attempt <= 5; attempt++ {
|
||||
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
+ if err != nil {
|
||||
+ lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||
+ break
|
||||
+ }
|
||||
+ req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
+ req.Header.Set("Content-Type", "application/json")
|
||||
+ req.Header.Set("User-Agent", UserAgent)
|
||||
+
|
||||
+ resp, err := c.httpClient.Do(req)
|
||||
+ if err != nil {
|
||||
+ lastErr = fmt.Errorf("onboardUser 请求失败: %w", err)
|
||||
+ if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
+ log.Printf("[antigravity] onboardUser URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
+ break
|
||||
+ }
|
||||
+ return "", lastErr
|
||||
+ }
|
||||
+
|
||||
+ respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
+ _ = resp.Body.Close()
|
||||
+ if err != nil {
|
||||
+ return "", fmt.Errorf("读取响应失败: %w", err)
|
||||
+ }
|
||||
+
|
||||
+ if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
+ log.Printf("[antigravity] onboardUser URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
+ break
|
||||
+ }
|
||||
+
|
||||
+ if resp.StatusCode != http.StatusOK {
|
||||
+ lastErr = fmt.Errorf("onboardUser 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
+ return "", lastErr
|
||||
+ }
|
||||
+
|
||||
+ var onboardResp OnboardUserResponse
|
||||
+ if err := json.Unmarshal(respBodyBytes, &onboardResp); err != nil {
|
||||
+ lastErr = fmt.Errorf("onboardUser 响应解析失败: %w", err)
|
||||
+ return "", lastErr
|
||||
+ }
|
||||
+
|
||||
+ if onboardResp.Done {
|
||||
+ if projectID := extractProjectIDFromOnboardResponse(onboardResp.Response); projectID != "" {
|
||||
+ DefaultURLAvailability.MarkSuccess(baseURL)
|
||||
+ return projectID, nil
|
||||
+ }
|
||||
+ lastErr = fmt.Errorf("onboardUser 完成但未返回 project_id")
|
||||
+ return "", lastErr
|
||||
+ }
|
||||
+
|
||||
+ // done=false 时等待后重试(与 CLIProxyAPI 行为一致)
|
||||
+ time.Sleep(2 * time.Second)
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ if lastErr != nil {
|
||||
+ return "", lastErr
|
||||
+ }
|
||||
+ return "", fmt.Errorf("onboardUser 未返回 project_id")
|
||||
+}
|
||||
+
|
||||
+func extractProjectIDFromOnboardResponse(resp map[string]any) string {
|
||||
+ if len(resp) == 0 {
|
||||
+ return ""
|
||||
+ }
|
||||
+
|
||||
+ if v, ok := resp["cloudaicompanionProject"]; ok {
|
||||
+ switch project := v.(type) {
|
||||
+ case string:
|
||||
+ return strings.TrimSpace(project)
|
||||
+ case map[string]any:
|
||||
+ if id, ok := project["id"].(string); ok {
|
||||
+ return strings.TrimSpace(id)
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ return ""
|
||||
+}
|
||||
+
|
||||
// ModelQuotaInfo 模型配额信息
|
||||
type ModelQuotaInfo struct {
|
||||
RemainingFraction float64 `json:"remainingFraction"`
|
||||
diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go
|
||||
index fa8379ed..86b7cc2e 100644
|
||||
--- a/backend/internal/service/antigravity_oauth_service.go
|
||||
+++ b/backend/internal/service/antigravity_oauth_service.go
|
||||
@@ -273,12 +273,20 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
- loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
+ loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
|
||||
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
return loadResp.CloudAICompanionProject, nil
|
||||
}
|
||||
|
||||
+ if err == nil {
|
||||
+ if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
|
||||
+ return projectID, nil
|
||||
+ } else if onboardErr != nil {
|
||||
+ lastErr = onboardErr
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
// 记录错误
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
@@ -292,6 +300,53 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
||||
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
+func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
|
||||
+ tierID := resolveDefaultTierID(loadRaw)
|
||||
+ if tierID == "" {
|
||||
+ return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
|
||||
+ }
|
||||
+
|
||||
+ projectID, err := client.OnboardUser(ctx, accessToken, tierID)
|
||||
+ if err != nil {
|
||||
+ return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
|
||||
+ }
|
||||
+ return projectID, nil
|
||||
+}
|
||||
+
|
||||
+func resolveDefaultTierID(loadRaw map[string]any) string {
|
||||
+ if len(loadRaw) == 0 {
|
||||
+ return ""
|
||||
+ }
|
||||
+
|
||||
+ rawTiers, ok := loadRaw["allowedTiers"]
|
||||
+ if !ok {
|
||||
+ return ""
|
||||
+ }
|
||||
+
|
||||
+ tiers, ok := rawTiers.([]any)
|
||||
+ if !ok {
|
||||
+ return ""
|
||||
+ }
|
||||
+
|
||||
+ for _, rawTier := range tiers {
|
||||
+ tier, ok := rawTier.(map[string]any)
|
||||
+ if !ok {
|
||||
+ continue
|
||||
+ }
|
||||
+ if isDefault, _ := tier["isDefault"].(bool); !isDefault {
|
||||
+ continue
|
||||
+ }
|
||||
+ if id, ok := tier["id"].(string); ok {
|
||||
+ id = strings.TrimSpace(id)
|
||||
+ if id != "" {
|
||||
+ return id
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ return ""
|
||||
+}
|
||||
+
|
||||
// BuildAccountCredentials 构建账户凭证
|
||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go
|
||||
index 94eca94d..dde3bb07 100644
|
||||
--- a/backend/internal/service/antigravity_token_provider.go
|
||||
+++ b/backend/internal/service/antigravity_token_provider.go
|
||||
@@ -102,6 +102,26 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
+ // 如果账号还没有 project_id,优先尝试在线补齐,避免请求 daily/sandbox 时出现
|
||||
+ // "Invalid project resource name projects/"。
|
||||
+ if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
||||
+ if tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account); err == nil {
|
||||
+ newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
+ for k, v := range account.Credentials {
|
||||
+ if _, exists := newCredentials[k]; !exists {
|
||||
+ newCredentials[k] = v
|
||||
+ }
|
||||
+ }
|
||||
+ account.Credentials = newCredentials
|
||||
+ if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
+ log.Printf("[AntigravityTokenProvider] Failed to persist project_id补齐: %v", updateErr)
|
||||
+ }
|
||||
+ if refreshed := strings.TrimSpace(account.GetCredential("access_token")); refreshed != "" {
|
||||
+ accessToken = refreshed
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
diff --git a/backend/internal/service/antigravity_oauth_service_test.go b/backend/internal/service/antigravity_oauth_service_test.go
|
||||
new file mode 100644
|
||||
index 00000000..e041c2b4
|
||||
--- /dev/null
|
||||
+++ b/backend/internal/service/antigravity_oauth_service_test.go
|
||||
@@ -0,0 +1,64 @@
|
||||
+package service
|
||||
+
|
||||
+import (
|
||||
+ "testing"
|
||||
+)
|
||||
+
|
||||
+func TestResolveDefaultTierID(t *testing.T) {
|
||||
+ t.Parallel()
|
||||
+
|
||||
+ tests := []struct {
|
||||
+ name string
|
||||
+ loadRaw map[string]any
|
||||
+ want string
|
||||
+ }{
|
||||
+ {
|
||||
+ name: "missing allowedTiers",
|
||||
+ loadRaw: map[string]any{
|
||||
+ "paidTier": map[string]any{"id": "g1-pro-tier"},
|
||||
+ },
|
||||
+ want: "",
|
||||
+ },
|
||||
+ {
|
||||
+ name: "allowedTiers but no default",
|
||||
+ loadRaw: map[string]any{
|
||||
+ "allowedTiers": []any{
|
||||
+ map[string]any{"id": "free-tier", "isDefault": false},
|
||||
+ map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
+ },
|
||||
+ },
|
||||
+ want: "",
|
||||
+ },
|
||||
+ {
|
||||
+ name: "default tier found",
|
||||
+ loadRaw: map[string]any{
|
||||
+ "allowedTiers": []any{
|
||||
+ map[string]any{"id": "free-tier", "isDefault": true},
|
||||
+ map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
+ },
|
||||
+ },
|
||||
+ want: "free-tier",
|
||||
+ },
|
||||
+ {
|
||||
+ name: "default tier id with spaces",
|
||||
+ loadRaw: map[string]any{
|
||||
+ "allowedTiers": []any{
|
||||
+ map[string]any{"id": " standard-tier ", "isDefault": true},
|
||||
+ },
|
||||
+ },
|
||||
+ want: "standard-tier",
|
||||
+ },
|
||||
+ }
|
||||
+
|
||||
+ for _, tc := range tests {
|
||||
+ tc := tc
|
||||
+ t.Run(tc.name, func(t *testing.T) {
|
||||
+ t.Parallel()
|
||||
+
|
||||
+ got := resolveDefaultTierID(tc.loadRaw)
|
||||
+ if got != tc.want {
|
||||
+ t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
|
||||
+ }
|
||||
+ })
|
||||
+ }
|
||||
+}
|
||||
@@ -115,6 +115,23 @@ type LoadCodeAssistResponse struct {
|
||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||
}
|
||||
|
||||
// OnboardUserRequest onboardUser 请求
|
||||
type OnboardUserRequest struct {
|
||||
TierID string `json:"tierId"`
|
||||
Metadata struct {
|
||||
IDEType string `json:"ideType"`
|
||||
Platform string `json:"platform,omitempty"`
|
||||
PluginType string `json:"pluginType,omitempty"`
|
||||
} `json:"metadata"`
|
||||
}
|
||||
|
||||
// OnboardUserResponse onboardUser 响应
|
||||
type OnboardUserResponse struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Done bool `json:"done,omitempty"`
|
||||
Response map[string]any `json:"response,omitempty"`
|
||||
}
|
||||
|
||||
// GetTier 获取账户类型
|
||||
// 优先返回 paidTier(付费订阅级别),否则返回 currentTier
|
||||
func (r *LoadCodeAssistResponse) GetTier() string {
|
||||
@@ -361,6 +378,113 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
// OnboardUser 触发账号 onboarding,并返回 project_id
|
||||
// 说明:
|
||||
// 1) 部分账号 loadCodeAssist 不会立即返回 cloudaicompanionProject;
|
||||
// 2) 这时需要调用 onboardUser 完成初始化,之后才能拿到 project_id。
|
||||
func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
|
||||
tierID = strings.TrimSpace(tierID)
|
||||
if tierID == "" {
|
||||
return "", fmt.Errorf("tier_id 为空")
|
||||
}
|
||||
|
||||
reqBody := OnboardUserRequest{TierID: tierID}
|
||||
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||
reqBody.Metadata.Platform = "PLATFORM_UNSPECIFIED"
|
||||
reqBody.Metadata.PluginType = "GEMINI"
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
availableURLs := BaseURLs
|
||||
var lastErr error
|
||||
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
apiURL := baseURL + "/v1internal:onboardUser"
|
||||
|
||||
for attempt := 1; attempt <= 5; attempt++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||
break
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("onboardUser 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
log.Printf("[antigravity] onboardUser URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
break
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
log.Printf("[antigravity] onboardUser URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
break
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("onboardUser 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
var onboardResp OnboardUserResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &onboardResp); err != nil {
|
||||
lastErr = fmt.Errorf("onboardUser 响应解析失败: %w", err)
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
if onboardResp.Done {
|
||||
if projectID := extractProjectIDFromOnboardResponse(onboardResp.Response); projectID != "" {
|
||||
DefaultURLAvailability.MarkSuccess(baseURL)
|
||||
return projectID, nil
|
||||
}
|
||||
lastErr = fmt.Errorf("onboardUser 完成但未返回 project_id")
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
// done=false 时等待后重试(与 CLIProxyAPI 行为一致)
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return "", lastErr
|
||||
}
|
||||
return "", fmt.Errorf("onboardUser 未返回 project_id")
|
||||
}
|
||||
|
||||
func extractProjectIDFromOnboardResponse(resp map[string]any) string {
|
||||
if len(resp) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if v, ok := resp["cloudaicompanionProject"]; ok {
|
||||
switch project := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(project)
|
||||
case map[string]any:
|
||||
if id, ok := project["id"].(string); ok {
|
||||
return strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ModelQuotaInfo 模型配额信息
|
||||
type ModelQuotaInfo struct {
|
||||
RemainingFraction float64 `json:"remainingFraction"`
|
||||
|
||||
@@ -273,12 +273,20 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
|
||||
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
return loadResp.CloudAICompanionProject, nil
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
|
||||
return projectID, nil
|
||||
} else if onboardErr != nil {
|
||||
lastErr = onboardErr
|
||||
}
|
||||
}
|
||||
|
||||
// 记录错误
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
@@ -292,6 +300,53 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
||||
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
|
||||
tierID := resolveDefaultTierID(loadRaw)
|
||||
if tierID == "" {
|
||||
return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
|
||||
}
|
||||
|
||||
projectID, err := client.OnboardUser(ctx, accessToken, tierID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
|
||||
}
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
func resolveDefaultTierID(loadRaw map[string]any) string {
|
||||
if len(loadRaw) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
rawTiers, ok := loadRaw["allowedTiers"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
tiers, ok := rawTiers.([]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, rawTier := range tiers {
|
||||
tier, ok := rawTier.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDefault, _ := tier["isDefault"].(bool); !isDefault {
|
||||
continue
|
||||
}
|
||||
if id, ok := tier["id"].(string); ok {
|
||||
id = strings.TrimSpace(id)
|
||||
if id != "" {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// BuildAccountCredentials 构建账户凭证
|
||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
|
||||
64
backend/internal/service/antigravity_oauth_service_test.go
Normal file
64
backend/internal/service/antigravity_oauth_service_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveDefaultTierID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
loadRaw map[string]any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "missing allowedTiers",
|
||||
loadRaw: map[string]any{
|
||||
"paidTier": map[string]any{"id": "g1-pro-tier"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "allowedTiers but no default",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": "free-tier", "isDefault": false},
|
||||
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "default tier found",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": "free-tier", "isDefault": true},
|
||||
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
},
|
||||
},
|
||||
want: "free-tier",
|
||||
},
|
||||
{
|
||||
name: "default tier id with spaces",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": " standard-tier ", "isDefault": true},
|
||||
},
|
||||
},
|
||||
want: "standard-tier",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := resolveDefaultTierID(tc.loadRaw)
|
||||
if got != tc.want {
|
||||
t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -113,6 +113,26 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 如果账号还没有 project_id,优先尝试在线补齐,避免请求 daily/sandbox 时出现
|
||||
// "Invalid project resource name projects/"。
|
||||
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
||||
if tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account); err == nil {
|
||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] Failed to persist project_id补齐: %v", updateErr)
|
||||
}
|
||||
if refreshed := strings.TrimSpace(account.GetCredential("access_token")); refreshed != "" {
|
||||
accessToken = refreshed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
|
||||
Reference in New Issue
Block a user