Files
sub2api/backend/internal/service/openai_oauth_service.go

380 lines
13 KiB
Go
Raw Normal View History

2025-12-22 22:58:31 +08:00
package service
import (
"context"
"crypto/subtle"
"log/slog"
"net/http"
"strings"
2025-12-22 22:58:31 +08:00
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
2025-12-24 21:07:21 +08:00
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
2025-12-22 22:58:31 +08:00
)
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct {
sessionStore *openai.SessionStore
proxyRepo ProxyRepository
oauthClient OpenAIOAuthClient
privacyClientFactory PrivacyClientFactory // 用于调用 chatgpt.com/backend-apiImpersonateChrome
2025-12-22 22:58:31 +08:00
}
// NewOpenAIOAuthService creates a new OpenAI OAuth service
2025-12-25 17:15:01 +08:00
func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthClient) *OpenAIOAuthService {
2025-12-22 22:58:31 +08:00
return &OpenAIOAuthService{
sessionStore: openai.NewSessionStore(),
proxyRepo: proxyRepo,
oauthClient: oauthClient,
}
}
// SetPrivacyClientFactory 注入 ImpersonateChrome 客户端工厂,
// 用于调用 chatgpt.com/backend-api 获取账号信息plan_type 等)。
func (s *OpenAIOAuthService) SetPrivacyClientFactory(factory PrivacyClientFactory) {
s.privacyClientFactory = factory
}
2025-12-22 22:58:31 +08:00
// OpenAIAuthURLResult contains the authorization URL and session info
type OpenAIAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
}
// GenerateAuthURL generates an OpenAI OAuth authorization URL
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) {
2025-12-22 22:58:31 +08:00
// Generate PKCE values
state, err := openai.GenerateState()
if err != nil {
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_STATE_FAILED", "failed to generate state: %v", err)
2025-12-22 22:58:31 +08:00
}
codeVerifier, err := openai.GenerateCodeVerifier()
if err != nil {
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_VERIFIER_FAILED", "failed to generate code verifier: %v", err)
2025-12-22 22:58:31 +08:00
}
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
// Generate session ID
sessionID, err := openai.GenerateSessionID()
if err != nil {
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_SESSION_FAILED", "failed to generate session ID: %v", err)
2025-12-22 22:58:31 +08:00
}
// Get proxy URL if specified
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy != nil {
2025-12-22 22:58:31 +08:00
proxyURL = proxy.URL()
}
}
// Use default redirect URI if not specified
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
}
normalizedPlatform := normalizeOpenAIOAuthPlatform(platform)
clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform)
2025-12-22 22:58:31 +08:00
// Store session
session := &openai.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ClientID: clientID,
2025-12-22 22:58:31 +08:00
RedirectURI: redirectURI,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
// Build authorization URL
authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform)
2025-12-22 22:58:31 +08:00
return &OpenAIAuthURLResult{
AuthURL: authURL,
SessionID: sessionID,
}, nil
}
// OpenAIExchangeCodeInput represents the input for code exchange
type OpenAIExchangeCodeInput struct {
SessionID string
Code string
State string
2025-12-22 22:58:31 +08:00
RedirectURI string
ProxyID *int64
}
// OpenAITokenInfo represents the token information for OpenAI
type OpenAITokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token,omitempty"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
ClientID string `json:"client_id,omitempty"`
Email string `json:"email,omitempty"`
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
OrganizationID string `json:"organization_id,omitempty"`
PlanType string `json:"plan_type,omitempty"`
SubscriptionExpiresAt string `json:"subscription_expires_at,omitempty"`
PrivacyMode string `json:"privacy_mode,omitempty"`
2025-12-22 22:58:31 +08:00
}
// ExchangeCode exchanges authorization code for tokens
func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExchangeCodeInput) (*OpenAITokenInfo, error) {
// Get session
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
2025-12-22 22:58:31 +08:00
}
if input.State == "" {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required")
}
if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state")
}
2025-12-22 22:58:31 +08:00
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
2025-12-22 22:58:31 +08:00
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
}
if proxy != nil {
2025-12-22 22:58:31 +08:00
proxyURL = proxy.URL()
}
}
// Use redirect URI from session or input
redirectURI := session.RedirectURI
if input.RedirectURI != "" {
redirectURI = input.RedirectURI
}
clientID := strings.TrimSpace(session.ClientID)
if clientID == "" {
clientID = openai.ClientID
}
2025-12-22 22:58:31 +08:00
// Exchange code for token
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID)
2025-12-22 22:58:31 +08:00
if err != nil {
return nil, err
2025-12-22 22:58:31 +08:00
}
// Parse ID token to get user info
var userInfo *openai.UserInfo
if tokenResp.IDToken != "" {
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
if parseErr != nil {
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
} else {
2025-12-22 22:58:31 +08:00
userInfo = claims.GetUserInfo()
}
}
// Delete session after successful exchange
s.sessionStore.Delete(input.SessionID)
tokenInfo := &OpenAITokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
IDToken: tokenResp.IDToken,
ExpiresIn: int64(tokenResp.ExpiresIn),
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
ClientID: clientID,
2025-12-22 22:58:31 +08:00
}
if userInfo != nil {
tokenInfo.Email = userInfo.Email
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
tokenInfo.OrganizationID = userInfo.OrganizationID
tokenInfo.PlanType = userInfo.PlanType
2025-12-22 22:58:31 +08:00
}
s.enrichTokenInfo(ctx, tokenInfo, proxyURL)
2025-12-22 22:58:31 +08:00
return tokenInfo, nil
}
// RefreshToken refreshes an OpenAI OAuth token
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
// RefreshTokenWithClientID refreshes an OpenAI OAuth token with optional client_id.
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
2025-12-22 22:58:31 +08:00
if err != nil {
return nil, err
}
// Parse ID token to get user info
var userInfo *openai.UserInfo
if tokenResp.IDToken != "" {
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
if parseErr != nil {
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
} else {
2025-12-22 22:58:31 +08:00
userInfo = claims.GetUserInfo()
}
}
tokenInfo := &OpenAITokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
IDToken: tokenResp.IDToken,
ExpiresIn: int64(tokenResp.ExpiresIn),
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
}
if trimmed := strings.TrimSpace(clientID); trimmed != "" {
tokenInfo.ClientID = trimmed
}
2025-12-22 22:58:31 +08:00
if userInfo != nil {
tokenInfo.Email = userInfo.Email
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
tokenInfo.OrganizationID = userInfo.OrganizationID
tokenInfo.PlanType = userInfo.PlanType
2025-12-22 22:58:31 +08:00
}
s.enrichTokenInfo(ctx, tokenInfo, proxyURL)
return tokenInfo, nil
}
// enrichTokenInfo 通过 ChatGPT backend-api 补全 tokenInfo 并设置隐私best-effort
// 从 accounts/check 获取最新 plan_type、subscription_expires_at、email
// 然后尝试关闭训练数据共享。适用于所有获取/刷新 token 的路径。
func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *OpenAITokenInfo, proxyURL string) {
if tokenInfo.AccessToken == "" || s.privacyClientFactory == nil {
return
}
// 从 access_token JWT 中提取 orgIDpoid用于匹配正确的账号
orgID := tokenInfo.OrganizationID
if orgID == "" {
if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil {
orgID = atClaims.OpenAIAuth.POID
}
}
if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil {
if info.PlanType != "" {
tokenInfo.PlanType = info.PlanType
}
if info.SubscriptionExpiresAt != "" {
tokenInfo.SubscriptionExpiresAt = info.SubscriptionExpiresAt
}
if tokenInfo.Email == "" && info.Email != "" {
tokenInfo.Email = info.Email
}
}
// 尝试设置隐私关闭训练数据共享best-effort
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
2025-12-22 22:58:31 +08:00
}
// RefreshAccountToken refreshes token for an OpenAI OAuth account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if account.Platform != PlatformOpenAI {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
}
if account.Type != AccountTypeOAuth {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
2025-12-22 22:58:31 +08:00
}
refreshToken := account.GetCredential("refresh_token")
2025-12-22 22:58:31 +08:00
if refreshToken == "" {
accessToken := account.GetCredential("access_token")
if accessToken != "" {
tokenInfo := &OpenAITokenInfo{
AccessToken: accessToken,
RefreshToken: "",
IDToken: account.GetCredential("id_token"),
ClientID: account.GetCredential("client_id"),
Email: account.GetCredential("email"),
ChatGPTAccountID: account.GetCredential("chatgpt_account_id"),
ChatGPTUserID: account.GetCredential("chatgpt_user_id"),
OrganizationID: account.GetCredential("organization_id"),
PlanType: account.GetCredential("plan_type"),
}
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
tokenInfo.ExpiresAt = expiresAt.Unix()
tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds())
}
return tokenInfo, nil
}
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
2025-12-22 22:58:31 +08:00
}
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
clientID := account.GetCredential("client_id")
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
2025-12-22 22:58:31 +08:00
}
// BuildAccountCredentials builds credentials map from token info
func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) map[string]any {
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": expiresAt,
}
// 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌
if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
2025-12-22 22:58:31 +08:00
}
if tokenInfo.IDToken != "" {
creds["id_token"] = tokenInfo.IDToken
}
if tokenInfo.Email != "" {
creds["email"] = tokenInfo.Email
}
if tokenInfo.ChatGPTAccountID != "" {
creds["chatgpt_account_id"] = tokenInfo.ChatGPTAccountID
}
if tokenInfo.ChatGPTUserID != "" {
creds["chatgpt_user_id"] = tokenInfo.ChatGPTUserID
}
if tokenInfo.OrganizationID != "" {
creds["organization_id"] = tokenInfo.OrganizationID
}
if tokenInfo.PlanType != "" {
creds["plan_type"] = tokenInfo.PlanType
}
if tokenInfo.SubscriptionExpiresAt != "" {
creds["subscription_expires_at"] = tokenInfo.SubscriptionExpiresAt
}
if strings.TrimSpace(tokenInfo.ClientID) != "" {
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
}
2025-12-22 22:58:31 +08:00
return creds
}
// Stop stops the session store cleanup goroutine
func (s *OpenAIOAuthService) Stop() {
s.sessionStore.Stop()
}
func normalizeOpenAIOAuthPlatform(platform string) string {
return openai.OAuthPlatformOpenAI
}