2025-12-25 06:44:18 -08:00
package service
import (
"context"
2025-12-25 21:24:44 -08:00
"encoding/json"
2025-12-25 06:44:18 -08:00
"errors"
"fmt"
2025-12-25 21:24:44 -08:00
"io"
"net/http"
"net/url"
2025-12-25 06:44:18 -08:00
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
)
type GeminiOAuthService struct {
sessionStore * geminicli . SessionStore
proxyRepo ports . ProxyRepository
oauthClient ports . GeminiOAuthClient
codeAssist ports . GeminiCliCodeAssistClient
cfg * config . Config
}
func NewGeminiOAuthService (
proxyRepo ports . ProxyRepository ,
oauthClient ports . GeminiOAuthClient ,
codeAssist ports . GeminiCliCodeAssistClient ,
cfg * config . Config ,
) * GeminiOAuthService {
return & GeminiOAuthService {
sessionStore : geminicli . NewSessionStore ( ) ,
proxyRepo : proxyRepo ,
oauthClient : oauthClient ,
codeAssist : codeAssist ,
cfg : cfg ,
}
}
type GeminiAuthURLResult struct {
AuthURL string ` json:"auth_url" `
SessionID string ` json:"session_id" `
State string ` json:"state" `
}
2025-12-25 21:24:44 -08:00
func ( s * GeminiOAuthService ) GenerateAuthURL ( ctx context . Context , proxyID * int64 , redirectURI , projectID , oauthType string ) ( * GeminiAuthURLResult , error ) {
2025-12-25 06:44:18 -08:00
state , err := geminicli . GenerateState ( )
if err != nil {
return nil , fmt . Errorf ( "failed to generate state: %w" , err )
}
codeVerifier , err := geminicli . GenerateCodeVerifier ( )
if err != nil {
return nil , fmt . Errorf ( "failed to generate code verifier: %w" , err )
}
codeChallenge := geminicli . GenerateCodeChallenge ( codeVerifier )
sessionID , err := geminicli . GenerateSessionID ( )
if err != nil {
return nil , fmt . Errorf ( "failed to generate session ID: %w" , err )
}
var proxyURL string
if proxyID != nil {
proxy , err := s . proxyRepo . GetByID ( ctx , * proxyID )
if err == nil && proxy != nil {
proxyURL = proxy . URL ( )
}
}
2025-12-25 21:24:44 -08:00
// 两种 OAuth 模式都使用相同的配置,只是 scopes 不同
// scopes 会在 EffectiveOAuthConfig 中根据 oauthType 自动选择
oauthCfg := geminicli . OAuthConfig {
ClientID : s . cfg . Gemini . OAuth . ClientID ,
ClientSecret : s . cfg . Gemini . OAuth . ClientSecret ,
Scopes : s . cfg . Gemini . OAuth . Scopes ,
}
2025-12-25 06:44:18 -08:00
session := & geminicli . OAuthSession {
State : state ,
CodeVerifier : codeVerifier ,
ProxyURL : proxyURL ,
RedirectURI : redirectURI ,
2025-12-25 21:24:44 -08:00
ProjectID : strings . TrimSpace ( projectID ) ,
OAuthType : oauthType ,
2025-12-25 06:44:18 -08:00
CreatedAt : time . Now ( ) ,
}
s . sessionStore . Set ( sessionID , session )
2025-12-25 21:24:44 -08:00
effectiveCfg , err := geminicli . EffectiveOAuthConfig ( oauthCfg , oauthType )
if err != nil {
return nil , err
2025-12-25 06:44:18 -08:00
}
2025-12-25 21:24:44 -08:00
// For Code Assist with Gemini CLI credentials, use the CLI's redirect URI
if oauthType == "code_assist" {
redirectURI = geminicli . GeminiCLIRedirectURI
session . RedirectURI = redirectURI
s . sessionStore . Set ( sessionID , session )
}
authURL , err := geminicli . BuildAuthorizationURL ( effectiveCfg , state , codeChallenge , redirectURI , session . ProjectID , oauthType )
2025-12-25 06:44:18 -08:00
if err != nil {
return nil , err
}
return & GeminiAuthURLResult {
AuthURL : authURL ,
SessionID : sessionID ,
State : state ,
} , nil
}
type GeminiExchangeCodeInput struct {
2025-12-25 21:24:44 -08:00
SessionID string
State string
Code string
ProxyID * int64
OAuthType string // "code_assist" 或 "ai_studio"
2025-12-25 06:44:18 -08:00
}
type GeminiTokenInfo struct {
AccessToken string ` json:"access_token" `
RefreshToken string ` json:"refresh_token" `
ExpiresIn int64 ` json:"expires_in" `
ExpiresAt int64 ` json:"expires_at" `
TokenType string ` json:"token_type" `
Scope string ` json:"scope,omitempty" `
ProjectID string ` json:"project_id,omitempty" `
2025-12-25 21:24:44 -08:00
OAuthType string ` json:"oauth_type,omitempty" ` // "code_assist" 或 "ai_studio"
2025-12-25 06:44:18 -08:00
}
func ( s * GeminiOAuthService ) ExchangeCode ( ctx context . Context , input * GeminiExchangeCodeInput ) ( * GeminiTokenInfo , error ) {
session , ok := s . sessionStore . Get ( input . SessionID )
if ! ok {
return nil , fmt . Errorf ( "session not found or expired" )
}
if strings . TrimSpace ( input . State ) == "" || input . State != session . State {
return nil , fmt . Errorf ( "invalid state" )
}
proxyURL := session . ProxyURL
if input . ProxyID != nil {
proxy , err := s . proxyRepo . GetByID ( ctx , * input . ProxyID )
if err == nil && proxy != nil {
proxyURL = proxy . URL ( )
}
}
redirectURI := session . RedirectURI
tokenResp , err := s . oauthClient . ExchangeCode ( ctx , input . Code , session . CodeVerifier , redirectURI , proxyURL )
if err != nil {
return nil , fmt . Errorf ( "failed to exchange code: %w" , err )
}
2025-12-25 21:24:44 -08:00
sessionProjectID := strings . TrimSpace ( session . ProjectID )
oauthType := session . OAuthType
if oauthType == "" {
oauthType = "code_assist" // 默认为 code_assist 以兼容旧 session
}
2025-12-25 06:44:18 -08:00
s . sessionStore . Delete ( input . SessionID )
2025-12-25 08:39:32 -08:00
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time . Now ( ) . Unix ( ) + tokenResp . ExpiresIn - 300
2025-12-25 21:24:44 -08:00
projectID := sessionProjectID
// 对于 code_assist 模式, project_id 是必需的
// 对于 ai_studio 模式, project_id 是可选的(不影响使用 AI Studio API)
if oauthType == "code_assist" {
if projectID == "" {
var err error
projectID , err = s . fetchProjectID ( ctx , tokenResp . AccessToken , proxyURL )
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt . Printf ( "[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n" , err )
}
}
if strings . TrimSpace ( projectID ) == "" {
return nil , fmt . Errorf ( "missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project" )
}
}
2025-12-25 06:44:18 -08:00
return & GeminiTokenInfo {
AccessToken : tokenResp . AccessToken ,
RefreshToken : tokenResp . RefreshToken ,
TokenType : tokenResp . TokenType ,
ExpiresIn : tokenResp . ExpiresIn ,
ExpiresAt : expiresAt ,
Scope : tokenResp . Scope ,
ProjectID : projectID ,
2025-12-25 21:24:44 -08:00
OAuthType : oauthType ,
2025-12-25 06:44:18 -08:00
} , nil
}
func ( s * GeminiOAuthService ) RefreshToken ( ctx context . Context , refreshToken , proxyURL string ) ( * GeminiTokenInfo , error ) {
var lastErr error
for attempt := 0 ; attempt <= 3 ; attempt ++ {
if attempt > 0 {
backoff := time . Duration ( 1 << uint ( attempt - 1 ) ) * time . Second
if backoff > 30 * time . Second {
backoff = 30 * time . Second
}
time . Sleep ( backoff )
}
tokenResp , err := s . oauthClient . RefreshToken ( ctx , refreshToken , proxyURL )
if err == nil {
2025-12-25 08:39:32 -08:00
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time . Now ( ) . Unix ( ) + tokenResp . ExpiresIn - 300
2025-12-25 06:44:18 -08:00
return & GeminiTokenInfo {
AccessToken : tokenResp . AccessToken ,
RefreshToken : tokenResp . RefreshToken ,
TokenType : tokenResp . TokenType ,
ExpiresIn : tokenResp . ExpiresIn ,
ExpiresAt : expiresAt ,
Scope : tokenResp . Scope ,
} , nil
}
if isNonRetryableGeminiOAuthError ( err ) {
return nil , err
}
lastErr = err
}
return nil , fmt . Errorf ( "token refresh failed after retries: %w" , lastErr )
}
func isNonRetryableGeminiOAuthError ( err error ) bool {
msg := err . Error ( )
nonRetryable := [ ] string {
"invalid_grant" ,
"invalid_client" ,
"unauthorized_client" ,
"access_denied" ,
}
for _ , needle := range nonRetryable {
if strings . Contains ( msg , needle ) {
return true
}
}
return false
}
func ( s * GeminiOAuthService ) RefreshAccountToken ( ctx context . Context , account * model . Account ) ( * GeminiTokenInfo , error ) {
if account . Platform != model . PlatformGemini || account . Type != model . AccountTypeOAuth {
return nil , fmt . Errorf ( "account is not a Gemini OAuth account" )
}
refreshToken := account . GetCredential ( "refresh_token" )
if strings . TrimSpace ( refreshToken ) == "" {
return nil , fmt . Errorf ( "no refresh token available" )
}
var proxyURL string
if account . ProxyID != nil {
proxy , err := s . proxyRepo . GetByID ( ctx , * account . ProxyID )
if err == nil && proxy != nil {
proxyURL = proxy . URL ( )
}
}
2025-12-25 21:24:44 -08:00
tokenInfo , err := s . RefreshToken ( ctx , refreshToken , proxyURL )
if err != nil {
return nil , err
}
// Preserve oauth_type from the account (defaults to code_assist for backward compatibility).
oauthType := strings . TrimSpace ( account . GetCredential ( "oauth_type" ) )
if oauthType == "" {
oauthType = "code_assist"
}
tokenInfo . OAuthType = oauthType
// Preserve account's project_id when present.
existingProjectID := strings . TrimSpace ( account . GetCredential ( "project_id" ) )
if existingProjectID != "" {
tokenInfo . ProjectID = existingProjectID
}
// For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings . TrimSpace ( tokenInfo . ProjectID ) == "" {
projectID , err := s . fetchProjectID ( ctx , tokenInfo . AccessToken , proxyURL )
if err != nil {
return nil , fmt . Errorf ( "failed to auto-detect project_id: %w" , err )
}
projectID = strings . TrimSpace ( projectID )
if projectID == "" {
return nil , fmt . Errorf ( "failed to auto-detect project_id: empty result" )
}
tokenInfo . ProjectID = projectID
}
return tokenInfo , nil
2025-12-25 06:44:18 -08:00
}
func ( s * GeminiOAuthService ) BuildAccountCredentials ( tokenInfo * GeminiTokenInfo ) map [ string ] any {
creds := map [ string ] any {
"access_token" : tokenInfo . AccessToken ,
"expires_at" : strconv . FormatInt ( tokenInfo . ExpiresAt , 10 ) ,
}
if tokenInfo . RefreshToken != "" {
creds [ "refresh_token" ] = tokenInfo . RefreshToken
}
if tokenInfo . TokenType != "" {
creds [ "token_type" ] = tokenInfo . TokenType
}
if tokenInfo . Scope != "" {
creds [ "scope" ] = tokenInfo . Scope
}
if tokenInfo . ProjectID != "" {
creds [ "project_id" ] = tokenInfo . ProjectID
}
2025-12-25 21:24:44 -08:00
if tokenInfo . OAuthType != "" {
creds [ "oauth_type" ] = tokenInfo . OAuthType
}
2025-12-25 06:44:18 -08:00
return creds
}
func ( s * GeminiOAuthService ) Stop ( ) {
s . sessionStore . Stop ( )
}
func ( s * GeminiOAuthService ) fetchProjectID ( ctx context . Context , accessToken , proxyURL string ) ( string , error ) {
if s . codeAssist == nil {
return "" , errors . New ( "code assist client not configured" )
}
2025-12-25 21:24:44 -08:00
loadResp , loadErr := s . codeAssist . LoadCodeAssist ( ctx , accessToken , proxyURL , nil )
if loadErr == nil && loadResp != nil && strings . TrimSpace ( loadResp . CloudAICompanionProject ) != "" {
2025-12-25 06:44:18 -08:00
return strings . TrimSpace ( loadResp . CloudAICompanionProject ) , nil
}
2025-12-25 21:24:44 -08:00
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
2025-12-25 06:44:18 -08:00
tierID := "LEGACY"
if loadResp != nil {
for _ , tier := range loadResp . AllowedTiers {
if tier . IsDefault && strings . TrimSpace ( tier . ID ) != "" {
2025-12-25 21:24:44 -08:00
tierID = strings . TrimSpace ( tier . ID )
2025-12-25 06:44:18 -08:00
break
}
}
2025-12-25 21:24:44 -08:00
if strings . TrimSpace ( tierID ) == "" || tierID == "LEGACY" {
for _ , tier := range loadResp . AllowedTiers {
if strings . TrimSpace ( tier . ID ) != "" {
tierID = strings . TrimSpace ( tier . ID )
break
}
}
}
2025-12-25 06:44:18 -08:00
}
req := & geminicli . OnboardUserRequest {
TierID : tierID ,
Metadata : geminicli . LoadCodeAssistMetadata {
IDEType : "ANTIGRAVITY" ,
Platform : "PLATFORM_UNSPECIFIED" ,
PluginType : "GEMINI" ,
} ,
}
maxAttempts := 5
for attempt := 1 ; attempt <= maxAttempts ; attempt ++ {
resp , err := s . codeAssist . OnboardUser ( ctx , accessToken , proxyURL , req )
if err != nil {
2025-12-25 21:24:44 -08:00
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback , fbErr := fetchProjectIDFromResourceManager ( ctx , accessToken , proxyURL )
if fbErr == nil && strings . TrimSpace ( fallback ) != "" {
return strings . TrimSpace ( fallback ) , nil
}
2025-12-25 06:44:18 -08:00
return "" , err
}
if resp . Done {
2025-12-25 21:24:44 -08:00
if resp . Response != nil && resp . Response . CloudAICompanionProject != nil {
switch v := resp . Response . CloudAICompanionProject . ( type ) {
case string :
return strings . TrimSpace ( v ) , nil
case map [ string ] any :
if id , ok := v [ "id" ] . ( string ) ; ok {
return strings . TrimSpace ( id ) , nil
}
2025-12-25 06:44:18 -08:00
}
}
2025-12-25 21:24:44 -08:00
fallback , fbErr := fetchProjectIDFromResourceManager ( ctx , accessToken , proxyURL )
if fbErr == nil && strings . TrimSpace ( fallback ) != "" {
return strings . TrimSpace ( fallback ) , nil
}
return "" , errors . New ( "onboardUser completed but no project_id returned" )
2025-12-25 06:44:18 -08:00
}
time . Sleep ( 2 * time . Second )
}
2025-12-25 21:24:44 -08:00
fallback , fbErr := fetchProjectIDFromResourceManager ( ctx , accessToken , proxyURL )
if fbErr == nil && strings . TrimSpace ( fallback ) != "" {
return strings . TrimSpace ( fallback ) , nil
}
if loadErr != nil {
return "" , fmt . Errorf ( "loadCodeAssist failed (%v) and onboardUser timeout after %d attempts" , loadErr , maxAttempts )
}
2025-12-25 06:44:18 -08:00
return "" , fmt . Errorf ( "onboardUser timeout after %d attempts" , maxAttempts )
}
2025-12-25 21:24:44 -08:00
type googleCloudProject struct {
ProjectID string ` json:"projectId" `
DisplayName string ` json:"name" `
LifecycleState string ` json:"lifecycleState" `
}
type googleCloudProjectsResponse struct {
Projects [ ] googleCloudProject ` json:"projects" `
}
func fetchProjectIDFromResourceManager ( ctx context . Context , accessToken , proxyURL string ) ( string , error ) {
req , err := http . NewRequestWithContext ( ctx , http . MethodGet , "https://cloudresourcemanager.googleapis.com/v1/projects" , nil )
if err != nil {
return "" , fmt . Errorf ( "failed to create resource manager request: %w" , err )
}
req . Header . Set ( "Authorization" , "Bearer " + accessToken )
req . Header . Set ( "User-Agent" , geminicli . GeminiCLIUserAgent )
client := & http . Client { Timeout : 30 * time . Second }
if strings . TrimSpace ( proxyURL ) != "" {
if proxyURLParsed , err := url . Parse ( strings . TrimSpace ( proxyURL ) ) ; err == nil {
client . Transport = & http . Transport { Proxy : http . ProxyURL ( proxyURLParsed ) }
}
}
resp , err := client . Do ( req )
if err != nil {
return "" , fmt . Errorf ( "resource manager request failed: %w" , err )
}
2025-12-25 21:35:30 -08:00
defer func ( ) { _ = resp . Body . Close ( ) } ( )
2025-12-25 21:24:44 -08:00
bodyBytes , err := io . ReadAll ( resp . Body )
if err != nil {
return "" , fmt . Errorf ( "failed to read resource manager response: %w" , err )
}
if resp . StatusCode != http . StatusOK {
return "" , fmt . Errorf ( "resource manager HTTP %d: %s" , resp . StatusCode , string ( bodyBytes ) )
}
var projectsResp googleCloudProjectsResponse
if err := json . Unmarshal ( bodyBytes , & projectsResp ) ; err != nil {
return "" , fmt . Errorf ( "failed to parse resource manager response: %w" , err )
}
active := make ( [ ] googleCloudProject , 0 , len ( projectsResp . Projects ) )
for _ , p := range projectsResp . Projects {
if p . LifecycleState == "ACTIVE" && strings . TrimSpace ( p . ProjectID ) != "" {
active = append ( active , p )
}
}
if len ( active ) == 0 {
return "" , errors . New ( "no ACTIVE projects found from resource manager" )
}
// Prefer likely companion projects first.
for _ , p := range active {
id := strings . ToLower ( strings . TrimSpace ( p . ProjectID ) )
name := strings . ToLower ( strings . TrimSpace ( p . DisplayName ) )
if strings . Contains ( id , "cloud-ai-companion" ) || strings . Contains ( name , "cloud ai companion" ) || strings . Contains ( name , "code assist" ) {
return strings . TrimSpace ( p . ProjectID ) , nil
}
}
// Then prefer "default".
for _ , p := range active {
id := strings . ToLower ( strings . TrimSpace ( p . ProjectID ) )
name := strings . ToLower ( strings . TrimSpace ( p . DisplayName ) )
if strings . Contains ( id , "default" ) || strings . Contains ( name , "default" ) {
return strings . TrimSpace ( p . ProjectID ) , nil
}
}
return strings . TrimSpace ( active [ 0 ] . ProjectID ) , nil
}