Files
sub2api/backend/internal/service/gemini_oauth_service.go

557 lines
18 KiB
Go
Raw Normal View History

package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
type GeminiOAuthService struct {
sessionStore *geminicli.SessionStore
proxyRepo ProxyRepository
oauthClient GeminiOAuthClient
codeAssist GeminiCliCodeAssistClient
cfg *config.Config
}
type GeminiOAuthCapabilities struct {
AIStudioOAuthEnabled bool `json:"ai_studio_oauth_enabled"`
RequiredRedirectURIs []string `json:"required_redirect_uris"`
}
func NewGeminiOAuthService(
proxyRepo ProxyRepository,
oauthClient GeminiOAuthClient,
codeAssist GeminiCliCodeAssistClient,
cfg *config.Config,
) *GeminiOAuthService {
return &GeminiOAuthService{
sessionStore: geminicli.NewSessionStore(),
proxyRepo: proxyRepo,
oauthClient: oauthClient,
codeAssist: codeAssist,
cfg: cfg,
}
}
func (s *GeminiOAuthService) GetOAuthConfig() *GeminiOAuthCapabilities {
// AI Studio OAuth is only enabled when the operator configures a custom OAuth client.
clientID := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientID)
clientSecret := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientSecret)
enabled := clientID != "" && clientSecret != "" &&
(clientID != geminicli.GeminiCLIOAuthClientID || clientSecret != geminicli.GeminiCLIOAuthClientSecret)
return &GeminiOAuthCapabilities{
AIStudioOAuthEnabled: enabled,
RequiredRedirectURIs: []string{geminicli.AIStudioOAuthRedirectURI},
}
}
type GeminiAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
}
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
state, err := geminicli.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
codeVerifier, err := geminicli.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
}
codeChallenge := geminicli.GenerateCodeChallenge(codeVerifier)
sessionID, err := geminicli.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
}
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
// - ai_studio: requires a user-provided OAuth client.
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
oauthCfg.ClientID = ""
oauthCfg.ClientSecret = ""
}
session := &geminicli.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
RedirectURI: redirectURI,
ProjectID: strings.TrimSpace(projectID),
OAuthType: oauthType,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
effectiveCfg, err := geminicli.EffectiveOAuthConfig(oauthCfg, oauthType)
if err != nil {
return nil, err
}
isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID &&
effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret
// AI Studio OAuth requires a user-provided OAuth client (built-in Gemini CLI client is scope-restricted).
if oauthType == "ai_studio" && isBuiltinClient {
return nil, fmt.Errorf("AI Studio OAuth requires a custom OAuth Client (GEMINI_OAUTH_CLIENT_ID / GEMINI_OAUTH_CLIENT_SECRET). If you don't want to configure an OAuth client, please use an AI Studio API Key account instead")
}
// Redirect URI strategy:
// - code_assist: use Gemini CLI redirect URI (codeassist.google.com/authcode)
// - ai_studio: use localhost callback for manual copy/paste flow
if oauthType == "code_assist" {
redirectURI = geminicli.GeminiCLIRedirectURI
} else {
redirectURI = geminicli.AIStudioOAuthRedirectURI
}
session.RedirectURI = redirectURI
s.sessionStore.Set(sessionID, session)
authURL, err := geminicli.BuildAuthorizationURL(effectiveCfg, state, codeChallenge, redirectURI, session.ProjectID, oauthType)
if err != nil {
return nil, err
}
return &GeminiAuthURLResult{
AuthURL: authURL,
SessionID: sessionID,
State: state,
}, nil
}
type GeminiExchangeCodeInput struct {
SessionID string
State string
Code string
ProxyID *int64
OAuthType string // "code_assist" 或 "ai_studio"
}
type GeminiTokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
return nil, fmt.Errorf("invalid state")
}
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
redirectURI := session.RedirectURI
// Resolve oauth_type early (defaults to code_assist for backward compatibility).
oauthType := session.OAuthType
if oauthType == "" {
oauthType = "code_assist"
}
// If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured.
if oauthType == "ai_studio" {
effectiveCfg, err := geminicli.EffectiveOAuthConfig(geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}, "ai_studio")
if err != nil {
return nil, err
}
isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID &&
effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret
if isBuiltinClient {
return nil, fmt.Errorf("AI Studio OAuth requires a custom OAuth Client. Please use an AI Studio API Key account, or configure GEMINI_OAUTH_CLIENT_ID / GEMINI_OAUTH_CLIENT_SECRET and re-authorize")
}
}
// code_assist always uses the built-in client and its fixed redirect URI.
if oauthType == "code_assist" {
redirectURI = geminicli.GeminiCLIRedirectURI
}
tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID)
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
projectID := sessionProjectID
// 对于 code_assist 模式project_id 是必需的
// 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API
if oauthType == "code_assist" {
if projectID == "" {
var err error
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
}
}
if strings.TrimSpace(projectID) == "" {
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
}
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
OAuthType: oauthType,
}, nil
}
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
var lastErr error
for attempt := 0; attempt <= 3; attempt++ {
if attempt > 0 {
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
time.Sleep(backoff)
}
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
if err == nil {
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
}, nil
}
if isNonRetryableGeminiOAuthError(err) {
return nil, err
}
lastErr = err
}
return nil, fmt.Errorf("token refresh failed after retries: %w", lastErr)
}
func isNonRetryableGeminiOAuthError(err error) bool {
msg := err.Error()
nonRetryable := []string{
"invalid_grant",
"invalid_client",
"unauthorized_client",
"access_denied",
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
return true
}
}
return false
}
func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*GeminiTokenInfo, error) {
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("account is not a Gemini OAuth account")
}
refreshToken := account.GetCredential("refresh_token")
if strings.TrimSpace(refreshToken) == "" {
return nil, fmt.Errorf("no refresh token available")
}
// Preserve oauth_type from the account (defaults to code_assist for backward compatibility).
oauthType := strings.TrimSpace(account.GetCredential("oauth_type"))
if oauthType == "" {
oauthType = "code_assist"
}
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
tokenInfo, err := s.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
// Backward compatibility:
// Older versions could refresh Code Assist tokens using a user-provided OAuth client when configured.
// If the refresh token was originally issued to that custom client, forcing the built-in client will
// fail with "unauthorized_client". In that case, retry with the custom client (ai_studio path) when available.
if err != nil && oauthType == "code_assist" && strings.Contains(err.Error(), "unauthorized_client") && s.GetOAuthConfig().AIStudioOAuthEnabled {
if alt, altErr := s.RefreshToken(ctx, "ai_studio", refreshToken, proxyURL); altErr == nil {
tokenInfo = alt
err = nil
}
}
if err != nil {
// Provide a more actionable error for common OAuth client mismatch issues.
if strings.Contains(err.Error(), "unauthorized_client") {
return nil, fmt.Errorf("%w (OAuth client mismatch: the refresh_token is bound to the OAuth client used during authorization; please re-authorize this account or restore the original GEMINI_OAUTH_CLIENT_ID/SECRET)", err)
}
return nil, err
}
tokenInfo.OAuthType = oauthType
// Preserve account's project_id when present.
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
// For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
}
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) map[string]any {
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.TokenType != "" {
creds["token_type"] = tokenInfo.TokenType
}
if tokenInfo.Scope != "" {
creds["scope"] = tokenInfo.Scope
}
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
if tokenInfo.OAuthType != "" {
creds["oauth_type"] = tokenInfo.OAuthType
}
return creds
}
func (s *GeminiOAuthService) Stop() {
s.sessionStore.Stop()
}
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) {
if s.codeAssist == nil {
return "", errors.New("code assist client not configured")
}
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
}
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
tierID := "LEGACY"
if loadResp != nil {
for _, tier := range loadResp.AllowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
for _, tier := range loadResp.AllowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
}
req := &geminicli.OnboardUserRequest{
TierID: tierID,
Metadata: geminicli.LoadCodeAssistMetadata{
IDEType: "ANTIGRAVITY",
Platform: "PLATFORM_UNSPECIFIED",
PluginType: "GEMINI",
},
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
resp, err := s.codeAssist.OnboardUser(ctx, accessToken, proxyURL, req)
if err != nil {
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
return "", err
}
if resp.Done {
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
switch v := resp.Response.CloudAICompanionProject.(type) {
case string:
return strings.TrimSpace(v), nil
case map[string]any:
if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), nil
}
}
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
return "", errors.New("onboardUser completed but no project_id returned")
}
time.Sleep(2 * time.Second)
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
}
if loadErr != nil {
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
}
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
}
type googleCloudProject struct {
ProjectID string `json:"projectId"`
DisplayName string `json:"name"`
LifecycleState string `json:"lifecycleState"`
}
type googleCloudProjectsResponse struct {
Projects []googleCloudProject `json:"projects"`
}
func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
if err != nil {
return "", fmt.Errorf("failed to create resource manager request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("resource manager request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read resource manager response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("resource manager HTTP %d: %s", resp.StatusCode, string(bodyBytes))
}
var projectsResp googleCloudProjectsResponse
if err := json.Unmarshal(bodyBytes, &projectsResp); err != nil {
return "", fmt.Errorf("failed to parse resource manager response: %w", err)
}
active := make([]googleCloudProject, 0, len(projectsResp.Projects))
for _, p := range projectsResp.Projects {
if p.LifecycleState == "ACTIVE" && strings.TrimSpace(p.ProjectID) != "" {
active = append(active, p)
}
}
if len(active) == 0 {
return "", errors.New("no ACTIVE projects found from resource manager")
}
// Prefer likely companion projects first.
for _, p := range active {
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
if strings.Contains(id, "cloud-ai-companion") || strings.Contains(name, "cloud ai companion") || strings.Contains(name, "code assist") {
return strings.TrimSpace(p.ProjectID), nil
}
}
// Then prefer "default".
for _, p := range active {
id := strings.ToLower(strings.TrimSpace(p.ProjectID))
name := strings.ToLower(strings.TrimSpace(p.DisplayName))
if strings.Contains(id, "default") || strings.Contains(name, "default") {
return strings.TrimSpace(p.ProjectID), nil
}
}
return strings.TrimSpace(active[0].ProjectID), nil
}