mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-09 17:44:46 +08:00
refactor: optimize project_id fill to lightweight approach
- Replace heavy RefreshAccountToken with lightweight tryFillProjectID (loadCodeAssist → onboardUser → fallback), consistent with Antigravity-Manager's behavior - Add sync.Map cooldown/dedup (60s) to prevent repeated fill attempts - Add fallback project_id "bamboo-precept-lgxtn" matching AM - Extract mergeCredentials helper to eliminate duplication - Use slog structured logging instead of log.Printf - Fix time.Sleep in OnboardUser to context-aware select - Fix strings.NewReader(string(bodyBytes)) → bytes.NewReader(bodyBytes) - Remove redundant tc := tc in test (Go 1.22+) - Add nil guard in persistProjectID for test safety
This commit is contained in:
@@ -326,7 +326,7 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
for urlIdx, baseURL := range availableURLs {
|
for urlIdx, baseURL := range availableURLs {
|
||||||
apiURL := baseURL + "/v1internal:loadCodeAssist"
|
apiURL := baseURL + "/v1internal:loadCodeAssist"
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||||
continue
|
continue
|
||||||
@@ -405,7 +405,7 @@ func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (s
|
|||||||
apiURL := baseURL + "/v1internal:onboardUser"
|
apiURL := baseURL + "/v1internal:onboardUser"
|
||||||
|
|
||||||
for attempt := 1; attempt <= 5; attempt++ {
|
for attempt := 1; attempt <= 5; attempt++ {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||||
break
|
break
|
||||||
@@ -456,7 +456,11 @@ func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// done=false 时等待后重试(与 CLIProxyAPI 行为一致)
|
// done=false 时等待后重试(与 CLIProxyAPI 行为一致)
|
||||||
time.Sleep(2 * time.Second)
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -521,7 +525,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
for urlIdx, baseURL := range availableURLs {
|
for urlIdx, baseURL := range availableURLs {
|
||||||
apiURL := baseURL + "/v1internal:fetchAvailableModels"
|
apiURL := baseURL + "/v1internal:fetchAvailableModels"
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ func TestResolveDefaultTierID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
tc := tc
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -3,16 +3,24 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
antigravityTokenRefreshSkew = 3 * time.Minute
|
antigravityTokenRefreshSkew = 3 * time.Minute
|
||||||
antigravityTokenCacheSkew = 5 * time.Minute
|
antigravityTokenCacheSkew = 5 * time.Minute
|
||||||
|
|
||||||
|
// projectIDFillCooldown 同一账号 project_id 补齐失败后的冷却时间
|
||||||
|
projectIDFillCooldown = 60 * time.Second
|
||||||
|
|
||||||
|
// fallbackProjectID 所有获取方式都失败时的兜底值(与 Antigravity-Manager 一致)
|
||||||
|
fallbackProjectID = "bamboo-precept-lgxtn"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||||
@@ -23,6 +31,9 @@ type AntigravityTokenProvider struct {
|
|||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache AntigravityTokenCache
|
tokenCache AntigravityTokenCache
|
||||||
antigravityOAuthService *AntigravityOAuthService
|
antigravityOAuthService *AntigravityOAuthService
|
||||||
|
|
||||||
|
// projectIDFillAttempts 记录每个账号最近一次 project_id 补齐尝试时间,用于冷却去重
|
||||||
|
projectIDFillAttempts sync.Map // map[int64]time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAntigravityTokenProvider(
|
func NewAntigravityTokenProvider(
|
||||||
@@ -94,14 +105,10 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
for k, v := range account.Credentials {
|
mergeCredentials(newCredentials, account.Credentials)
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
account.Credentials = newCredentials
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
slog.Error("failed to update account credentials after token refresh", "account_id", account.ID, "error", updateErr)
|
||||||
}
|
}
|
||||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||||
}
|
}
|
||||||
@@ -113,27 +120,12 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return "", errors.New("access_token not found in credentials")
|
return "", errors.New("access_token not found in credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果账号还没有 project_id,优先尝试在线补齐,避免请求 daily/sandbox 时出现
|
// 3. 如果缺少 project_id,轻量补齐(不刷新 token)
|
||||||
// "Invalid project resource name projects/"。
|
if strings.TrimSpace(account.GetCredential("project_id")) == "" {
|
||||||
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
p.tryFillProjectID(ctx, account, accessToken)
|
||||||
if tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account); err == nil {
|
|
||||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
|
||||||
for k, v := range account.Credentials {
|
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
|
||||||
log.Printf("[AntigravityTokenProvider] Failed to persist project_id补齐: %v", updateErr)
|
|
||||||
}
|
|
||||||
if refreshed := strings.TrimSpace(account.GetCredential("access_token")); refreshed != "" {
|
|
||||||
accessToken = refreshed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
// 4. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
if isStale && latestAccount != nil {
|
if isStale && latestAccount != nil {
|
||||||
@@ -164,6 +156,70 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tryFillProjectID 轻量级 project_id 补齐(与 Antigravity-Manager 保持一致)
|
||||||
|
// 只调用 loadCodeAssist + onboardUser,不刷新 token。
|
||||||
|
// 带冷却去重:同一账号 60s 内不重复尝试。
|
||||||
|
func (p *AntigravityTokenProvider) tryFillProjectID(ctx context.Context, account *Account, accessToken string) {
|
||||||
|
// 冷却检查:60s 内不重复尝试
|
||||||
|
if lastAttempt, ok := p.projectIDFillAttempts.Load(account.ID); ok {
|
||||||
|
if time.Since(lastAttempt.(time.Time)) < projectIDFillCooldown {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.projectIDFillAttempts.Store(account.ID, time.Now())
|
||||||
|
|
||||||
|
proxyURL := ""
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
client := antigravity.NewClient(proxyURL)
|
||||||
|
|
||||||
|
// Step 1: loadCodeAssist(单次调用,不重试)
|
||||||
|
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
|
||||||
|
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||||
|
p.persistProjectID(ctx, account, loadResp.CloudAICompanionProject)
|
||||||
|
p.projectIDFillAttempts.Delete(account.ID) // 成功后清除冷却
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: onboardUser(loadCodeAssist 成功但未返回 project_id 时)
|
||||||
|
if err == nil {
|
||||||
|
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
|
||||||
|
p.persistProjectID(ctx, account, projectID)
|
||||||
|
p.projectIDFillAttempts.Delete(account.ID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: 兜底值(与 Antigravity-Manager 一致)
|
||||||
|
slog.Warn("project_id fill failed, using fallback",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"fallback", fallbackProjectID,
|
||||||
|
)
|
||||||
|
p.persistProjectID(ctx, account, fallbackProjectID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// persistProjectID 将 project_id 写入账号凭证并持久化
|
||||||
|
func (p *AntigravityTokenProvider) persistProjectID(ctx context.Context, account *Account, projectID string) {
|
||||||
|
account.Credentials["project_id"] = projectID
|
||||||
|
if p.accountRepo == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
|
slog.Error("failed to persist project_id", "account_id", account.ID, "error", updateErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeCredentials 将 old 中不存在于 new 的字段合并到 new
|
||||||
|
func mergeCredentials(newCreds, oldCreds map[string]any) {
|
||||||
|
for k, v := range oldCreds {
|
||||||
|
if _, exists := newCreds[k]; !exists {
|
||||||
|
newCreds[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func AntigravityTokenCacheKey(account *Account) string {
|
func AntigravityTokenCacheKey(account *Account) string {
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
|
|||||||
Reference in New Issue
Block a user