From aea48ae1abae364c928cae6dd388c1bfd9dd61bf Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 25 Dec 2025 06:43:00 -0800 Subject: [PATCH 01/42] =?UTF-8?q?feat(config):=20=E6=96=B0=E5=A2=9E=20Gemi?= =?UTF-8?q?ni=20=E9=85=8D=E7=BD=AE=E9=A1=B9=E5=92=8C=20geminicli=20?= =?UTF-8?q?=E6=A0=B8=E5=BF=83=E5=8C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 Gemini OAuth 配置结构 - 实现 geminicli 包(OAuth、Token、CodeAssist 类型) - 更新配置示例文件 --- backend/internal/config/config.go | 16 ++ .../pkg/geminicli/codeassist_types.go | 38 ++++ backend/internal/pkg/geminicli/constants.go | 20 +++ backend/internal/pkg/geminicli/oauth.go | 167 ++++++++++++++++++ backend/internal/pkg/geminicli/sanitize.go | 46 +++++ backend/internal/pkg/geminicli/token_types.go | 9 + deploy/config.example.yaml | 11 ++ 7 files changed, 307 insertions(+) create mode 100644 backend/internal/pkg/geminicli/codeassist_types.go create mode 100644 backend/internal/pkg/geminicli/constants.go create mode 100644 backend/internal/pkg/geminicli/oauth.go create mode 100644 backend/internal/pkg/geminicli/sanitize.go create mode 100644 backend/internal/pkg/geminicli/token_types.go diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 18fb162d..2050a0cf 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -18,6 +18,17 @@ type Config struct { Gateway GatewayConfig `mapstructure:"gateway"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` +} + +type GeminiConfig struct { + OAuth GeminiOAuthConfig `mapstructure:"oauth"` +} + +type GeminiOAuthConfig struct { + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + Scopes string `mapstructure:"scopes"` } // TokenRefreshConfig OAuth token自动刷新配置 @@ -214,6 +225,11 @@ func setDefaults() { viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新 viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 + + // Gemini (optional) + viper.SetDefault("gemini.oauth.client_id", "") + viper.SetDefault("gemini.oauth.client_secret", "") + viper.SetDefault("gemini.oauth.scopes", "") } func (c *Config) Validate() error { diff --git a/backend/internal/pkg/geminicli/codeassist_types.go b/backend/internal/pkg/geminicli/codeassist_types.go new file mode 100644 index 00000000..59d3ef78 --- /dev/null +++ b/backend/internal/pkg/geminicli/codeassist_types.go @@ -0,0 +1,38 @@ +package geminicli + +// LoadCodeAssistRequest matches done-hub's internal Code Assist call. +type LoadCodeAssistRequest struct { + Metadata LoadCodeAssistMetadata `json:"metadata"` +} + +type LoadCodeAssistMetadata struct { + IDEType string `json:"ideType"` + Platform string `json:"platform"` + PluginType string `json:"pluginType"` +} + +type LoadCodeAssistResponse struct { + CurrentTier string `json:"currentTier,omitempty"` + CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"` + AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"` +} + +type AllowedTier struct { + ID string `json:"id"` + IsDefault bool `json:"isDefault,omitempty"` +} + +type OnboardUserRequest struct { + TierID string `json:"tierId"` + Metadata LoadCodeAssistMetadata `json:"metadata"` +} + +type OnboardUserResponse struct { + Done bool `json:"done"` + Response *OnboardUserResultData `json:"response,omitempty"` + Name string `json:"name,omitempty"` +} + +type OnboardUserResultData struct { + CloudAICompanionProject any `json:"cloudaicompanionProject,omitempty"` +} diff --git a/backend/internal/pkg/geminicli/constants.go b/backend/internal/pkg/geminicli/constants.go new file mode 100644 index 00000000..7ad33d75 --- /dev/null +++ b/backend/internal/pkg/geminicli/constants.go @@ -0,0 +1,20 @@ +package geminicli + +import "time" + +const ( + AIStudioBaseURL = "https://generativelanguage.googleapis.com" + GeminiCliBaseURL = "https://cloudcode-pa.googleapis.com" + + AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth" + TokenURL = "https://oauth2.googleapis.com/token" + + // DefaultScopes is the minimal scope set for GeminiCli/CodeAssist usage. + // Keep this conservative and expand only when we have a clear requirement. + DefaultScopes = "https://www.googleapis.com/auth/cloud-platform" + + SessionTTL = 30 * time.Minute + + // GeminiCLIUserAgent mimics Gemini CLI to maximize compatibility with internal endpoints. + GeminiCLIUserAgent = "GeminiCLI/0.1.5 (Windows; AMD64)" +) diff --git a/backend/internal/pkg/geminicli/oauth.go b/backend/internal/pkg/geminicli/oauth.go new file mode 100644 index 00000000..2b6cf714 --- /dev/null +++ b/backend/internal/pkg/geminicli/oauth.go @@ -0,0 +1,167 @@ +package geminicli + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "net/url" + "strings" + "sync" + "time" +) + +type OAuthConfig struct { + ClientID string + ClientSecret string + Scopes string +} + +type OAuthSession struct { + State string `json:"state"` + CodeVerifier string `json:"code_verifier"` + ProxyURL string `json:"proxy_url,omitempty"` + RedirectURI string `json:"redirect_uri"` + CreatedAt time.Time `json:"created_at"` +} + +type SessionStore struct { + mu sync.RWMutex + sessions map[string]*OAuthSession + stopCh chan struct{} +} + +func NewSessionStore() *SessionStore { + store := &SessionStore{ + sessions: make(map[string]*OAuthSession), + stopCh: make(chan struct{}), + } + go store.cleanup() + return store +} + +func (s *SessionStore) Set(sessionID string, session *OAuthSession) { + s.mu.Lock() + defer s.mu.Unlock() + s.sessions[sessionID] = session +} + +func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + session, ok := s.sessions[sessionID] + if !ok { + return nil, false + } + if time.Since(session.CreatedAt) > SessionTTL { + return nil, false + } + return session, true +} + +func (s *SessionStore) Delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, sessionID) +} + +func (s *SessionStore) Stop() { + select { + case <-s.stopCh: + return + default: + close(s.stopCh) + } +} + +func (s *SessionStore) cleanup() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for { + select { + case <-s.stopCh: + return + case <-ticker.C: + s.mu.Lock() + for id, session := range s.sessions { + if time.Since(session.CreatedAt) > SessionTTL { + delete(s.sessions, id) + } + } + s.mu.Unlock() + } + } +} + +func GenerateRandomBytes(n int) ([]byte, error) { + b := make([]byte, n) + _, err := rand.Read(b) + if err != nil { + return nil, err + } + return b, nil +} + +func GenerateState() (string, error) { + bytes, err := GenerateRandomBytes(32) + if err != nil { + return "", err + } + return base64URLEncode(bytes), nil +} + +func GenerateSessionID() (string, error) { + bytes, err := GenerateRandomBytes(16) + if err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil +} + +// GenerateCodeVerifier returns an RFC 7636 compatible code verifier (43+ chars). +func GenerateCodeVerifier() (string, error) { + bytes, err := GenerateRandomBytes(32) + if err != nil { + return "", err + } + return base64URLEncode(bytes), nil +} + +func GenerateCodeChallenge(verifier string) string { + hash := sha256.Sum256([]byte(verifier)) + return base64URLEncode(hash[:]) +} + +func base64URLEncode(data []byte) string { + return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=") +} + +func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI string) (string, error) { + if strings.TrimSpace(cfg.ClientID) == "" { + return "", fmt.Errorf("gemini oauth client_id is empty") + } + redirectURI = strings.TrimSpace(redirectURI) + if redirectURI == "" { + return "", fmt.Errorf("redirect_uri is required") + } + + scopes := strings.TrimSpace(cfg.Scopes) + if scopes == "" { + scopes = DefaultScopes + } + + params := url.Values{} + params.Set("response_type", "code") + params.Set("client_id", cfg.ClientID) + params.Set("redirect_uri", redirectURI) + params.Set("scope", scopes) + params.Set("state", state) + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + params.Set("access_type", "offline") + params.Set("prompt", "consent") + params.Set("include_granted_scopes", "true") + + return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()), nil +} diff --git a/backend/internal/pkg/geminicli/sanitize.go b/backend/internal/pkg/geminicli/sanitize.go new file mode 100644 index 00000000..f5c407e4 --- /dev/null +++ b/backend/internal/pkg/geminicli/sanitize.go @@ -0,0 +1,46 @@ +package geminicli + +import "strings" + +const maxLogBodyLen = 2048 + +func SanitizeBodyForLogs(body string) string { + body = truncateBase64InMessage(body) + if len(body) > maxLogBodyLen { + body = body[:maxLogBodyLen] + "...[truncated]" + } + return body +} + +func truncateBase64InMessage(message string) string { + const maxBase64Length = 50 + + result := message + offset := 0 + for { + idx := strings.Index(result[offset:], ";base64,") + if idx == -1 { + break + } + actualIdx := offset + idx + start := actualIdx + len(";base64,") + + end := start + for end < len(result) && isBase64Char(result[end]) { + end++ + } + + if end-start > maxBase64Length { + result = result[:start+maxBase64Length] + "...[truncated]" + result[end:] + offset = start + maxBase64Length + len("...[truncated]") + continue + } + offset = end + } + + return result +} + +func isBase64Char(c byte) bool { + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '=' +} diff --git a/backend/internal/pkg/geminicli/token_types.go b/backend/internal/pkg/geminicli/token_types.go new file mode 100644 index 00000000..f3cfbaed --- /dev/null +++ b/backend/internal/pkg/geminicli/token_types.go @@ -0,0 +1,9 @@ +package geminicli + +type TokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + Scope string `json:"scope,omitempty"` +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 1e466244..400ebabb 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -87,3 +87,14 @@ pricing: update_interval_hours: 24 # Hash check interval in minutes hash_check_interval_minutes: 10 + +# ============================================================================= +# Gemini (Optional) +# ============================================================================= +gemini: + oauth: + # Google OAuth Client ID / Secret (for GeminiCli / Code Assist internal API) + client_id: "" + client_secret: "" + # Optional scopes (space-separated). Leave empty to use default cloud-platform scope. + scopes: "" From 2bafc28a9bb05b4285d52e787fb71402e0d7d25c Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 25 Dec 2025 06:43:35 -0800 Subject: [PATCH 02/42] =?UTF-8?q?feat(repository):=20=E5=AE=9E=E7=8E=B0=20?= =?UTF-8?q?Gemini=20OAuth=20=E5=92=8C=20Token=20=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E5=AE=A2=E6=88=B7=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 Gemini OAuth 客户端实现 - 实现 Redis 基础的 Token 缓存 - 添加 gemini-cli Code Assist 客户端封装 --- .../repository/gemini_oauth_client.go | 84 ++++++++++++++++ .../internal/repository/gemini_token_cache.go | 44 +++++++++ .../repository/geminicli_codeassist_client.go | 95 +++++++++++++++++++ 3 files changed, 223 insertions(+) create mode 100644 backend/internal/repository/gemini_oauth_client.go create mode 100644 backend/internal/repository/gemini_token_cache.go create mode 100644 backend/internal/repository/geminicli_codeassist_client.go diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go new file mode 100644 index 00000000..07364368 --- /dev/null +++ b/backend/internal/repository/gemini_oauth_client.go @@ -0,0 +1,84 @@ +package repository + +import ( + "context" + "fmt" + "net/url" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/service/ports" + + "github.com/imroc/req/v3" +) + +type geminiOAuthClient struct { + tokenURL string + cfg *config.Config +} + +func NewGeminiOAuthClient(cfg *config.Config) ports.GeminiOAuthClient { + return &geminiOAuthClient{ + tokenURL: geminicli.TokenURL, + cfg: cfg, + } +} + +func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) { + client := createGeminiReqClient(proxyURL) + + formData := url.Values{} + formData.Set("grant_type", "authorization_code") + formData.Set("client_id", c.cfg.Gemini.OAuth.ClientID) + formData.Set("client_secret", c.cfg.Gemini.OAuth.ClientSecret) + formData.Set("code", code) + formData.Set("code_verifier", codeVerifier) + formData.Set("redirect_uri", redirectURI) + + var tokenResp geminicli.TokenResponse + resp, err := client.R(). + SetContext(ctx). + SetFormDataFromValues(formData). + SetSuccessResult(&tokenResp). + Post(c.tokenURL) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + if !resp.IsSuccessState() { + return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String())) + } + return &tokenResp, nil +} + +func (c *geminiOAuthClient) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + client := createGeminiReqClient(proxyURL) + + formData := url.Values{} + formData.Set("grant_type", "refresh_token") + formData.Set("refresh_token", refreshToken) + formData.Set("client_id", c.cfg.Gemini.OAuth.ClientID) + formData.Set("client_secret", c.cfg.Gemini.OAuth.ClientSecret) + + var tokenResp geminicli.TokenResponse + resp, err := client.R(). + SetContext(ctx). + SetFormDataFromValues(formData). + SetSuccessResult(&tokenResp). + Post(c.tokenURL) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + if !resp.IsSuccessState() { + return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String())) + } + return &tokenResp, nil +} + +func createGeminiReqClient(proxyURL string) *req.Client { + client := req.C().SetTimeout(60 * time.Second) + if proxyURL != "" { + client.SetProxyURL(proxyURL) + } + return client +} diff --git a/backend/internal/repository/gemini_token_cache.go b/backend/internal/repository/gemini_token_cache.go new file mode 100644 index 00000000..9d294605 --- /dev/null +++ b/backend/internal/repository/gemini_token_cache.go @@ -0,0 +1,44 @@ +package repository + +import ( + "context" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service/ports" + + "github.com/redis/go-redis/v9" +) + +const ( + geminiTokenKeyPrefix = "gemini:token:" + geminiRefreshLockKeyPrefix = "gemini:refresh_lock:" +) + +type geminiTokenCache struct { + rdb *redis.Client +} + +func NewGeminiTokenCache(rdb *redis.Client) ports.GeminiTokenCache { + return &geminiTokenCache{rdb: rdb} +} + +func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) { + key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) + return c.rdb.Get(ctx, key).Result() +} + +func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error { + key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) + return c.rdb.Set(ctx, key, token, ttl).Err() +} + +func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) { + key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) + return c.rdb.SetNX(ctx, key, 1, ttl).Result() +} + +func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error { + key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go new file mode 100644 index 00000000..63f1719c --- /dev/null +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -0,0 +1,95 @@ +package repository + +import ( + "context" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/service/ports" + + "github.com/imroc/req/v3" +) + +type geminiCliCodeAssistClient struct { + baseURL string +} + +func NewGeminiCliCodeAssistClient() ports.GeminiCliCodeAssistClient { + return &geminiCliCodeAssistClient{baseURL: geminicli.GeminiCliBaseURL} +} + +func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + if reqBody == nil { + reqBody = defaultLoadCodeAssistRequest() + } + + var out geminicli.LoadCodeAssistResponse + resp, err := createGeminiCliReqClient(proxyURL).R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+accessToken). + SetHeader("Content-Type", "application/json"). + SetHeader("User-Agent", geminicli.GeminiCLIUserAgent). + SetBody(reqBody). + SetSuccessResult(&out). + Post(c.baseURL + "/v1internal:loadCodeAssist") + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + if !resp.IsSuccessState() { + return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String())) + } + return &out, nil +} + +func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) { + if reqBody == nil { + reqBody = defaultOnboardUserRequest() + } + + var out geminicli.OnboardUserResponse + resp, err := createGeminiCliReqClient(proxyURL).R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+accessToken). + SetHeader("Content-Type", "application/json"). + SetHeader("User-Agent", geminicli.GeminiCLIUserAgent). + SetBody(reqBody). + SetSuccessResult(&out). + Post(c.baseURL + "/v1internal:onboardUser") + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + if !resp.IsSuccessState() { + return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String())) + } + return &out, nil +} + +func createGeminiCliReqClient(proxyURL string) *req.Client { + client := req.C().SetTimeout(30 * time.Second) + if proxyURL != "" { + client.SetProxyURL(proxyURL) + } + return client +} + +func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest { + return &geminicli.LoadCodeAssistRequest{ + Metadata: geminicli.LoadCodeAssistMetadata{ + IDEType: "ANTIGRAVITY", + Platform: "PLATFORM_UNSPECIFIED", + PluginType: "GEMINI", + }, + } +} + +func defaultOnboardUserRequest() *geminicli.OnboardUserRequest { + return &geminicli.OnboardUserRequest{ + TierID: "LEGACY", + Metadata: geminicli.LoadCodeAssistMetadata{ + IDEType: "ANTIGRAVITY", + Platform: "PLATFORM_UNSPECIFIED", + PluginType: "GEMINI", + }, + } +} From 71c28e436a1156ac8a3db1924095f4427454f79c Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 25 Dec 2025 06:43:57 -0800 Subject: [PATCH 03/42] =?UTF-8?q?feat(service):=20=E5=AE=9A=E4=B9=89=20Gem?= =?UTF-8?q?ini=20=E6=9C=8D=E5=8A=A1=E7=AB=AF=E5=8F=A3=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 定义 OAuth 服务接口 - 定义 Token 缓存服务接口 - 定义 Code Assist 服务接口 --- backend/internal/service/gemini_oauth.go | 13 +++++++++++++ backend/internal/service/gemini_token_cache.go | 16 ++++++++++++++++ backend/internal/service/geminicli_codeassist.go | 13 +++++++++++++ 3 files changed, 42 insertions(+) create mode 100644 backend/internal/service/gemini_oauth.go create mode 100644 backend/internal/service/gemini_token_cache.go create mode 100644 backend/internal/service/geminicli_codeassist.go diff --git a/backend/internal/service/gemini_oauth.go b/backend/internal/service/gemini_oauth.go new file mode 100644 index 00000000..185d6c55 --- /dev/null +++ b/backend/internal/service/gemini_oauth.go @@ -0,0 +1,13 @@ +package ports + +import ( + "context" + + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" +) + +// GeminiOAuthClient performs Google OAuth token exchange/refresh for Gemini integration. +type GeminiOAuthClient interface { + ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) + RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) +} diff --git a/backend/internal/service/gemini_token_cache.go b/backend/internal/service/gemini_token_cache.go new file mode 100644 index 00000000..79a5f948 --- /dev/null +++ b/backend/internal/service/gemini_token_cache.go @@ -0,0 +1,16 @@ +package ports + +import ( + "context" + "time" +) + +// GeminiTokenCache stores short-lived access tokens and coordinates refresh to avoid stampedes. +type GeminiTokenCache interface { + // cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id. + GetAccessToken(ctx context.Context, cacheKey string) (string, error) + SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error + + AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) + ReleaseRefreshLock(ctx context.Context, cacheKey string) error +} diff --git a/backend/internal/service/geminicli_codeassist.go b/backend/internal/service/geminicli_codeassist.go new file mode 100644 index 00000000..2d742b24 --- /dev/null +++ b/backend/internal/service/geminicli_codeassist.go @@ -0,0 +1,13 @@ +package ports + +import ( + "context" + + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" +) + +// GeminiCliCodeAssistClient calls GeminiCli internal Code Assist endpoints. +type GeminiCliCodeAssistClient interface { + LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) + OnboardUser(ctx context.Context, accessToken, proxyURL string, req *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) +} From dc109827b7acd803540e1871c356e3542f61a2df Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 25 Dec 2025 06:44:18 -0800 Subject: [PATCH 04/42] =?UTF-8?q?feat(service):=20=E5=AE=9E=E7=8E=B0=20Gem?= =?UTF-8?q?ini=20OAuth=20=E5=92=8C=20Token=20=E7=AE=A1=E7=90=86=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 实现 OAuth 授权流程服务 - 添加 Token 提供者和自动刷新机制 - 实现 Gemini Messages API 兼容层 - 更新服务容器注册 --- .../service/gemini_messages_compat_service.go | 1298 +++++++++++++++++ .../internal/service/gemini_oauth_service.go | 305 ++++ .../internal/service/gemini_token_provider.go | 139 ++ .../service/gemini_token_refresher.go | 53 + 4 files changed, 1795 insertions(+) create mode 100644 backend/internal/service/gemini_messages_compat_service.go create mode 100644 backend/internal/service/gemini_oauth_service.go create mode 100644 backend/internal/service/gemini_token_provider.go create mode 100644 backend/internal/service/gemini_token_refresher.go diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go new file mode 100644 index 00000000..49fe7135 --- /dev/null +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -0,0 +1,1298 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math" + mathrand "math/rand" + "net/http" + "regexp" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/service/ports" + + "github.com/gin-gonic/gin" +) + +const geminiStickySessionTTL = time.Hour + +const ( + geminiMaxRetries = 5 + geminiRetryBaseDelay = 1 * time.Second + geminiRetryMaxDelay = 16 * time.Second +) + +type GeminiMessagesCompatService struct { + accountRepo ports.AccountRepository + cache ports.GatewayCache + tokenProvider *GeminiTokenProvider + rateLimitService *RateLimitService + httpUpstream ports.HTTPUpstream +} + +func NewGeminiMessagesCompatService( + accountRepo ports.AccountRepository, + cache ports.GatewayCache, + tokenProvider *GeminiTokenProvider, + rateLimitService *RateLimitService, + httpUpstream ports.HTTPUpstream, +) *GeminiMessagesCompatService { + return &GeminiMessagesCompatService{ + accountRepo: accountRepo, + cache: cache, + tokenProvider: tokenProvider, + rateLimitService: rateLimitService, + httpUpstream: httpUpstream, + } +} + +func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { + cacheKey := "gemini:" + sessionHash + if sessionHash != "" { + accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) + if err == nil && accountID > 0 { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err == nil && account.IsSchedulable() && account.Platform == model.PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { + _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) + return account, nil + } + } + } + + var accounts []model.Account + var err error + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini) + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini) + } + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + + var selected *model.Account + for i := range accounts { + acc := &accounts[i] + if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + continue + } + if selected == nil { + selected = acc + continue + } + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) { + selected = acc + } + } + } + + if selected == nil { + if requestedModel != "" { + return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel) + } + return nil, errors.New("no available Gemini accounts") + } + + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL) + } + + return selected, nil +} + +func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + + var req struct { + Model string `json:"model"` + Stream bool `json:"stream"` + } + if err := json.Unmarshal(body, &req); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + if strings.TrimSpace(req.Model) == "" { + return nil, fmt.Errorf("missing model") + } + + originalModel := req.Model + mappedModel := req.Model + if account.Type == model.AccountTypeApiKey { + mappedModel = account.GetMappedModel(req.Model) + } + + geminiReq, err := convertClaudeMessagesToGeminiGenerateContent(body) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + var requestIDHeader string + var buildReq func(ctx context.Context) (*http.Request, string, error) + + switch account.Type { + case model.AccountTypeApiKey: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + apiKey := account.GetCredential("api_key") + if strings.TrimSpace(apiKey) == "" { + return nil, "", errors.New("Gemini api_key not configured") + } + + baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") + if baseURL == "" { + baseURL = geminicli.AIStudioBaseURL + } + + action := "generateContent" + if req.Stream { + action = "streamGenerateContent" + } + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action) + if req.Stream { + fullURL += "?alt=sse" + } + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("x-goog-api-key", apiKey) + return upstreamReq, "x-request-id", nil + } + requestIDHeader = "x-request-id" + + case model.AccountTypeOAuth: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("Gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + projectID := strings.TrimSpace(account.GetCredential("project_id")) + if projectID == "" { + return nil, "", errors.New("missing project_id in account credentials") + } + + action := "generateContent" + if req.Stream { + action = "streamGenerateContent" + } + fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action) + if req.Stream { + fullURL += "?alt=sse" + } + + wrapped := map[string]any{ + "model": mappedModel, + "project": projectID, + } + var inner any + if err := json.Unmarshal(geminiReq, &inner); err != nil { + return nil, "", fmt.Errorf("failed to parse gemini request: %w", err) + } + wrapped["request"] = inner + wrappedBytes, _ := json.Marshal(wrapped) + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) + return upstreamReq, "x-request-id", nil + } + requestIDHeader = "x-request-id" + + default: + return nil, fmt.Errorf("unsupported account type: %s", account.Type) + } + + if buildReq == nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Gemini upstream not configured") + } + + var resp *http.Response + for attempt := 1; attempt <= geminiMaxRetries; attempt++ { + upstreamReq, idHeader, err := buildReq(ctx) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + // Local build error: don't retry. + if strings.Contains(err.Error(), "missing project_id") { + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error()) + } + requestIDHeader = idHeader + + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) + if err != nil { + if attempt < geminiMaxRetries { + log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + sleepGeminiBackoff(attempt) + continue + } + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") + } + + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + if resp.StatusCode == 429 { + // Mark as rate-limited early so concurrent requests avoid this account. + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + if attempt < geminiMaxRetries { + log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + sleepGeminiBackoff(attempt) + continue + } + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") + } + + break + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody) + } + + requestID := resp.Header.Get(requestIDHeader) + if requestID == "" { + requestID = resp.Header.Get("x-goog-request-id") + } + if requestID != "" { + c.Header("x-request-id", requestID) + } + + var usage *ClaudeUsage + var firstTokenMs *int + if req.Stream { + streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel) + if err != nil { + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else { + usage, err = s.handleNonStreamingResponse(c, resp, originalModel) + if err != nil { + return nil, err + } + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + Stream: req.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *model.Account, statusCode int) bool { + switch statusCode { + case 429, 500, 502, 503, 504, 529: + return true + case 403: + // GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry. + return account != nil && account.Type == model.AccountTypeOAuth + default: + return false + } +} + +func sleepGeminiBackoff(attempt int) { + delay := geminiRetryBaseDelay * time.Duration(1< geminiRetryMaxDelay { + delay = geminiRetryMaxDelay + } + + // +/- 20% jitter + r := mathrand.New(mathrand.NewSource(time.Now().UnixNano())) + jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1)) + sleepFor := delay + jitter + if sleepFor < 0 { + sleepFor = 0 + } + time.Sleep(sleepFor) +} + +func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, upstreamStatus int, body []byte) error { + var statusCode int + var errType, errMsg string + + if mapped := mapGeminiErrorBodyToClaudeError(body); mapped != nil { + errType = mapped.Type + if mapped.Message != "" { + errMsg = mapped.Message + } + if mapped.StatusCode > 0 { + statusCode = mapped.StatusCode + } + } + + switch upstreamStatus { + case 400: + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + if errType == "" { + errType = "invalid_request_error" + } + if errMsg == "" { + errMsg = "Invalid request" + } + case 401: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + errType = "authentication_error" + } + if errMsg == "" { + errMsg = "Upstream authentication failed, please contact administrator" + } + case 403: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + errType = "permission_error" + } + if errMsg == "" { + errMsg = "Upstream access forbidden, please contact administrator" + } + case 404: + if statusCode == 0 { + statusCode = http.StatusNotFound + } + if errType == "" { + errType = "not_found_error" + } + if errMsg == "" { + errMsg = "Resource not found" + } + case 429: + if statusCode == 0 { + statusCode = http.StatusTooManyRequests + } + if errType == "" { + errType = "rate_limit_error" + } + if errMsg == "" { + errMsg = "Upstream rate limit exceeded, please retry later" + } + case 529: + if statusCode == 0 { + statusCode = http.StatusServiceUnavailable + } + if errType == "" { + errType = "overloaded_error" + } + if errMsg == "" { + errMsg = "Upstream service overloaded, please retry later" + } + case 500, 502, 503, 504: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + switch upstreamStatus { + case 504: + errType = "timeout_error" + case 503: + errType = "overloaded_error" + default: + errType = "api_error" + } + } + if errMsg == "" { + errMsg = "Upstream service temporarily unavailable" + } + default: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + errType = "upstream_error" + } + if errMsg == "" { + errMsg = "Upstream request failed" + } + } + + c.JSON(statusCode, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": errMsg}, + }) + return fmt.Errorf("upstream error: %d", upstreamStatus) +} + +type claudeErrorMapping struct { + Type string + Message string + StatusCode int +} + +func mapGeminiErrorBodyToClaudeError(body []byte) *claudeErrorMapping { + if len(body) == 0 { + return nil + } + + var parsed struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + } `json:"error"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return nil + } + if strings.TrimSpace(parsed.Error.Status) == "" && parsed.Error.Code == 0 && strings.TrimSpace(parsed.Error.Message) == "" { + return nil + } + + mapped := &claudeErrorMapping{ + Type: mapGeminiStatusToClaudeErrorType(parsed.Error.Status), + Message: "", + } + if mapped.Type == "" { + mapped.Type = "upstream_error" + } + + switch strings.ToUpper(strings.TrimSpace(parsed.Error.Status)) { + case "INVALID_ARGUMENT": + mapped.StatusCode = http.StatusBadRequest + case "NOT_FOUND": + mapped.StatusCode = http.StatusNotFound + case "RESOURCE_EXHAUSTED": + mapped.StatusCode = http.StatusTooManyRequests + default: + // Keep StatusCode unset and let HTTP status mapping decide. + } + + // Keep messages generic by default; upstream error message can be long or include sensitive fragments. + return mapped +} + +func mapGeminiStatusToClaudeErrorType(status string) string { + switch strings.ToUpper(strings.TrimSpace(status)) { + case "INVALID_ARGUMENT": + return "invalid_request_error" + case "PERMISSION_DENIED": + return "permission_error" + case "NOT_FOUND": + return "not_found_error" + case "RESOURCE_EXHAUSTED": + return "rate_limit_error" + case "UNAUTHENTICATED": + return "authentication_error" + case "UNAVAILABLE": + return "overloaded_error" + case "INTERNAL": + return "api_error" + case "DEADLINE_EXCEEDED": + return "timeout_error" + default: + return "" + } +} + +type geminiStreamResult struct { + usage *ClaudeUsage + firstTokenMs *int +} + +func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) { + body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response") + } + + geminiResp, err := unwrapGeminiResponse(body) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel) + c.JSON(http.StatusOK, claudeResp) + + return usage, nil +} + +func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*geminiStreamResult, error) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + messageID := "msg_" + randomHex(12) + messageStart := map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": originalModel, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": 0, + "output_tokens": 0, + }, + }, + } + writeSSE(c.Writer, "message_start", messageStart) + flusher.Flush() + + var firstTokenMs *int + var usage ClaudeUsage + finishReason := "" + sawToolUse := false + + nextBlockIndex := 0 + openBlockIndex := -1 + openBlockType := "" + seenText := "" + + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadString('\n') + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, fmt.Errorf("stream read error: %w", err) + } + + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" || payload == "[DONE]" { + continue + } + + geminiResp, err := unwrapGeminiResponse([]byte(payload)) + if err != nil { + continue + } + + if fr := extractGeminiFinishReason(geminiResp); fr != "" { + finishReason = fr + } + + parts := extractGeminiParts(geminiResp) + for _, part := range parts { + if text, ok := part["text"].(string); ok && text != "" { + delta, newSeen := computeGeminiTextDelta(seenText, text) + seenText = newSeen + if delta == "" { + continue + } + + if openBlockType != "text" { + if openBlockIndex >= 0 { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openBlockIndex, + }) + } + openBlockType = "text" + openBlockIndex = nextBlockIndex + nextBlockIndex++ + writeSSE(c.Writer, "content_block_start", map[string]any{ + "type": "content_block_start", + "index": openBlockIndex, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + } + + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + writeSSE(c.Writer, "content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": openBlockIndex, + "delta": map[string]any{ + "type": "text_delta", + "text": delta, + }, + }) + flusher.Flush() + continue + } + + if fc, ok := part["functionCall"].(map[string]any); ok && fc != nil { + name, _ := fc["name"].(string) + args := fc["args"] + if strings.TrimSpace(name) == "" { + name = "tool" + } + + // Close any open block before tool_use. + if openBlockIndex >= 0 { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openBlockIndex, + }) + openBlockIndex = -1 + openBlockType = "" + } + + toolID := "toolu_" + randomHex(8) + toolIndex := nextBlockIndex + nextBlockIndex++ + sawToolUse = true + + writeSSE(c.Writer, "content_block_start", map[string]any{ + "type": "content_block_start", + "index": toolIndex, + "content_block": map[string]any{ + "type": "tool_use", + "id": toolID, + "name": name, + "input": map[string]any{}, + }, + }) + + argsJSON := "{}" + if args != nil { + if b, err := json.Marshal(args); err == nil { + argsJSON = string(b) + } + } + writeSSE(c.Writer, "content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": toolIndex, + "delta": map[string]any{ + "type": "input_json_delta", + "partial_json": argsJSON, + }, + }) + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": toolIndex, + }) + flusher.Flush() + } + } + + if u := extractGeminiUsage(geminiResp); u != nil { + usage = *u + } + } + + if openBlockIndex >= 0 { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openBlockIndex, + }) + } + + stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason) + if sawToolUse { + stopReason = "tool_use" + } + + usageObj := map[string]any{ + "output_tokens": usage.OutputTokens, + } + if usage.InputTokens > 0 { + usageObj["input_tokens"] = usage.InputTokens + } + writeSSE(c.Writer, "message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": usageObj, + }) + writeSSE(c.Writer, "message_stop", map[string]any{ + "type": "message_stop", + }) + flusher.Flush() + + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil +} + +func writeSSE(w io.Writer, event string, data any) { + if event != "" { + _, _ = fmt.Fprintf(w, "event: %s\n", event) + } + b, _ := json.Marshal(data) + _, _ = fmt.Fprintf(w, "data: %s\n\n", string(b)) +} + +func randomHex(nBytes int) string { + b := make([]byte, nBytes) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + +func (s *GeminiMessagesCompatService) writeClaudeError(c *gin.Context, status int, errType, message string) error { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": message}, + }) + return fmt.Errorf("%s", message) +} + +func unwrapGeminiResponse(raw []byte) (map[string]any, error) { + var outer map[string]any + if err := json.Unmarshal(raw, &outer); err != nil { + return nil, err + } + if resp, ok := outer["response"].(map[string]any); ok && resp != nil { + return resp, nil + } + return outer, nil +} + +func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string) (map[string]any, *ClaudeUsage) { + usage := extractGeminiUsage(geminiResp) + if usage == nil { + usage = &ClaudeUsage{} + } + + contentBlocks := make([]any, 0) + sawToolUse := false + if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { + if cand, ok := candidates[0].(map[string]any); ok { + if content, ok := cand["content"].(map[string]any); ok { + if parts, ok := content["parts"].([]any); ok { + for _, part := range parts { + pm, ok := part.(map[string]any) + if !ok { + continue + } + if text, ok := pm["text"].(string); ok && text != "" { + contentBlocks = append(contentBlocks, map[string]any{ + "type": "text", + "text": text, + }) + } + if fc, ok := pm["functionCall"].(map[string]any); ok { + name, _ := fc["name"].(string) + if strings.TrimSpace(name) == "" { + name = "tool" + } + args := fc["args"] + sawToolUse = true + contentBlocks = append(contentBlocks, map[string]any{ + "type": "tool_use", + "id": "toolu_" + randomHex(8), + "name": name, + "input": args, + }) + } + } + } + } + } + } + + stopReason := mapGeminiFinishReasonToClaudeStopReason(extractGeminiFinishReason(geminiResp)) + if sawToolUse { + stopReason = "tool_use" + } + + resp := map[string]any{ + "id": "msg_" + randomHex(12), + "type": "message", + "role": "assistant", + "model": originalModel, + "content": contentBlocks, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": usage.InputTokens, + "output_tokens": usage.OutputTokens, + }, + } + + return resp, usage +} + +func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage { + usageMeta, ok := geminiResp["usageMetadata"].(map[string]any) + if !ok || usageMeta == nil { + return nil + } + prompt, _ := asInt(usageMeta["promptTokenCount"]) + cand, _ := asInt(usageMeta["candidatesTokenCount"]) + return &ClaudeUsage{ + InputTokens: prompt, + OutputTokens: cand, + } +} + +func asInt(v any) (int, bool) { + switch t := v.(type) { + case float64: + return int(t), true + case int: + return t, true + case int64: + return int(t), true + case json.Number: + i, err := t.Int64() + if err != nil { + return 0, false + } + return int(i), true + default: + return 0, false + } +} + +func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, body []byte) { + if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) { + s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) + return + } + if statusCode != 429 { + return + } + resetAt := parseGeminiRateLimitResetTime(body) + if resetAt == nil { + ra := time.Now().Add(5 * time.Minute) + _ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) + return + } + _ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0)) +} + +func parseGeminiRateLimitResetTime(body []byte) *int64 { + // Try to parse metadata.quotaResetDelay like "12.345s" + var parsed map[string]any + if err := json.Unmarshal(body, &parsed); err == nil { + if errObj, ok := parsed["error"].(map[string]any); ok { + if msg, ok := errObj["message"].(string); ok { + if looksLikeGeminiDailyQuota(msg) { + if ts := nextGeminiDailyResetUnix(); ts != nil { + return ts + } + } + } + if details, ok := errObj["details"].([]any); ok { + for _, d := range details { + dm, ok := d.(map[string]any) + if !ok { + continue + } + if meta, ok := dm["metadata"].(map[string]any); ok { + if v, ok := meta["quotaResetDelay"].(string); ok { + if dur, err := time.ParseDuration(v); err == nil { + ts := time.Now().Unix() + int64(dur.Seconds()) + return &ts + } + } + } + } + } + } + } + + // Match "Please retry in Xs" + retryInRegex := regexp.MustCompile(`Please retry in ([0-9.]+)s`) + matches := retryInRegex.FindStringSubmatch(string(body)) + if len(matches) == 2 { + if dur, err := time.ParseDuration(matches[1] + "s"); err == nil { + ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds())) + return &ts + } + } + + return nil +} + +func looksLikeGeminiDailyQuota(message string) bool { + m := strings.ToLower(message) + if strings.Contains(m, "per day") || strings.Contains(m, "requests per day") || strings.Contains(m, "quota") && strings.Contains(m, "per day") { + return true + } + return false +} + +func nextGeminiDailyResetUnix() *int64 { + loc, err := time.LoadLocation("America/Los_Angeles") + if err != nil { + // Fallback: PST without DST. + loc = time.FixedZone("PST", -8*3600) + } + now := time.Now().In(loc) + reset := time.Date(now.Year(), now.Month(), now.Day(), 0, 5, 0, 0, loc) + if !reset.After(now) { + reset = reset.Add(24 * time.Hour) + } + ts := reset.Unix() + return &ts +} + +func extractGeminiFinishReason(geminiResp map[string]any) string { + if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { + if cand, ok := candidates[0].(map[string]any); ok { + if fr, ok := cand["finishReason"].(string); ok { + return fr + } + } + } + return "" +} + +func extractGeminiParts(geminiResp map[string]any) []map[string]any { + if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { + if cand, ok := candidates[0].(map[string]any); ok { + if content, ok := cand["content"].(map[string]any); ok { + if partsAny, ok := content["parts"].([]any); ok && len(partsAny) > 0 { + out := make([]map[string]any, 0, len(partsAny)) + for _, p := range partsAny { + pm, ok := p.(map[string]any) + if !ok { + continue + } + out = append(out, pm) + } + return out + } + } + } + } + return nil +} + +func computeGeminiTextDelta(seen, incoming string) (delta, newSeen string) { + incoming = strings.TrimSuffix(incoming, "\u0000") + if incoming == "" { + return "", seen + } + + // Cumulative mode: incoming contains full text so far. + if strings.HasPrefix(incoming, seen) { + return strings.TrimPrefix(incoming, seen), incoming + } + // Duplicate/rewind: ignore. + if strings.HasPrefix(seen, incoming) { + return "", seen + } + // Delta mode: treat incoming as incremental chunk. + return incoming, seen + incoming +} + +func mapGeminiFinishReasonToClaudeStopReason(finishReason string) string { + switch strings.ToUpper(strings.TrimSpace(finishReason)) { + case "MAX_TOKENS": + return "max_tokens" + case "STOP": + return "end_turn" + default: + return "end_turn" + } +} + +func convertClaudeMessagesToGeminiGenerateContent(body []byte) ([]byte, error) { + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + toolUseIDToName := make(map[string]string) + + systemText := extractClaudeSystemText(req["system"]) + contents, err := convertClaudeMessagesToGeminiContents(req["messages"], toolUseIDToName) + if err != nil { + return nil, err + } + + out := make(map[string]any) + if systemText != "" { + out["systemInstruction"] = map[string]any{ + "parts": []any{map[string]any{"text": systemText}}, + } + } + out["contents"] = contents + + if tools := convertClaudeToolsToGeminiTools(req["tools"]); tools != nil { + out["tools"] = tools + } + + generationConfig := convertClaudeGenerationConfig(req) + if generationConfig != nil { + out["generationConfig"] = generationConfig + } + + stripGeminiFunctionIDs(out) + return json.Marshal(out) +} + +func stripGeminiFunctionIDs(req map[string]any) { + // Defensive cleanup: some upstreams reject unexpected `id` fields in functionCall/functionResponse. + contents, ok := req["contents"].([]any) + if !ok { + return + } + for _, c := range contents { + cm, ok := c.(map[string]any) + if !ok { + continue + } + contentParts, ok := cm["parts"].([]any) + if !ok { + continue + } + for _, p := range contentParts { + pm, ok := p.(map[string]any) + if !ok { + continue + } + if fc, ok := pm["functionCall"].(map[string]any); ok && fc != nil { + delete(fc, "id") + } + if fr, ok := pm["functionResponse"].(map[string]any); ok && fr != nil { + delete(fr, "id") + } + } + } +} + +func extractClaudeSystemText(system any) string { + switch v := system.(type) { + case string: + return strings.TrimSpace(v) + case []any: + var parts []string + for _, p := range v { + pm, ok := p.(map[string]any) + if !ok { + continue + } + if t, _ := pm["type"].(string); t != "text" { + continue + } + if text, ok := pm["text"].(string); ok && strings.TrimSpace(text) != "" { + parts = append(parts, text) + } + } + return strings.TrimSpace(strings.Join(parts, "\n")) + default: + return "" + } +} + +func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[string]string) ([]any, error) { + arr, ok := messages.([]any) + if !ok { + return nil, errors.New("messages must be an array") + } + + out := make([]any, 0, len(arr)) + for _, m := range arr { + mm, ok := m.(map[string]any) + if !ok { + continue + } + role, _ := mm["role"].(string) + role = strings.ToLower(strings.TrimSpace(role)) + gRole := "user" + if role == "assistant" { + gRole = "model" + } + + parts := make([]any, 0) + switch content := mm["content"].(type) { + case string: + if strings.TrimSpace(content) != "" { + parts = append(parts, map[string]any{"text": content}) + } + case []any: + for _, block := range content { + bm, ok := block.(map[string]any) + if !ok { + continue + } + bt, _ := bm["type"].(string) + switch bt { + case "text": + if text, ok := bm["text"].(string); ok && strings.TrimSpace(text) != "" { + parts = append(parts, map[string]any{"text": text}) + } + case "tool_use": + id, _ := bm["id"].(string) + name, _ := bm["name"].(string) + if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" { + toolUseIDToName[id] = name + } + parts = append(parts, map[string]any{ + "functionCall": map[string]any{ + "name": name, + "args": bm["input"], + }, + }) + case "tool_result": + toolUseID, _ := bm["tool_use_id"].(string) + name := toolUseIDToName[toolUseID] + if name == "" { + name = "tool" + } + parts = append(parts, map[string]any{ + "functionResponse": map[string]any{ + "name": name, + "response": map[string]any{ + "content": extractClaudeContentText(bm["content"]), + }, + }, + }) + case "image": + if src, ok := bm["source"].(map[string]any); ok { + if srcType, _ := src["type"].(string); srcType == "base64" { + mediaType, _ := src["media_type"].(string) + data, _ := src["data"].(string) + if mediaType != "" && data != "" { + parts = append(parts, map[string]any{ + "inlineData": map[string]any{ + "mimeType": mediaType, + "data": data, + }, + }) + } + } + } + default: + // best-effort: preserve unknown blocks as text + if b, err := json.Marshal(bm); err == nil { + parts = append(parts, map[string]any{"text": string(b)}) + } + } + } + default: + // ignore + } + + out = append(out, map[string]any{ + "role": gRole, + "parts": parts, + }) + } + return out, nil +} + +func extractClaudeContentText(v any) string { + switch t := v.(type) { + case string: + return t + case []any: + var sb strings.Builder + for _, part := range t { + pm, ok := part.(map[string]any) + if !ok { + continue + } + if pm["type"] == "text" { + if text, ok := pm["text"].(string); ok { + sb.WriteString(text) + } + } + } + return sb.String() + default: + b, _ := json.Marshal(t) + return string(b) + } +} + +func convertClaudeToolsToGeminiTools(tools any) []any { + arr, ok := tools.([]any) + if !ok || len(arr) == 0 { + return nil + } + + funcDecls := make([]any, 0, len(arr)) + for _, t := range arr { + tm, ok := t.(map[string]any) + if !ok { + continue + } + name, _ := tm["name"].(string) + desc, _ := tm["description"].(string) + params := tm["input_schema"] + if name == "" { + continue + } + funcDecls = append(funcDecls, map[string]any{ + "name": name, + "description": desc, + "parameters": params, + }) + } + + if len(funcDecls) == 0 { + return nil + } + return []any{ + map[string]any{ + "functionDeclarations": funcDecls, + }, + } +} + +func convertClaudeGenerationConfig(req map[string]any) map[string]any { + out := make(map[string]any) + if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { + out["maxOutputTokens"] = mt + } + if temp, ok := req["temperature"].(float64); ok { + out["temperature"] = temp + } + if topP, ok := req["top_p"].(float64); ok { + out["topP"] = topP + } + if stopSeq, ok := req["stop_sequences"].([]any); ok && len(stopSeq) > 0 { + out["stopSequences"] = stopSeq + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go new file mode 100644 index 00000000..067a2455 --- /dev/null +++ b/backend/internal/service/gemini_oauth_service.go @@ -0,0 +1,305 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/service/ports" +) + +type GeminiOAuthService struct { + sessionStore *geminicli.SessionStore + proxyRepo ports.ProxyRepository + oauthClient ports.GeminiOAuthClient + codeAssist ports.GeminiCliCodeAssistClient + cfg *config.Config +} + +func NewGeminiOAuthService( + proxyRepo ports.ProxyRepository, + oauthClient ports.GeminiOAuthClient, + codeAssist ports.GeminiCliCodeAssistClient, + cfg *config.Config, +) *GeminiOAuthService { + return &GeminiOAuthService{ + sessionStore: geminicli.NewSessionStore(), + proxyRepo: proxyRepo, + oauthClient: oauthClient, + codeAssist: codeAssist, + cfg: cfg, + } +} + +type GeminiAuthURLResult struct { + AuthURL string `json:"auth_url"` + SessionID string `json:"session_id"` + State string `json:"state"` +} + +func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*GeminiAuthURLResult, error) { + state, err := geminicli.GenerateState() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + codeVerifier, err := geminicli.GenerateCodeVerifier() + if err != nil { + return nil, fmt.Errorf("failed to generate code verifier: %w", err) + } + codeChallenge := geminicli.GenerateCodeChallenge(codeVerifier) + sessionID, err := geminicli.GenerateSessionID() + if err != nil { + return nil, fmt.Errorf("failed to generate session ID: %w", err) + } + + var proxyURL string + if proxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) + if err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + + session := &geminicli.OAuthSession{ + State: state, + CodeVerifier: codeVerifier, + ProxyURL: proxyURL, + RedirectURI: redirectURI, + CreatedAt: time.Now(), + } + s.sessionStore.Set(sessionID, session) + + oauthCfg := geminicli.OAuthConfig{ + ClientID: s.cfg.Gemini.OAuth.ClientID, + ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, + Scopes: s.cfg.Gemini.OAuth.Scopes, + } + + authURL, err := geminicli.BuildAuthorizationURL(oauthCfg, state, codeChallenge, redirectURI) + if err != nil { + return nil, err + } + + return &GeminiAuthURLResult{ + AuthURL: authURL, + SessionID: sessionID, + State: state, + }, nil +} + +type GeminiExchangeCodeInput struct { + SessionID string + State string + Code string + RedirectURI string + ProxyID *int64 +} + +type GeminiTokenInfo struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + ExpiresAt int64 `json:"expires_at"` + TokenType string `json:"token_type"` + Scope string `json:"scope,omitempty"` + ProjectID string `json:"project_id,omitempty"` +} + +func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { + session, ok := s.sessionStore.Get(input.SessionID) + if !ok { + return nil, fmt.Errorf("session not found or expired") + } + if strings.TrimSpace(input.State) == "" || input.State != session.State { + return nil, fmt.Errorf("invalid state") + } + + proxyURL := session.ProxyURL + if input.ProxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID) + if err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + + redirectURI := session.RedirectURI + if strings.TrimSpace(input.RedirectURI) != "" { + redirectURI = input.RedirectURI + } + + tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL) + if err != nil { + return nil, fmt.Errorf("failed to exchange code: %w", err) + } + s.sessionStore.Delete(input.SessionID) + + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn + projectID, _ := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + + return &GeminiTokenInfo{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + TokenType: tokenResp.TokenType, + ExpiresIn: tokenResp.ExpiresIn, + ExpiresAt: expiresAt, + Scope: tokenResp.Scope, + ProjectID: projectID, + }, nil +} + +func (s *GeminiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*GeminiTokenInfo, error) { + var lastErr error + + for attempt := 0; attempt <= 3; attempt++ { + if attempt > 0 { + backoff := time.Duration(1< 30*time.Second { + backoff = 30 * time.Second + } + time.Sleep(backoff) + } + + tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) + if err == nil { + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn + return &GeminiTokenInfo{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + TokenType: tokenResp.TokenType, + ExpiresIn: tokenResp.ExpiresIn, + ExpiresAt: expiresAt, + Scope: tokenResp.Scope, + }, nil + } + + if isNonRetryableGeminiOAuthError(err) { + return nil, err + } + lastErr = err + } + + return nil, fmt.Errorf("token refresh failed after retries: %w", lastErr) +} + +func isNonRetryableGeminiOAuthError(err error) bool { + msg := err.Error() + nonRetryable := []string{ + "invalid_grant", + "invalid_client", + "unauthorized_client", + "access_denied", + } + for _, needle := range nonRetryable { + if strings.Contains(msg, needle) { + return true + } + } + return false +} + +func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*GeminiTokenInfo, error) { + if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth { + return nil, fmt.Errorf("account is not a Gemini OAuth account") + } + + refreshToken := account.GetCredential("refresh_token") + if strings.TrimSpace(refreshToken) == "" { + return nil, fmt.Errorf("no refresh token available") + } + + var proxyURL string + if account.ProxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID) + if err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + + return s.RefreshToken(ctx, refreshToken, proxyURL) +} + +func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) map[string]any { + creds := map[string]any{ + "access_token": tokenInfo.AccessToken, + "expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10), + } + if tokenInfo.RefreshToken != "" { + creds["refresh_token"] = tokenInfo.RefreshToken + } + if tokenInfo.TokenType != "" { + creds["token_type"] = tokenInfo.TokenType + } + if tokenInfo.Scope != "" { + creds["scope"] = tokenInfo.Scope + } + if tokenInfo.ProjectID != "" { + creds["project_id"] = tokenInfo.ProjectID + } + return creds +} + +func (s *GeminiOAuthService) Stop() { + s.sessionStore.Stop() +} + +func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) { + if s.codeAssist == nil { + return "", errors.New("code assist client not configured") + } + + loadResp, err := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) + if err == nil && strings.TrimSpace(loadResp.CurrentTier) != "" && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" { + return strings.TrimSpace(loadResp.CloudAICompanionProject), nil + } + + // pick default tier from allowedTiers, fallback to LEGACY. + tierID := "LEGACY" + if loadResp != nil { + for _, tier := range loadResp.AllowedTiers { + if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { + tierID = tier.ID + break + } + } + } + + req := &geminicli.OnboardUserRequest{ + TierID: tierID, + Metadata: geminicli.LoadCodeAssistMetadata{ + IDEType: "ANTIGRAVITY", + Platform: "PLATFORM_UNSPECIFIED", + PluginType: "GEMINI", + }, + } + + maxAttempts := 5 + for attempt := 1; attempt <= maxAttempts; attempt++ { + resp, err := s.codeAssist.OnboardUser(ctx, accessToken, proxyURL, req) + if err != nil { + return "", err + } + if resp.Done { + if resp.Response == nil || resp.Response.CloudAICompanionProject == nil { + return "", errors.New("onboardUser completed but no project_id returned") + } + switch v := resp.Response.CloudAICompanionProject.(type) { + case string: + return strings.TrimSpace(v), nil + case map[string]any: + if id, ok := v["id"].(string); ok { + return strings.TrimSpace(id), nil + } + } + return "", errors.New("onboardUser returned unsupported project_id format") + } + time.Sleep(2 * time.Second) + } + + return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) +} diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go new file mode 100644 index 00000000..51a2f54a --- /dev/null +++ b/backend/internal/service/gemini_token_provider.go @@ -0,0 +1,139 @@ +package service + +import ( + "context" + "errors" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service/ports" +) + +const ( + geminiTokenRefreshSkew = 3 * time.Minute + geminiTokenCacheSkew = 5 * time.Minute +) + +type GeminiTokenProvider struct { + accountRepo ports.AccountRepository + tokenCache ports.GeminiTokenCache + geminiOAuthService *GeminiOAuthService +} + +func NewGeminiTokenProvider( + accountRepo ports.AccountRepository, + tokenCache ports.GeminiTokenCache, + geminiOAuthService *GeminiOAuthService, +) *GeminiTokenProvider { + return &GeminiTokenProvider{ + accountRepo: accountRepo, + tokenCache: tokenCache, + geminiOAuthService: geminiOAuthService, + } +} + +func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model.Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth { + return "", errors.New("not a gemini oauth account") + } + + cacheKey := geminiTokenCacheKey(account) + + // 1) Try cache first. + if p.tokenCache != nil { + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + + // 2) Refresh if needed (pre-expiry skew). + expiresAt := parseExpiresAt(account) + needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew + if needsRefresh && p.tokenCache != nil { + locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if err == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + + // Re-check after lock (another worker may have refreshed). + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + + fresh, err := p.accountRepo.GetByID(ctx, account.ID) + if err == nil && fresh != nil { + account = fresh + } + expiresAt = parseExpiresAt(account) + if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew { + if p.geminiOAuthService == nil { + return "", errors.New("gemini oauth service not configured") + } + tokenInfo, err := p.geminiOAuthService.RefreshAccountToken(ctx, account) + if err != nil { + return "", err + } + newCredentials := p.geminiOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + account.Credentials = model.JSONB(newCredentials) + _ = p.accountRepo.Update(ctx, account) + expiresAt = parseExpiresAt(account) + } + } + } + + accessToken := account.GetCredential("access_token") + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("access_token not found in credentials") + } + + // 3) Populate cache with TTL. + if p.tokenCache != nil { + ttl := 30 * time.Minute + if expiresAt != nil { + until := time.Until(*expiresAt) + switch { + case until > geminiTokenCacheSkew: + ttl = until - geminiTokenCacheSkew + case until > 0: + ttl = until + default: + ttl = time.Minute + } + } + _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) + } + + return accessToken, nil +} + +func geminiTokenCacheKey(account *model.Account) string { + projectID := strings.TrimSpace(account.GetCredential("project_id")) + if projectID != "" { + return projectID + } + return "account:" + strconv.FormatInt(account.ID, 10) +} + +func parseExpiresAt(account *model.Account) *time.Time { + raw := strings.TrimSpace(account.GetCredential("expires_at")) + if raw == "" { + return nil + } + if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 { + t := time.Unix(unixSec, 0) + return &t + } + if t, err := time.Parse(time.RFC3339, raw); err == nil { + return &t + } + return nil +} diff --git a/backend/internal/service/gemini_token_refresher.go b/backend/internal/service/gemini_token_refresher.go new file mode 100644 index 00000000..25ad699d --- /dev/null +++ b/backend/internal/service/gemini_token_refresher.go @@ -0,0 +1,53 @@ +package service + +import ( + "context" + "strconv" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" +) + +type GeminiTokenRefresher struct { + geminiOAuthService *GeminiOAuthService +} + +func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiTokenRefresher { + return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService} +} + +func (r *GeminiTokenRefresher) CanRefresh(account *model.Account) bool { + return account.Platform == model.PlatformGemini && account.Type == model.AccountTypeOAuth +} + +func (r *GeminiTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool { + if !r.CanRefresh(account) { + return false + } + expiresAtStr := account.GetCredential("expires_at") + if expiresAtStr == "" { + return false + } + expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) + if err != nil { + return false + } + expiryTime := time.Unix(expiresAt, 0) + return time.Until(expiryTime) < refreshWindow +} + +func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) { + tokenInfo, err := r.geminiOAuthService.RefreshAccountToken(ctx, account) + if err != nil { + return nil, err + } + + newCredentials := r.geminiOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + + return newCredentials, nil +} From 55258bf099029f80c7e134b68f02069c807a2375 Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 25 Dec 2025 06:44:40 -0800 Subject: [PATCH 05/42] =?UTF-8?q?feat(service):=20=E6=89=A9=E5=B1=95=20CRS?= =?UTF-8?q?=20=E5=90=8C=E6=AD=A5=E5=92=8C=E5=AE=9A=E4=BB=B7=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E6=94=AF=E6=8C=81=20Gemini?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - CRS 同步服务新增 Gemini 账号同步逻辑(+273行) - 定价服务扩展 Gemini 模型定价计算(+99行) - 更新 Token 刷新服务集成 Gemini - 更新相关单元测试 --- .../repository/pricing_service_test.go | 4 +- backend/internal/service/crs_sync_service.go | 273 +++++++++++++++++- backend/internal/service/pricing_service.go | 99 ++++++- .../internal/service/token_refresh_service.go | 2 + 4 files changed, 360 insertions(+), 18 deletions(-) diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go index 8cfc8222..c51317a4 100644 --- a/backend/internal/repository/pricing_service_test.go +++ b/backend/internal/repository/pricing_service_test.go @@ -120,10 +120,9 @@ func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() { func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() { started := make(chan struct{}) - block := make(chan struct{}) s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { close(started) - <-block + <-r.Context().Done() })) ctx, cancel := context.WithCancel(s.ctx) @@ -136,7 +135,6 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() { <-started cancel() - close(block) err := <-done require.Error(s.T(), err) diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index e1f9d252..6a0241fb 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/url" + "strconv" "strings" "time" @@ -20,6 +21,7 @@ type CRSSyncService struct { proxyRepo ProxyRepository oauthService *OAuthService openaiOAuthService *OpenAIOAuthService + geminiOAuthService *GeminiOAuthService } func NewCRSSyncService( @@ -27,12 +29,14 @@ func NewCRSSyncService( proxyRepo ProxyRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, + geminiOAuthService *GeminiOAuthService, ) *CRSSyncService { return &CRSSyncService{ accountRepo: accountRepo, proxyRepo: proxyRepo, oauthService: oauthService, openaiOAuthService: openaiOAuthService, + geminiOAuthService: geminiOAuthService, } } @@ -77,6 +81,8 @@ type crsExportResponse struct { ClaudeConsoleAccounts []crsConsoleAccount `json:"claudeConsoleAccounts"` OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"` OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"` + GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"` + GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"` } `json:"data"` } @@ -149,6 +155,37 @@ type crsOpenAIOAuthAccount struct { Extra map[string]any `json:"extra"` } +type crsGeminiOAuthAccount struct { + Kind string `json:"kind"` + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Platform string `json:"platform"` + AuthType string `json:"authType"` // oauth + IsActive bool `json:"isActive"` + Schedulable bool `json:"schedulable"` + Priority int `json:"priority"` + Status string `json:"status"` + Proxy *crsProxy `json:"proxy"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra"` +} + +type crsGeminiAPIKeyAccount struct { + Kind string `json:"kind"` + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Platform string `json:"platform"` + IsActive bool `json:"isActive"` + Schedulable bool `json:"schedulable"` + Priority int `json:"priority"` + Status string `json:"status"` + Proxy *crsProxy `json:"proxy"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra"` +} + func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) { baseURL, err := normalizeBaseURL(input.BaseURL) if err != nil { @@ -176,7 +213,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput Items: make( []SyncFromCRSItemResult, 0, - len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts), + len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts)+len(exported.Data.GeminiOAuthAccounts)+len(exported.Data.GeminiAPIKeyAccounts), ), } @@ -680,6 +717,225 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput result.Items = append(result.Items, item) } + // Gemini OAuth -> sub2api gemini oauth + for _, src := range exported.Data.GeminiOAuthAccounts { + item := SyncFromCRSItemResult{ + CRSAccountID: src.ID, + Kind: src.Kind, + Name: src.Name, + } + + refreshToken, _ := src.Credentials["refresh_token"].(string) + if strings.TrimSpace(refreshToken) == "" { + item.Action = "failed" + item.Error = "missing refresh_token" + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name)) + if err != nil { + item.Action = "failed" + item.Error = "proxy sync failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + credentials := sanitizeCredentialsMap(src.Credentials) + if v, ok := credentials["token_type"].(string); !ok || strings.TrimSpace(v) == "" { + credentials["token_type"] = "Bearer" + } + // Convert expires_at from RFC3339 to Unix seconds string (recommended to keep consistent with GetCredential()) + if expiresAtStr, ok := credentials["expires_at"].(string); ok && strings.TrimSpace(expiresAtStr) != "" { + if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil { + credentials["expires_at"] = strconv.FormatInt(t.Unix(), 10) + } + } + + extra := make(map[string]any) + if src.Extra != nil { + for k, v := range src.Extra { + extra[k] = v + } + } + extra["crs_account_id"] = src.ID + extra["crs_kind"] = src.Kind + extra["crs_synced_at"] = now + + existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID) + if err != nil { + item.Action = "failed" + item.Error = "db lookup failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + if existing == nil { + account := &model.Account{ + Name: defaultName(src.Name, src.ID), + Platform: model.PlatformGemini, + Type: model.AccountTypeOAuth, + Credentials: model.JSONB(credentials), + Extra: model.JSONB(extra), + ProxyID: proxyID, + Concurrency: 3, + Priority: clampPriority(src.Priority), + Status: mapCRSStatus(src.IsActive, src.Status), + Schedulable: src.Schedulable, + } + if err := s.accountRepo.Create(ctx, account); err != nil { + item.Action = "failed" + item.Error = "create failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { + account.Credentials = refreshedCreds + _ = s.accountRepo.Update(ctx, account) + } + item.Action = "created" + result.Created++ + result.Items = append(result.Items, item) + continue + } + + existing.Extra = mergeJSONB(existing.Extra, extra) + existing.Name = defaultName(src.Name, src.ID) + existing.Platform = model.PlatformGemini + existing.Type = model.AccountTypeOAuth + existing.Credentials = mergeJSONB(existing.Credentials, credentials) + if proxyID != nil { + existing.ProxyID = proxyID + } + existing.Concurrency = 3 + existing.Priority = clampPriority(src.Priority) + existing.Status = mapCRSStatus(src.IsActive, src.Status) + existing.Schedulable = src.Schedulable + + if err := s.accountRepo.Update(ctx, existing); err != nil { + item.Action = "failed" + item.Error = "update failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { + existing.Credentials = refreshedCreds + _ = s.accountRepo.Update(ctx, existing) + } + + item.Action = "updated" + result.Updated++ + result.Items = append(result.Items, item) + } + + // Gemini API Key -> sub2api gemini apikey + for _, src := range exported.Data.GeminiAPIKeyAccounts { + item := SyncFromCRSItemResult{ + CRSAccountID: src.ID, + Kind: src.Kind, + Name: src.Name, + } + + apiKey, _ := src.Credentials["api_key"].(string) + if strings.TrimSpace(apiKey) == "" { + item.Action = "failed" + item.Error = "missing api_key" + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name)) + if err != nil { + item.Action = "failed" + item.Error = "proxy sync failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + credentials := sanitizeCredentialsMap(src.Credentials) + if baseURL, ok := credentials["base_url"].(string); !ok || strings.TrimSpace(baseURL) == "" { + credentials["base_url"] = "https://generativelanguage.googleapis.com" + } + + extra := make(map[string]any) + if src.Extra != nil { + for k, v := range src.Extra { + extra[k] = v + } + } + extra["crs_account_id"] = src.ID + extra["crs_kind"] = src.Kind + extra["crs_synced_at"] = now + + existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID) + if err != nil { + item.Action = "failed" + item.Error = "db lookup failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + if existing == nil { + account := &model.Account{ + Name: defaultName(src.Name, src.ID), + Platform: model.PlatformGemini, + Type: model.AccountTypeApiKey, + Credentials: model.JSONB(credentials), + Extra: model.JSONB(extra), + ProxyID: proxyID, + Concurrency: 3, + Priority: clampPriority(src.Priority), + Status: mapCRSStatus(src.IsActive, src.Status), + Schedulable: src.Schedulable, + } + if err := s.accountRepo.Create(ctx, account); err != nil { + item.Action = "failed" + item.Error = "create failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + item.Action = "created" + result.Created++ + result.Items = append(result.Items, item) + continue + } + + existing.Extra = mergeJSONB(existing.Extra, extra) + existing.Name = defaultName(src.Name, src.ID) + existing.Platform = model.PlatformGemini + existing.Type = model.AccountTypeApiKey + existing.Credentials = mergeJSONB(existing.Credentials, credentials) + if proxyID != nil { + existing.ProxyID = proxyID + } + existing.Concurrency = 3 + existing.Priority = clampPriority(src.Priority) + existing.Status = mapCRSStatus(src.IsActive, src.Status) + existing.Schedulable = src.Schedulable + + if err := s.accountRepo.Update(ctx, existing); err != nil { + item.Action = "failed" + item.Error = "update failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + item.Action = "updated" + result.Updated++ + result.Items = append(result.Items, item) + } + return result, nil } @@ -947,6 +1203,21 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A } } } + case model.PlatformGemini: + if s.geminiOAuthService == nil { + return nil + } + tokenInfo, refreshErr := s.geminiOAuthService.RefreshAccountToken(ctx, account) + if refreshErr != nil { + err = refreshErr + } else { + newCredentials = s.geminiOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + } default: return nil } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 56aab2bc..6d95f12f 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -393,27 +393,32 @@ func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing return nil } - // 标准化模型名称 - modelLower := strings.ToLower(modelName) + // 标准化模型名称(同时兼容 "models/xxx"、VertexAI 资源名等前缀) + modelLower := strings.ToLower(strings.TrimSpace(modelName)) + lookupCandidates := s.buildModelLookupCandidates(modelLower) // 1. 精确匹配 - if pricing, ok := s.pricingData[modelLower]; ok { - return pricing - } - if pricing, ok := s.pricingData[modelName]; ok { - return pricing + for _, candidate := range lookupCandidates { + if candidate == "" { + continue + } + if pricing, ok := s.pricingData[candidate]; ok { + return pricing + } } // 2. 处理常见的模型名称变体 // claude-opus-4-5-20251101 -> claude-opus-4.5-20251101 - normalized := strings.ReplaceAll(modelLower, "-4-5-", "-4.5-") - if pricing, ok := s.pricingData[normalized]; ok { - return pricing + for _, candidate := range lookupCandidates { + normalized := strings.ReplaceAll(candidate, "-4-5-", "-4.5-") + if pricing, ok := s.pricingData[normalized]; ok { + return pricing + } } // 3. 尝试模糊匹配(去掉版本号后缀) // claude-opus-4-5-20251101 -> claude-opus-4.5 - baseName := s.extractBaseName(modelLower) + baseName := s.extractBaseName(lookupCandidates[0]) for key, pricing := range s.pricingData { keyBase := s.extractBaseName(strings.ToLower(key)) if keyBase == baseName { @@ -422,18 +427,84 @@ func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing } // 4. 基于模型系列匹配(Claude) - if pricing := s.matchByModelFamily(modelLower); pricing != nil { + if pricing := s.matchByModelFamily(lookupCandidates[0]); pricing != nil { return pricing } // 5. OpenAI 模型回退策略 - if strings.HasPrefix(modelLower, "gpt-") { - return s.matchOpenAIModel(modelLower) + if strings.HasPrefix(lookupCandidates[0], "gpt-") { + return s.matchOpenAIModel(lookupCandidates[0]) } return nil } +func (s *PricingService) buildModelLookupCandidates(modelLower string) []string { + // Prefer canonical model name first (this also improves billing compatibility with "models/xxx"). + candidates := []string{ + normalizeModelNameForPricing(modelLower), + modelLower, + } + for _, cand := range []string{ + strings.TrimPrefix(modelLower, "models/"), + lastSegment(modelLower), + lastSegment(strings.TrimPrefix(modelLower, "models/")), + } { + candidates = append(candidates, cand) + } + + seen := make(map[string]struct{}, len(candidates)) + out := make([]string, 0, len(candidates)) + for _, c := range candidates { + c = strings.TrimSpace(c) + if c == "" { + continue + } + if _, ok := seen[c]; ok { + continue + } + seen[c] = struct{}{} + out = append(out, c) + } + if len(out) == 0 { + return []string{modelLower} + } + return out +} + +func normalizeModelNameForPricing(model string) string { + // Common Gemini/VertexAI forms: + // - models/gemini-2.0-flash-exp + // - publishers/google/models/gemini-1.5-pro + // - projects/.../locations/.../publishers/google/models/gemini-1.5-pro + model = strings.TrimSpace(model) + model = strings.TrimLeft(model, "/") + + if strings.HasPrefix(model, "models/") { + model = strings.TrimPrefix(model, "models/") + } + if strings.HasPrefix(model, "publishers/google/models/") { + model = strings.TrimPrefix(model, "publishers/google/models/") + } + + if idx := strings.LastIndex(model, "/publishers/google/models/"); idx != -1 { + model = model[idx+len("/publishers/google/models/"):] + } + if idx := strings.LastIndex(model, "/models/"); idx != -1 { + model = model[idx+len("/models/"):] + } + + model = strings.TrimLeft(model, "/") + return model +} + +func lastSegment(model string) string { + if idx := strings.LastIndex(model, "/"); idx != -1 { + return model[idx+1:] + } + return model +} + // extractBaseName 提取基础模型名称(去掉日期版本号) func (s *PricingService) extractBaseName(model string) string { // 移除日期后缀 (如 -20251101, -20241022) diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 24ef7b8e..187a517e 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -27,6 +27,7 @@ func NewTokenRefreshService( accountRepo AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, + geminiOAuthService *GeminiOAuthService, cfg *config.Config, ) *TokenRefreshService { s := &TokenRefreshService{ @@ -39,6 +40,7 @@ func NewTokenRefreshService( s.refreshers = []TokenRefresher{ NewClaudeTokenRefresher(oauthService), NewOpenAITokenRefresher(openaiOAuthService), + NewGeminiTokenRefresher(geminiOAuthService), } return s From e36fb98fb9facfd09a39d0c16ae5a349b0d8f7eb Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 25 Dec 2025 06:45:03 -0800 Subject: [PATCH 06/42] =?UTF-8?q?feat(handler):=20=E6=B7=BB=E5=8A=A0=20Gem?= =?UTF-8?q?ini=20OAuth=20Handler=20=E5=92=8C=E5=AE=8C=E5=96=84=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=E6=B3=A8=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 Gemini OAuth 授权处理器 - 扩展账号和网关处理器支持 Gemini - 注册 Gemini 相关路由 - 更新 Wire 依赖注入配置(所有层) - 更新 Docker Compose 配置 --- backend/cmd/server/wire.go | 4 + backend/cmd/server/wire_gen.go | 86 +++-- .../internal/handler/admin/account_handler.go | 16 + .../handler/admin/gemini_oauth_handler.go | 71 ++++ backend/internal/handler/gateway_handler.go | 29 +- backend/internal/handler/handler.go | 1 + backend/internal/handler/wire.go | 3 + backend/internal/repository/wire.go | 3 + backend/internal/server/router.go | 321 ++++++++++++++++-- backend/internal/service/wire.go | 6 +- deploy/docker-compose.yml | 2 +- 11 files changed, 488 insertions(+), 54 deletions(-) create mode 100644 backend/internal/handler/admin/gemini_oauth_handler.go diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 7d6ec065..a74c906a 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -99,6 +99,10 @@ func provideCleanup( openaiOAuth.Stop() return nil }}, + {"GeminiOAuthService", func() error { + services.GeminiOAuth.Stop() + return nil + }}, {"Redis", func() error { return rdb.Close() }}, diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index e72e5f6e..6ed9897b 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -14,7 +14,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/server" - "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" "gorm.io/gorm" @@ -80,6 +79,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) openAIOAuthClient := repository.NewOpenAIOAuthClient() openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) + geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) + geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() + geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig) rateLimitService := service.NewRateLimitService(accountRepository, configConfig) claudeUsageFetcher := repository.NewClaudeUsageFetcher() accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher) @@ -87,10 +89,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream) concurrencyCache := repository.NewConcurrencyCache(client) concurrencyService := service.NewConcurrencyService(concurrencyCache) - crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService) - accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) + crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) + accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) + geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) proxyHandler := admin.NewProxyHandler(adminService) adminRedeemHandler := admin.NewRedeemHandler(adminService) settingHandler := admin.NewSettingHandler(settingService, emailService) @@ -101,7 +104,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { systemHandler := handler.ProvideSystemHandler(updateService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler) gatewayCache := repository.NewGatewayCache(client) pricingRemoteClient := repository.NewPricingRemoteClient() pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) @@ -112,18 +115,63 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityCache := repository.NewIdentityCache(client) identityService := service.NewIdentityService(identityCache) gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream) - gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService) + geminiTokenCache := repository.NewGeminiTokenCache(client) + geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) + geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) - jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) - adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) - apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService) - engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware) + groupService := service.NewGroupService(groupRepository) + accountService := service.NewAccountService(accountRepository, groupRepository) + proxyService := service.NewProxyService(proxyRepository) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) + services := &service.Services{ + Auth: authService, + User: userService, + ApiKey: apiKeyService, + Group: groupService, + Account: accountService, + Proxy: proxyService, + Redeem: redeemService, + Usage: usageService, + Pricing: pricingService, + Billing: billingService, + BillingCache: billingCacheService, + Admin: adminService, + Gateway: gatewayService, + OpenAIGateway: openAIGatewayService, + OAuth: oAuthService, + OpenAIOAuth: openAIOAuthService, + GeminiOAuth: geminiOAuthService, + RateLimit: rateLimitService, + AccountUsage: accountUsageService, + AccountTest: accountTestService, + Setting: settingService, + Email: emailService, + EmailQueue: emailQueueService, + Turnstile: turnstileService, + Subscription: subscriptionService, + Concurrency: concurrencyService, + Identity: identityService, + Update: updateService, + TokenRefresh: tokenRefreshService, + } + repositories := &repository.Repositories{ + User: userRepository, + ApiKey: apiKeyRepository, + Group: groupRepository, + Account: accountRepository, + Proxy: proxyRepository, + RedeemCode: redeemCodeRepository, + UsageLog: usageLogRepository, + Setting: settingRepository, + UserSubscription: userSubscriptionRepository, + } + engine := server.ProvideRouter(configConfig, handlers, services, repositories) httpServer := server.ProvideHTTPServer(configConfig, engine) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig) - v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService) + v := provideCleanup(db, client, services) application := &Application{ Server: httpServer, Cleanup: v, @@ -148,11 +196,7 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { func provideCleanup( db *gorm.DB, rdb *redis.Client, - tokenRefresh *service.TokenRefreshService, - pricing *service.PricingService, - emailQueue *service.EmailQueueService, - oauth *service.OAuthService, - openaiOAuth *service.OpenAIOAuthService, + services *service.Services, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -163,23 +207,23 @@ func provideCleanup( fn func() error }{ {"TokenRefreshService", func() error { - tokenRefresh.Stop() + services.TokenRefresh.Stop() return nil }}, {"PricingService", func() error { - pricing.Stop() + services.Pricing.Stop() return nil }}, {"EmailQueueService", func() error { - emailQueue.Stop() + services.EmailQueue.Stop() return nil }}, {"OAuthService", func() error { - oauth.Stop() + services.OAuth.Stop() return nil }}, {"OpenAIOAuthService", func() error { - openaiOAuth.Stop() + services.OpenAIOAuth.Stop() return nil }}, {"Redis", func() error { diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 8ecb4326..3b6827e9 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -30,6 +30,7 @@ type AccountHandler struct { adminService service.AdminService oauthService *service.OAuthService openaiOAuthService *service.OpenAIOAuthService + geminiOAuthService *service.GeminiOAuthService rateLimitService *service.RateLimitService accountUsageService *service.AccountUsageService accountTestService *service.AccountTestService @@ -42,6 +43,7 @@ func NewAccountHandler( adminService service.AdminService, oauthService *service.OAuthService, openaiOAuthService *service.OpenAIOAuthService, + geminiOAuthService *service.GeminiOAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService, @@ -52,6 +54,7 @@ func NewAccountHandler( adminService: adminService, oauthService: oauthService, openaiOAuthService: openaiOAuthService, + geminiOAuthService: geminiOAuthService, rateLimitService: rateLimitService, accountUsageService: accountUsageService, accountTestService: accountTestService, @@ -345,6 +348,19 @@ func (h *AccountHandler) Refresh(c *gin.Context) { newCredentials[k] = v } } + } else if account.Platform == model.PlatformGemini { + tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account) + if err != nil { + response.InternalError(c, "Failed to refresh credentials: "+err.Error()) + return + } + + newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } } else { // Use Anthropic/Claude OAuth service to refresh token tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account) diff --git a/backend/internal/handler/admin/gemini_oauth_handler.go b/backend/internal/handler/admin/gemini_oauth_handler.go new file mode 100644 index 00000000..4d39700b --- /dev/null +++ b/backend/internal/handler/admin/gemini_oauth_handler.go @@ -0,0 +1,71 @@ +package admin + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +type GeminiOAuthHandler struct { + geminiOAuthService *service.GeminiOAuthService +} + +func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *GeminiOAuthHandler { + return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService} +} + +type GeminiGenerateAuthURLRequest struct { + ProxyID *int64 `json:"proxy_id"` + RedirectURI string `json:"redirect_uri" binding:"required"` +} + +// GenerateAuthURL generates Google OAuth authorization URL for Gemini. +// POST /api/v1/admin/gemini/oauth/auth-url +func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) { + var req GeminiGenerateAuthURLRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI) + if err != nil { + response.InternalError(c, "Failed to generate auth URL: "+err.Error()) + return + } + + response.Success(c, result) +} + +type GeminiExchangeCodeRequest struct { + SessionID string `json:"session_id" binding:"required"` + State string `json:"state" binding:"required"` + Code string `json:"code" binding:"required"` + RedirectURI string `json:"redirect_uri" binding:"required"` + ProxyID *int64 `json:"proxy_id"` +} + +// ExchangeCode exchanges authorization code for tokens. +// POST /api/v1/admin/gemini/oauth/exchange-code +func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) { + var req GeminiExchangeCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + tokenInfo, err := h.geminiOAuthService.ExchangeCode(c.Request.Context(), &service.GeminiExchangeCodeInput{ + SessionID: req.SessionID, + State: req.State, + Code: req.Code, + RedirectURI: req.RedirectURI, + ProxyID: req.ProxyID, + }) + if err != nil { + response.BadRequest(c, "Failed to exchange code: "+err.Error()) + return + } + + response.Success(c, tokenInfo) +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index bfb8b6fd..a980dfb1 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -22,15 +22,23 @@ import ( // GatewayHandler handles API gateway requests type GatewayHandler struct { gatewayService *service.GatewayService + geminiCompatService *service.GeminiMessagesCompatService userService *service.UserService billingCacheService *service.BillingCacheService concurrencyHelper *ConcurrencyHelper } // NewGatewayHandler creates a new GatewayHandler -func NewGatewayHandler(gatewayService *service.GatewayService, userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService) *GatewayHandler { +func NewGatewayHandler( + gatewayService *service.GatewayService, + geminiCompatService *service.GeminiMessagesCompatService, + userService *service.UserService, + concurrencyService *service.ConcurrencyService, + billingCacheService *service.BillingCacheService, +) *GatewayHandler { return &GatewayHandler{ gatewayService: gatewayService, + geminiCompatService: geminiCompatService, userService: userService, billingCacheService: billingCacheService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude), @@ -115,8 +123,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 计算粘性会话hash sessionHash := h.gatewayService.GenerateSessionHash(body) + platform := "" + if apiKey.Group != nil { + platform = apiKey.Group.Platform + } + // 选择支持该模型的账号 - account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) + var account *model.Account + if platform == model.PlatformGemini { + account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) + } else { + account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) + } if err != nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return @@ -144,7 +162,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 转发请求 - result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + var result *service.ForwardResult + if platform == model.PlatformGemini { + result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) + } else { + result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body) + } if err != nil { // 错误响应已在Forward中处理,这里只记录日志 log.Printf("Forward request failed: %v", err) diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 11c9dcf1..af28bc1f 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -12,6 +12,7 @@ type AdminHandlers struct { Account *admin.AccountHandler OAuth *admin.OAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler + GeminiOAuth *admin.GeminiOAuthHandler Proxy *admin.ProxyHandler Redeem *admin.RedeemHandler Setting *admin.SettingHandler diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index a37cb3e6..f6e2c031 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -15,6 +15,7 @@ func ProvideAdminHandlers( accountHandler *admin.AccountHandler, oauthHandler *admin.OAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler, + geminiOAuthHandler *admin.GeminiOAuthHandler, proxyHandler *admin.ProxyHandler, redeemHandler *admin.RedeemHandler, settingHandler *admin.SettingHandler, @@ -29,6 +30,7 @@ func ProvideAdminHandlers( Account: accountHandler, OAuth: oauthHandler, OpenAIOAuth: openaiOAuthHandler, + GeminiOAuth: geminiOAuthHandler, Proxy: proxyHandler, Redeem: redeemHandler, Setting: settingHandler, @@ -95,6 +97,7 @@ var ProviderSet = wire.NewSet( admin.NewAccountHandler, admin.NewOAuthHandler, admin.NewOpenAIOAuthHandler, + admin.NewGeminiOAuthHandler, admin.NewProxyHandler, admin.NewRedeemHandler, admin.NewSettingHandler, diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index ceeb82fc..53d42d90 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -25,6 +25,7 @@ var ProviderSet = wire.NewSet( NewIdentityCache, NewRedeemCache, NewUpdateCache, + NewGeminiTokenCache, // HTTP service ports (DI Strategy A: return interface directly) NewTurnstileVerifier, @@ -35,4 +36,6 @@ var ProviderSet = wire.NewSet( NewClaudeOAuthClient, NewHTTPUpstream, NewOpenAIOAuthClient, + NewGeminiOAuthClient, + NewGeminiCliCodeAssistClient, ) diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 226fe99b..d1a73c43 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -1,54 +1,319 @@ package server import ( + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" - middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/server/routes" + "github.com/Wei-Shaw/sub2api/internal/middleware" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/web" + "net/http" "github.com/gin-gonic/gin" ) // SetupRouter 配置路由器中间件和路由 -func SetupRouter( - r *gin.Engine, - handlers *handler.Handlers, - jwtAuth middleware2.JWTAuthMiddleware, - adminAuth middleware2.AdminAuthMiddleware, - apiKeyAuth middleware2.ApiKeyAuthMiddleware, -) *gin.Engine { +func SetupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine { // 应用中间件 - r.Use(middleware2.Logger()) - r.Use(middleware2.CORS()) + r.Use(middleware.Logger()) + r.Use(middleware.CORS()) + + // 注册路由 + registerRoutes(r, handlers, services, repos) // Serve embedded frontend if available if web.HasEmbeddedFrontend() { r.Use(web.ServeEmbeddedFrontend()) } - // 注册路由 - registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth) - return r } // registerRoutes 注册所有 HTTP 路由 -func registerRoutes( - r *gin.Engine, - h *handler.Handlers, - jwtAuth middleware2.JWTAuthMiddleware, - adminAuth middleware2.AdminAuthMiddleware, - apiKeyAuth middleware2.ApiKeyAuthMiddleware, -) { - // 通用路由(健康检查、状态等) - routes.RegisterCommonRoutes(r) +func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) { + // 健康检查 + r.GET("/health", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) + + // Claude Code 遥测日志(忽略,直接返回200) + r.POST("/api/event_logging/batch", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + // Setup status endpoint (always returns needs_setup: false in normal mode) + // This is used by the frontend to detect when the service has restarted after setup + r.GET("/setup/status", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": gin.H{ + "needs_setup": false, + "step": "completed", + }, + }) + }) // API v1 v1 := r.Group("/api/v1") + { + // 公开接口 + auth := v1.Group("/auth") + { + auth.POST("/register", h.Auth.Register) + auth.POST("/login", h.Auth.Login) + auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + } - // 注册各模块路由 - routes.RegisterAuthRoutes(v1, h, jwtAuth) - routes.RegisterUserRoutes(v1, h, jwtAuth) - routes.RegisterAdminRoutes(v1, h, adminAuth) - routes.RegisterGatewayRoutes(r, h, apiKeyAuth) + // 公开设置(无需认证) + settings := v1.Group("/settings") + { + settings.GET("/public", h.Setting.GetPublicSettings) + } + + // 需要认证的接口 + authenticated := v1.Group("") + authenticated.Use(middleware.JWTAuth(s.Auth, repos.User)) + { + // 当前用户信息 + authenticated.GET("/auth/me", h.Auth.GetCurrentUser) + + // 用户接口 + user := authenticated.Group("/user") + { + user.GET("/profile", h.User.GetProfile) + user.PUT("/password", h.User.ChangePassword) + user.PUT("", h.User.UpdateProfile) + } + + // API Key管理 + keys := authenticated.Group("/keys") + { + keys.GET("", h.APIKey.List) + keys.GET("/:id", h.APIKey.GetByID) + keys.POST("", h.APIKey.Create) + keys.PUT("/:id", h.APIKey.Update) + keys.DELETE("/:id", h.APIKey.Delete) + } + + // 用户可用分组(非管理员接口) + groups := authenticated.Group("/groups") + { + groups.GET("/available", h.APIKey.GetAvailableGroups) + } + + // 使用记录 + usage := authenticated.Group("/usage") + { + usage.GET("", h.Usage.List) + usage.GET("/:id", h.Usage.GetByID) + usage.GET("/stats", h.Usage.Stats) + // User dashboard endpoints + usage.GET("/dashboard/stats", h.Usage.DashboardStats) + usage.GET("/dashboard/trend", h.Usage.DashboardTrend) + usage.GET("/dashboard/models", h.Usage.DashboardModels) + usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage) + } + + // 卡密兑换 + redeem := authenticated.Group("/redeem") + { + redeem.POST("", h.Redeem.Redeem) + redeem.GET("/history", h.Redeem.GetHistory) + } + + // 用户订阅 + subscriptions := authenticated.Group("/subscriptions") + { + subscriptions.GET("", h.Subscription.List) + subscriptions.GET("/active", h.Subscription.GetActive) + subscriptions.GET("/progress", h.Subscription.GetProgress) + subscriptions.GET("/summary", h.Subscription.GetSummary) + } + } + + // 管理员接口 + admin := v1.Group("/admin") + admin.Use(middleware.AdminAuth(s.Auth, repos.User, s.Setting)) + { + // 仪表盘 + dashboard := admin.Group("/dashboard") + { + dashboard.GET("/stats", h.Admin.Dashboard.GetStats) + dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics) + dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend) + dashboard.GET("/models", h.Admin.Dashboard.GetModelStats) + dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend) + dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) + dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) + dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage) + } + + // 用户管理 + users := admin.Group("/users") + { + users.GET("", h.Admin.User.List) + users.GET("/:id", h.Admin.User.GetByID) + users.POST("", h.Admin.User.Create) + users.PUT("/:id", h.Admin.User.Update) + users.DELETE("/:id", h.Admin.User.Delete) + users.POST("/:id/balance", h.Admin.User.UpdateBalance) + users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys) + users.GET("/:id/usage", h.Admin.User.GetUserUsage) + } + + // 分组管理 + groups := admin.Group("/groups") + { + groups.GET("", h.Admin.Group.List) + groups.GET("/all", h.Admin.Group.GetAll) + groups.GET("/:id", h.Admin.Group.GetByID) + groups.POST("", h.Admin.Group.Create) + groups.PUT("/:id", h.Admin.Group.Update) + groups.DELETE("/:id", h.Admin.Group.Delete) + groups.GET("/:id/stats", h.Admin.Group.GetStats) + groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys) + } + + // 账号管理 + accounts := admin.Group("/accounts") + { + accounts.GET("", h.Admin.Account.List) + accounts.GET("/:id", h.Admin.Account.GetByID) + accounts.POST("", h.Admin.Account.Create) + accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS) + accounts.PUT("/:id", h.Admin.Account.Update) + accounts.DELETE("/:id", h.Admin.Account.Delete) + accounts.POST("/:id/test", h.Admin.Account.Test) + accounts.POST("/:id/refresh", h.Admin.Account.Refresh) + accounts.GET("/:id/stats", h.Admin.Account.GetStats) + accounts.POST("/:id/clear-error", h.Admin.Account.ClearError) + accounts.GET("/:id/usage", h.Admin.Account.GetUsage) + accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) + accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) + accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) + accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels) + accounts.POST("/batch", h.Admin.Account.BatchCreate) + accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials) + accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate) + + // Claude OAuth routes + accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL) + accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL) + accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode) + accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode) + accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth) + accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth) + } + + // OpenAI OAuth routes + openai := admin.Group("/openai") + { + openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL) + openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode) + openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken) + openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken) + openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth) + } + + // Gemini OAuth routes + gemini := admin.Group("/gemini") + { + gemini.POST("/oauth/auth-url", h.Admin.GeminiOAuth.GenerateAuthURL) + gemini.POST("/oauth/exchange-code", h.Admin.GeminiOAuth.ExchangeCode) + } + + // 代理管理 + proxies := admin.Group("/proxies") + { + proxies.GET("", h.Admin.Proxy.List) + proxies.GET("/all", h.Admin.Proxy.GetAll) + proxies.GET("/:id", h.Admin.Proxy.GetByID) + proxies.POST("", h.Admin.Proxy.Create) + proxies.PUT("/:id", h.Admin.Proxy.Update) + proxies.DELETE("/:id", h.Admin.Proxy.Delete) + proxies.POST("/:id/test", h.Admin.Proxy.Test) + proxies.GET("/:id/stats", h.Admin.Proxy.GetStats) + proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts) + proxies.POST("/batch", h.Admin.Proxy.BatchCreate) + } + + // 卡密管理 + codes := admin.Group("/redeem-codes") + { + codes.GET("", h.Admin.Redeem.List) + codes.GET("/stats", h.Admin.Redeem.GetStats) + codes.GET("/export", h.Admin.Redeem.Export) + codes.GET("/:id", h.Admin.Redeem.GetByID) + codes.POST("/generate", h.Admin.Redeem.Generate) + codes.DELETE("/:id", h.Admin.Redeem.Delete) + codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete) + codes.POST("/:id/expire", h.Admin.Redeem.Expire) + } + + // 系统设置 + adminSettings := admin.Group("/settings") + { + adminSettings.GET("", h.Admin.Setting.GetSettings) + adminSettings.PUT("", h.Admin.Setting.UpdateSettings) + adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection) + adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail) + // Admin API Key 管理 + adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey) + adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey) + adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey) + } + + // 系统管理 + system := admin.Group("/system") + { + system.GET("/version", h.Admin.System.GetVersion) + system.GET("/check-updates", h.Admin.System.CheckUpdates) + system.POST("/update", h.Admin.System.PerformUpdate) + system.POST("/rollback", h.Admin.System.Rollback) + system.POST("/restart", h.Admin.System.RestartService) + } + + // 订阅管理 + subscriptions := admin.Group("/subscriptions") + { + subscriptions.GET("", h.Admin.Subscription.List) + subscriptions.GET("/:id", h.Admin.Subscription.GetByID) + subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress) + subscriptions.POST("/assign", h.Admin.Subscription.Assign) + subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign) + subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend) + subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke) + } + + // 分组下的订阅列表 + admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup) + + // 用户下的订阅列表 + admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser) + + // 使用记录管理 + usage := admin.Group("/usage") + { + usage.GET("", h.Admin.Usage.List) + usage.GET("/stats", h.Admin.Usage.Stats) + usage.GET("/search-users", h.Admin.Usage.SearchUsers) + usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys) + } + } + } + + // API网关(Claude API兼容) + gateway := r.Group("/v1") + gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription)) + { + gateway.POST("/messages", h.Gateway.Messages) + gateway.POST("/messages/count_tokens", h.Gateway.CountTokens) + gateway.GET("/models", h.Gateway.Models) + gateway.GET("/usage", h.Gateway.Usage) + // OpenAI Responses API + gateway.POST("/responses", h.OpenAIGateway.Responses) + } + + // OpenAI Responses API(不带v1前缀的别名) + r.POST("/responses", middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription), h.OpenAIGateway.Responses) } diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 02ef2392..d60ec737 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -36,9 +36,10 @@ func ProvideTokenRefreshService( accountRepo AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, + geminiOAuthService *GeminiOAuthService, cfg *config.Config, ) *TokenRefreshService { - svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, cfg) + svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, cfg) svc.Start() return svc } @@ -63,6 +64,9 @@ var ProviderSet = wire.NewSet( NewOpenAIGatewayService, NewOAuthService, NewOpenAIOAuthService, + NewGeminiOAuthService, + NewGeminiTokenProvider, + NewGeminiMessagesCompatService, NewRateLimitService, NewAccountUsageService, NewAccountTestService, diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index d534d3d6..59a91969 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -16,7 +16,7 @@ services: # Sub2API Application # =========================================================================== sub2api: - image: weishaw/sub2api:latest + image: sub2api:latest container_name: sub2api restart: unless-stopped ports: From 03a8ae62e53a9796299f347470e7e628f833a939 Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 25 Dec 2025 08:39:32 -0800 Subject: [PATCH 07/42] =?UTF-8?q?feat(backend):=20=E5=AE=8C=E5=96=84=20Gem?= =?UTF-8?q?ini=20OAuth=20Token=20=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复 account_handler 中 token 字段类型转换(int64 转 string) - 增强 Account.GetCredential 支持多种数值类型(float64, int, json.Number 等) - 添加 Account.IsGemini() 方法用于平台判断 - 优化 refresh_token 和 scope 的空值处理 --- .../internal/handler/admin/account_handler.go | 27 +++++++------ backend/internal/model/account.go | 39 ++++++++++++++++++- .../internal/service/gemini_oauth_service.go | 6 ++- 3 files changed, 57 insertions(+), 15 deletions(-) diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 3b6827e9..8351d1af 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -2,6 +2,7 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" @@ -375,18 +376,22 @@ func (h *AccountHandler) Refresh(c *gin.Context) { newCredentials[k] = v } - // Update token-related fields - newCredentials["access_token"] = tokenInfo.AccessToken - newCredentials["token_type"] = tokenInfo.TokenType - newCredentials["expires_in"] = tokenInfo.ExpiresIn - newCredentials["expires_at"] = tokenInfo.ExpiresAt - newCredentials["refresh_token"] = tokenInfo.RefreshToken - newCredentials["scope"] = tokenInfo.Scope - } + // Update token-related fields + newCredentials["access_token"] = tokenInfo.AccessToken + newCredentials["token_type"] = tokenInfo.TokenType + newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10) + newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) + if strings.TrimSpace(tokenInfo.RefreshToken) != "" { + newCredentials["refresh_token"] = tokenInfo.RefreshToken + } + if strings.TrimSpace(tokenInfo.Scope) != "" { + newCredentials["scope"] = tokenInfo.Scope + } + } - updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ - Credentials: newCredentials, - }) + updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ + Credentials: newCredentials, + }) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/model/account.go b/backend/internal/model/account.go index 9b09b114..8abd9049 100644 --- a/backend/internal/model/account.go +++ b/backend/internal/model/account.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/json" "errors" + "strconv" "time" "gorm.io/gorm" @@ -128,8 +129,37 @@ func (a *Account) GetCredential(key string) string { return "" } if v, ok := a.Credentials[key]; ok { - if s, ok := v.(string); ok { - return s + switch vv := v.(type) { + case string: + return vv + case json.Number: + return vv.String() + case float64: + // JSON numbers decode to float64; keep integer formatting for integer-like values. + i := int64(vv) + if vv == float64(i) { + return strconv.FormatInt(i, 10) + } + return strconv.FormatFloat(vv, 'f', -1, 64) + case float32: + f := float64(vv) + i := int64(f) + if f == float64(i) { + return strconv.FormatInt(i, 10) + } + return strconv.FormatFloat(f, 'f', -1, 64) + case int: + return strconv.FormatInt(int64(vv), 10) + case int64: + return strconv.FormatInt(vv, 10) + case int32: + return strconv.FormatInt(int64(vv), 10) + case uint: + return strconv.FormatUint(uint64(vv), 10) + case uint64: + return strconv.FormatUint(vv, 10) + case uint32: + return strconv.FormatUint(uint64(vv), 10) } } return "" @@ -291,6 +321,11 @@ func (a *Account) IsAnthropic() bool { return a.Platform == PlatformAnthropic } +// IsGemini 检查是否为 Gemini 平台账号 +func (a *Account) IsGemini() bool { + return a.Platform == PlatformGemini +} + // IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号 func (a *Account) IsOpenAIOAuth() bool { return a.IsOpenAI() && a.Type == AccountTypeOAuth diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 067a2455..b3dc3f09 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -139,7 +139,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch } s.sessionStore.Delete(input.SessionID) - expiresAt := time.Now().Unix() + tokenResp.ExpiresIn + // 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 projectID, _ := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) return &GeminiTokenInfo{ @@ -167,7 +168,8 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) if err == nil { - expiresAt := time.Now().Unix() + tokenResp.ExpiresIn + // 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 return &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, From 0b30cc2b7e049a2079fadf5d04e02a354f806066 Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 25 Dec 2025 08:39:48 -0800 Subject: [PATCH 08/42] =?UTF-8?q?feat(frontend):=20=E6=96=B0=E5=A2=9E=20Ge?= =?UTF-8?q?mini=20OAuth=20=E6=8E=88=E6=9D=83=E6=B5=81=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 /admin/gemini API 接口封装(generateAuthUrl, exchangeCode) - 新增 useGeminiOAuth composable 处理 Gemini OAuth 流程 - 新增 OAuthCallbackView 视图用于接收 OAuth 回调 - 支持 code/state 参数提取和 credentials 构建 --- frontend/src/api/admin/gemini.ts | 47 +++++++ frontend/src/composables/useGeminiOAuth.ts | 133 ++++++++++++++++++ frontend/src/views/auth/OAuthCallbackView.vue | 87 ++++++++++++ 3 files changed, 267 insertions(+) create mode 100644 frontend/src/api/admin/gemini.ts create mode 100644 frontend/src/composables/useGeminiOAuth.ts create mode 100644 frontend/src/views/auth/OAuthCallbackView.vue diff --git a/frontend/src/api/admin/gemini.ts b/frontend/src/api/admin/gemini.ts new file mode 100644 index 00000000..0a2c4a66 --- /dev/null +++ b/frontend/src/api/admin/gemini.ts @@ -0,0 +1,47 @@ +/** + * Admin Gemini API endpoints + * Handles Gemini OAuth flows for administrators + */ + +import { apiClient } from '../client' + +export interface GeminiAuthUrlResponse { + auth_url: string + session_id: string + state: string +} + +export interface GeminiAuthUrlRequest { + redirect_uri: string + proxy_id?: number +} + +export interface GeminiExchangeCodeRequest { + session_id: string + state: string + code: string + redirect_uri: string + proxy_id?: number +} + +export type GeminiTokenInfo = Record + +export async function generateAuthUrl( + payload: GeminiAuthUrlRequest +): Promise { + const { data } = await apiClient.post( + '/admin/gemini/oauth/auth-url', + payload + ) + return data +} + +export async function exchangeCode(payload: GeminiExchangeCodeRequest): Promise { + const { data } = await apiClient.post( + '/admin/gemini/oauth/exchange-code', + payload + ) + return data +} + +export default { generateAuthUrl, exchangeCode } diff --git a/frontend/src/composables/useGeminiOAuth.ts b/frontend/src/composables/useGeminiOAuth.ts new file mode 100644 index 00000000..63e5c2b9 --- /dev/null +++ b/frontend/src/composables/useGeminiOAuth.ts @@ -0,0 +1,133 @@ +import { ref } from 'vue' +import { useI18n } from 'vue-i18n' +import { useAppStore } from '@/stores/app' +import { adminAPI } from '@/api/admin' + +export interface GeminiTokenInfo { + access_token?: string + refresh_token?: string + token_type?: string + scope?: string + expires_at?: number | string + project_id?: string + [key: string]: unknown +} + +export function useGeminiOAuth() { + const appStore = useAppStore() + const { t } = useI18n() + + const authUrl = ref('') + const sessionId = ref('') + const state = ref('') + const loading = ref(false) + const error = ref('') + + const resetState = () => { + authUrl.value = '' + sessionId.value = '' + state.value = '' + loading.value = false + error.value = '' + } + + const generateAuthUrl = async ( + proxyId: number | null | undefined, + redirectUri: string + ): Promise => { + loading.value = true + authUrl.value = '' + sessionId.value = '' + state.value = '' + error.value = '' + + try { + if (!redirectUri?.trim()) { + error.value = t('admin.accounts.oauth.gemini.missingRedirectUri') + appStore.showError(error.value) + return false + } + + const payload: Record = { redirect_uri: redirectUri.trim() } + if (proxyId) payload.proxy_id = proxyId + + const response = await adminAPI.gemini.generateAuthUrl(payload as any) + authUrl.value = response.auth_url + sessionId.value = response.session_id + state.value = response.state + return true + } catch (err: any) { + error.value = err.response?.data?.detail || t('admin.accounts.oauth.gemini.failedToGenerateUrl') + appStore.showError(error.value) + return false + } finally { + loading.value = false + } + } + + const exchangeAuthCode = async (params: { + code: string + sessionId: string + state: string + redirectUri: string + proxyId?: number | null + }): Promise => { + const code = params.code?.trim() + if (!code || !params.sessionId || !params.state || !params.redirectUri?.trim()) { + error.value = t('admin.accounts.oauth.gemini.missingExchangeParams') + return null + } + + loading.value = true + error.value = '' + + try { + const payload: Record = { + session_id: params.sessionId, + state: params.state, + code, + redirect_uri: params.redirectUri.trim() + } + if (params.proxyId) payload.proxy_id = params.proxyId + + const tokenInfo = await adminAPI.gemini.exchangeCode(payload as any) + return tokenInfo as GeminiTokenInfo + } catch (err: any) { + error.value = err.response?.data?.detail || t('admin.accounts.oauth.gemini.failedToExchangeCode') + appStore.showError(error.value) + return null + } finally { + loading.value = false + } + } + + const buildCredentials = (tokenInfo: GeminiTokenInfo): Record => { + let expiresAt: string | undefined + if (typeof tokenInfo.expires_at === 'number' && Number.isFinite(tokenInfo.expires_at)) { + expiresAt = Math.floor(tokenInfo.expires_at).toString() + } else if (typeof tokenInfo.expires_at === 'string' && tokenInfo.expires_at.trim()) { + expiresAt = tokenInfo.expires_at.trim() + } + + return { + access_token: tokenInfo.access_token, + refresh_token: tokenInfo.refresh_token, + token_type: tokenInfo.token_type, + expires_at: expiresAt, + scope: tokenInfo.scope, + project_id: tokenInfo.project_id + } + } + + return { + authUrl, + sessionId, + state, + loading, + error, + resetState, + generateAuthUrl, + exchangeAuthCode, + buildCredentials + } +} diff --git a/frontend/src/views/auth/OAuthCallbackView.vue b/frontend/src/views/auth/OAuthCallbackView.vue new file mode 100644 index 00000000..64489e42 --- /dev/null +++ b/frontend/src/views/auth/OAuthCallbackView.vue @@ -0,0 +1,87 @@ + + + From 1ac8b1f03e17b64358eb371583fd60f9b4aef496 Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 25 Dec 2025 08:40:05 -0800 Subject: [PATCH 09/42] =?UTF-8?q?feat(frontend):=20Components=20=E9=9B=86?= =?UTF-8?q?=E6=88=90=20Gemini=20=E8=B4=A6=E5=8F=B7=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - CreateAccountModal: 添加 Gemini 平台选项和 OAuth 授权流程 - EditAccountModal: 支持 Gemini 账号编辑 - OAuthAuthorizationFlow: 新增 Gemini 平台 OAuth 流程处理(支持 state 参数) - ReAuthAccountModal: 支持 Gemini 账号重新授权 - 优化代码格式和组件逻辑 --- .../components/account/CreateAccountModal.vue | 1070 +++++++++++++---- .../components/account/EditAccountModal.vue | 570 ++++++--- .../account/OAuthAuthorizationFlow.vue | 355 ++++-- .../components/account/ReAuthAccountModal.vue | 298 ++++- 4 files changed, 1783 insertions(+), 510 deletions(-) diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 1634777e..e22af4c7 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1,10 +1,5 @@