mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
- Security: force token_uri to Google default, preventing SSRF via crafted service account JSON - Dedup: extract shared getVertexServiceAccountAccessToken() to eliminate ~35 lines of duplication between ClaudeTokenProvider and GeminiTokenProvider - Fix: apply model mapping + Vertex model ID normalization in forward_as_responses and forward_as_chat_completions paths - Fix: exclude service_account from AI Studio endpoint selection (Vertex cannot serve generativelanguage.googleapis.com) - Feature: add model restriction/mapping UI for service_account in EditAccountModal - Dedup: extract VERTEX_LOCATION_OPTIONS to shared constants - i18n: replace all hardcoded Chinese strings in Vertex UI with translation keys
346 lines
10 KiB
Go
346 lines
10 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
const (
|
|
vertexDefaultLocation = "us-central1"
|
|
vertexDefaultTokenURL = "https://oauth2.googleapis.com/token"
|
|
vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
|
vertexServiceAccountCacheSkew = 5 * time.Minute
|
|
vertexLockWaitTime = 200 * time.Millisecond
|
|
vertexAnthropicVersion = "vertex-2023-10-16"
|
|
)
|
|
|
|
var (
|
|
vertexLocationPattern = regexp.MustCompile(`^[a-z0-9-]+$`)
|
|
vertexAnthropicDatedModelIDPattern = regexp.MustCompile(`^(.+)-([0-9]{8})$`)
|
|
vertexAnthropicAlreadyDatedIDPattern = regexp.MustCompile(`^.+@[0-9]{8}$`)
|
|
)
|
|
|
|
type vertexServiceAccountKey struct {
|
|
Type string `json:"type"`
|
|
ProjectID string `json:"project_id"`
|
|
PrivateKeyID string `json:"private_key_id"`
|
|
PrivateKey string `json:"private_key"`
|
|
ClientEmail string `json:"client_email"`
|
|
TokenURI string `json:"token_uri"`
|
|
}
|
|
|
|
type vertexTokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
Error string `json:"error"`
|
|
ErrorDesc string `json:"error_description"`
|
|
}
|
|
|
|
func (a *Account) IsVertexServiceAccount() bool {
|
|
return a != nil && a.Type == AccountTypeServiceAccount
|
|
}
|
|
|
|
func (a *Account) VertexProjectID() string {
|
|
if a == nil {
|
|
return ""
|
|
}
|
|
if v := strings.TrimSpace(a.GetCredential("project_id")); v != "" {
|
|
return v
|
|
}
|
|
key, err := parseVertexServiceAccountKey(a)
|
|
if err == nil {
|
|
return strings.TrimSpace(key.ProjectID)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (a *Account) VertexLocation(model string) string {
|
|
if a == nil {
|
|
return vertexDefaultLocation
|
|
}
|
|
if model != "" && a.Credentials != nil {
|
|
if raw, ok := a.Credentials["vertex_model_locations"].(map[string]any); ok {
|
|
if loc, ok := raw[model].(string); ok && strings.TrimSpace(loc) != "" {
|
|
return strings.TrimSpace(loc)
|
|
}
|
|
}
|
|
}
|
|
if v := strings.TrimSpace(a.GetCredential("location")); v != "" {
|
|
return v
|
|
}
|
|
if v := strings.TrimSpace(a.GetCredential("vertex_location")); v != "" {
|
|
return v
|
|
}
|
|
return vertexDefaultLocation
|
|
}
|
|
|
|
func parseVertexServiceAccountKey(account *Account) (*vertexServiceAccountKey, error) {
|
|
if account == nil || account.Credentials == nil {
|
|
return nil, errors.New("service account credentials not configured")
|
|
}
|
|
|
|
if raw := strings.TrimSpace(account.GetCredential("service_account_json")); raw != "" {
|
|
return parseVertexServiceAccountJSON([]byte(raw))
|
|
}
|
|
if raw := strings.TrimSpace(account.GetCredential("service_account")); raw != "" {
|
|
return parseVertexServiceAccountJSON([]byte(raw))
|
|
}
|
|
if nested, ok := account.Credentials["service_account_json"].(map[string]any); ok {
|
|
b, _ := json.Marshal(nested)
|
|
return parseVertexServiceAccountJSON(b)
|
|
}
|
|
if nested, ok := account.Credentials["service_account"].(map[string]any); ok {
|
|
b, _ := json.Marshal(nested)
|
|
return parseVertexServiceAccountJSON(b)
|
|
}
|
|
return nil, errors.New("service_account_json not found in credentials")
|
|
}
|
|
|
|
func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error) {
|
|
var key vertexServiceAccountKey
|
|
if err := json.Unmarshal(raw, &key); err != nil {
|
|
return nil, fmt.Errorf("invalid service account json: %w", err)
|
|
}
|
|
if strings.TrimSpace(key.ClientEmail) == "" {
|
|
return nil, errors.New("service account json missing client_email")
|
|
}
|
|
if strings.TrimSpace(key.PrivateKey) == "" {
|
|
return nil, errors.New("service account json missing private_key")
|
|
}
|
|
if strings.TrimSpace(key.ProjectID) == "" {
|
|
return nil, errors.New("service account json missing project_id")
|
|
}
|
|
// Always use the well-known Google token endpoint to prevent SSRF via crafted token_uri.
|
|
key.TokenURI = vertexDefaultTokenURL
|
|
return &key, nil
|
|
}
|
|
|
|
func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey) string {
|
|
fingerprint := ""
|
|
if key != nil {
|
|
sum := sha256.Sum256([]byte(key.ClientEmail + "\x00" + key.PrivateKeyID))
|
|
fingerprint = hex.EncodeToString(sum[:8])
|
|
}
|
|
if fingerprint == "" && account != nil {
|
|
fingerprint = fmt.Sprintf("account:%d", account.ID)
|
|
}
|
|
return "vertex:service_account:" + fingerprint
|
|
}
|
|
|
|
// getVertexServiceAccountAccessToken obtains an access token for a Vertex service account,
|
|
// using the shared cache and distributed lock to avoid redundant exchanges.
|
|
func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCache, account *Account) (string, error) {
|
|
key, err := parseVertexServiceAccountKey(account)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
cacheKey := vertexServiceAccountCacheKey(account, key)
|
|
|
|
if cache != nil {
|
|
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
|
return token, nil
|
|
}
|
|
}
|
|
|
|
locked := false
|
|
if cache != nil {
|
|
var lockErr error
|
|
locked, lockErr = cache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
|
if lockErr == nil && locked {
|
|
defer func() { _ = cache.ReleaseRefreshLock(ctx, cacheKey) }()
|
|
} else if lockErr != nil {
|
|
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
|
} else {
|
|
time.Sleep(vertexLockWaitTime)
|
|
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
|
return token, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if cache != nil {
|
|
_ = cache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
|
}
|
|
return accessToken, nil
|
|
}
|
|
|
|
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
|
|
now := time.Now()
|
|
claims := jwt.MapClaims{
|
|
"iss": key.ClientEmail,
|
|
"scope": vertexCloudPlatformScope,
|
|
"aud": key.TokenURI,
|
|
"iat": now.Unix(),
|
|
"exp": now.Add(time.Hour).Unix(),
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
if strings.TrimSpace(key.PrivateKeyID) != "" {
|
|
token.Header["kid"] = key.PrivateKeyID
|
|
}
|
|
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(key.PrivateKey))
|
|
if err != nil {
|
|
return "", 0, fmt.Errorf("parse service account private key: %w", err)
|
|
}
|
|
assertion, err := token.SignedString(privateKey)
|
|
if err != nil {
|
|
return "", 0, fmt.Errorf("sign service account assertion: %w", err)
|
|
}
|
|
|
|
values := url.Values{}
|
|
values.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
|
values.Set("assertion", assertion)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, key.TokenURI, strings.NewReader(values.Encode()))
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
client := &http.Client{Timeout: 15 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", 0, fmt.Errorf("service account token request failed: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
var parsed vertexTokenResponse
|
|
_ = json.Unmarshal(body, &parsed)
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
msg := strings.TrimSpace(parsed.ErrorDesc)
|
|
if msg == "" {
|
|
msg = strings.TrimSpace(parsed.Error)
|
|
}
|
|
if msg == "" {
|
|
msg = string(bytes.TrimSpace(body))
|
|
}
|
|
return "", 0, fmt.Errorf("service account token request returned %d: %s", resp.StatusCode, msg)
|
|
}
|
|
if strings.TrimSpace(parsed.AccessToken) == "" {
|
|
return "", 0, errors.New("service account token response missing access_token")
|
|
}
|
|
ttl := time.Duration(parsed.ExpiresIn) * time.Second
|
|
if ttl <= 0 {
|
|
ttl = time.Hour
|
|
}
|
|
if ttl > vertexServiceAccountCacheSkew {
|
|
ttl -= vertexServiceAccountCacheSkew
|
|
}
|
|
return parsed.AccessToken, ttl, nil
|
|
}
|
|
|
|
func buildVertexGeminiURL(projectID, location, model, action string, stream bool) (string, error) {
|
|
projectID = strings.TrimSpace(projectID)
|
|
location = strings.TrimSpace(location)
|
|
model = strings.TrimSpace(model)
|
|
action = strings.TrimSpace(action)
|
|
if projectID == "" {
|
|
return "", errors.New("vertex project_id is required")
|
|
}
|
|
if location == "" {
|
|
location = vertexDefaultLocation
|
|
}
|
|
if !vertexLocationPattern.MatchString(location) {
|
|
return "", fmt.Errorf("invalid vertex location: %s", location)
|
|
}
|
|
if model == "" {
|
|
return "", errors.New("vertex model is required")
|
|
}
|
|
switch action {
|
|
case "generateContent", "streamGenerateContent", "countTokens":
|
|
default:
|
|
return "", fmt.Errorf("unsupported vertex gemini action: %s", action)
|
|
}
|
|
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
|
|
if location == "global" {
|
|
host = "aiplatform.googleapis.com"
|
|
}
|
|
u := fmt.Sprintf(
|
|
"https://%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
|
host,
|
|
url.PathEscape(projectID),
|
|
url.PathEscape(location),
|
|
url.PathEscape(model),
|
|
action,
|
|
)
|
|
if stream {
|
|
u += "?alt=sse"
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
func buildVertexAnthropicURL(projectID, location, model string, stream bool) (string, error) {
|
|
projectID = strings.TrimSpace(projectID)
|
|
location = strings.TrimSpace(location)
|
|
model = strings.TrimSpace(model)
|
|
if projectID == "" {
|
|
return "", errors.New("vertex project_id is required")
|
|
}
|
|
if location == "" {
|
|
location = vertexDefaultLocation
|
|
}
|
|
if !vertexLocationPattern.MatchString(location) {
|
|
return "", fmt.Errorf("invalid vertex location: %s", location)
|
|
}
|
|
if model == "" {
|
|
return "", errors.New("vertex model is required")
|
|
}
|
|
action := "rawPredict"
|
|
if stream {
|
|
action = "streamRawPredict"
|
|
}
|
|
host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
|
|
if location == "global" {
|
|
host = "aiplatform.googleapis.com"
|
|
}
|
|
escapedModel := strings.ReplaceAll(url.PathEscape(model), "%40", "@")
|
|
return fmt.Sprintf(
|
|
"https://%s/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
|
host,
|
|
url.PathEscape(projectID),
|
|
url.PathEscape(location),
|
|
escapedModel,
|
|
action,
|
|
), nil
|
|
}
|
|
|
|
func normalizeVertexAnthropicModelID(model string) string {
|
|
model = strings.TrimSpace(model)
|
|
if model == "" || vertexAnthropicAlreadyDatedIDPattern.MatchString(model) {
|
|
return model
|
|
}
|
|
if m := vertexAnthropicDatedModelIDPattern.FindStringSubmatch(model); len(m) == 3 {
|
|
return m[1] + "@" + m[2]
|
|
}
|
|
return model
|
|
}
|
|
|
|
func buildVertexAnthropicRequestBody(body []byte) ([]byte, error) {
|
|
var payload map[string]any
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
return nil, fmt.Errorf("parse anthropic vertex request body: %w", err)
|
|
}
|
|
delete(payload, "model")
|
|
payload["anthropic_version"] = vertexAnthropicVersion
|
|
return json.Marshal(payload)
|
|
}
|