Files
sub2api/backend/internal/service/identity_service.go

461 lines
16 KiB
Go
Raw Permalink Normal View History

2025-12-18 13:50:39 +08:00
package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log/slog"
2025-12-18 13:50:39 +08:00
"net/http"
"regexp"
"strconv"
"strings"
2025-12-18 13:50:39 +08:00
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
2025-12-18 13:50:39 +08:00
)
// 预编译正则表达式(避免每次调用重新编译)
var (
// 匹配 user_id 格式:
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
2025-12-18 13:50:39 +08:00
// 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
)
// 默认指纹值(当客户端未提供时使用)
2025-12-25 17:15:01 +08:00
var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.1.22 (external, cli)",
2025-12-18 13:50:39 +08:00
StainlessLang: "js",
StainlessPackageVersion: "0.70.0",
2025-12-18 13:50:39 +08:00
StainlessOS: "Linux",
StainlessArch: "arm64",
2025-12-18 13:50:39 +08:00
StainlessRuntime: "node",
StainlessRuntimeVersion: "v24.13.0",
2025-12-18 13:50:39 +08:00
}
2025-12-25 17:15:01 +08:00
// Fingerprint represents account fingerprint data
type Fingerprint struct {
ClientID string
UserAgent string
StainlessLang string
StainlessPackageVersion string
StainlessOS string
StainlessArch string
StainlessRuntime string
StainlessRuntimeVersion string
UpdatedAt int64 `json:",omitempty"` // Unix timestamp用于判断是否需要续期TTL
2025-12-25 17:15:01 +08:00
}
// IdentityCache defines cache operations for identity service
type IdentityCache interface {
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
// GetMaskedSessionID 获取固定的会话ID用于会话ID伪装功能
// 返回的 sessionID 是一个 UUID 格式的字符串
// 如果不存在或已过期15分钟无请求返回空字符串
GetMaskedSessionID(ctx context.Context, accountID int64) (string, error)
// SetMaskedSessionID 设置固定的会话IDTTL 为 15 分钟
// 每次调用都会刷新 TTL
SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error
2025-12-25 17:15:01 +08:00
}
2025-12-18 13:50:39 +08:00
// IdentityService 管理OAuth账号的请求身份指纹
type IdentityService struct {
2025-12-25 17:15:01 +08:00
cache IdentityCache
2025-12-18 13:50:39 +08:00
}
// NewIdentityService 创建新的IdentityService
2025-12-25 17:15:01 +08:00
func NewIdentityService(cache IdentityCache) *IdentityService {
return &IdentityService{cache: cache}
2025-12-18 13:50:39 +08:00
}
// GetOrCreateFingerprint 获取或创建账号的指纹
// 如果缓存存在检测user-agent版本新版本则更新
// 如果缓存不存在生成随机ClientID并从请求头创建指纹然后缓存
2025-12-25 17:15:01 +08:00
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
// 尝试从缓存获取指纹
cached, err := s.cache.GetFingerprint(ctx, accountID)
if err == nil && cached != nil {
needWrite := false
// 检查客户端的user-agent是否是更新版本
clientUA := headers.Get("User-Agent")
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
// 版本升级merge 语义 — 仅更新请求中实际携带的字段,保留缓存值
// 避免缺失的头被硬编码默认值覆盖(如新 CLI 版本 + 旧 SDK 默认值的不一致)
mergeHeadersIntoFingerprint(cached, headers)
needWrite = true
logger.LegacyPrintf("service.identity", "Updated fingerprint for account %d: %s (merge update)", accountID, clientUA)
} else if time.Since(time.Unix(cached.UpdatedAt, 0)) > 24*time.Hour {
// 距上次写入超过24小时续期TTL
needWrite = true
}
if needWrite {
cached.UpdatedAt = time.Now().Unix()
if err := s.cache.SetFingerprint(ctx, accountID, cached); err != nil {
logger.LegacyPrintf("service.identity", "Warning: failed to refresh fingerprint for account %d: %v", accountID, err)
}
2025-12-18 13:50:39 +08:00
}
return cached, nil
2025-12-18 13:50:39 +08:00
}
// 缓存不存在或解析失败,创建新指纹
fp := s.createFingerprintFromHeaders(headers)
// 生成随机ClientID
fp.ClientID = generateClientID()
fp.UpdatedAt = time.Now().Unix()
2025-12-18 13:50:39 +08:00
// 保存到缓存7天TTL每24小时自动续期
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
logger.LegacyPrintf("service.identity", "Warning: failed to cache fingerprint for account %d: %v", accountID, err)
2025-12-18 13:50:39 +08:00
}
logger.LegacyPrintf("service.identity", "Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
2025-12-18 13:50:39 +08:00
return fp, nil
}
// createFingerprintFromHeaders 从请求头创建指纹
2025-12-25 17:15:01 +08:00
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
fp := &Fingerprint{}
2025-12-18 13:50:39 +08:00
// 获取User-Agent
if ua := headers.Get("User-Agent"); ua != "" {
fp.UserAgent = ua
} else {
fp.UserAgent = defaultFingerprint.UserAgent
}
// 获取x-stainless-*头,如果没有则使用默认值
fp.StainlessLang = getHeaderOrDefault(headers, "X-Stainless-Lang", defaultFingerprint.StainlessLang)
fp.StainlessPackageVersion = getHeaderOrDefault(headers, "X-Stainless-Package-Version", defaultFingerprint.StainlessPackageVersion)
fp.StainlessOS = getHeaderOrDefault(headers, "X-Stainless-OS", defaultFingerprint.StainlessOS)
fp.StainlessArch = getHeaderOrDefault(headers, "X-Stainless-Arch", defaultFingerprint.StainlessArch)
fp.StainlessRuntime = getHeaderOrDefault(headers, "X-Stainless-Runtime", defaultFingerprint.StainlessRuntime)
fp.StainlessRuntimeVersion = getHeaderOrDefault(headers, "X-Stainless-Runtime-Version", defaultFingerprint.StainlessRuntimeVersion)
return fp
}
// mergeHeadersIntoFingerprint 将请求头中实际存在的字段合并到现有指纹中(用于版本升级场景)
// 关键语义:请求中有的字段 → 用新值覆盖;缺失的头 → 保留缓存中的已有值
// 与 createFingerprintFromHeaders 的区别:后者用于首次创建,缺失头回退到 defaultFingerprint
// 本函数用于升级更新,缺失头保留缓存值,避免将已知的真实值退化为硬编码默认值
func mergeHeadersIntoFingerprint(fp *Fingerprint, headers http.Header) {
// User-Agent版本升级的触发条件一定存在
if ua := headers.Get("User-Agent"); ua != "" {
fp.UserAgent = ua
}
// X-Stainless-* 头:仅在请求中实际携带时才更新,否则保留缓存值
mergeHeader(headers, "X-Stainless-Lang", &fp.StainlessLang)
mergeHeader(headers, "X-Stainless-Package-Version", &fp.StainlessPackageVersion)
mergeHeader(headers, "X-Stainless-OS", &fp.StainlessOS)
mergeHeader(headers, "X-Stainless-Arch", &fp.StainlessArch)
mergeHeader(headers, "X-Stainless-Runtime", &fp.StainlessRuntime)
mergeHeader(headers, "X-Stainless-Runtime-Version", &fp.StainlessRuntimeVersion)
}
// mergeHeader 如果请求头中存在该字段则更新目标值,否则保留原值
func mergeHeader(headers http.Header, key string, target *string) {
if v := headers.Get(key); v != "" {
*target = v
}
}
2025-12-18 13:50:39 +08:00
// getHeaderOrDefault 获取header值如果不存在则返回默认值
func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
if v := headers.Get(key); v != "" {
return v
}
return defaultValue
}
// ApplyFingerprint 将指纹应用到请求头覆盖原有的x-stainless-*头)
2025-12-25 17:15:01 +08:00
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
2025-12-18 13:50:39 +08:00
if fp == nil {
return
}
2025-12-22 22:58:31 +08:00
// 设置user-agent
2025-12-18 13:50:39 +08:00
if fp.UserAgent != "" {
2025-12-22 22:58:31 +08:00
req.Header.Set("user-agent", fp.UserAgent)
2025-12-18 13:50:39 +08:00
}
2025-12-22 22:58:31 +08:00
// 设置x-stainless-*头
2025-12-18 13:50:39 +08:00
if fp.StainlessLang != "" {
req.Header.Set("X-Stainless-Lang", fp.StainlessLang)
}
if fp.StainlessPackageVersion != "" {
req.Header.Set("X-Stainless-Package-Version", fp.StainlessPackageVersion)
}
if fp.StainlessOS != "" {
req.Header.Set("X-Stainless-OS", fp.StainlessOS)
}
if fp.StainlessArch != "" {
req.Header.Set("X-Stainless-Arch", fp.StainlessArch)
}
if fp.StainlessRuntime != "" {
req.Header.Set("X-Stainless-Runtime", fp.StainlessRuntime)
}
if fp.StainlessRuntimeVersion != "" {
req.Header.Set("X-Stainless-Runtime-Version", fp.StainlessRuntimeVersion)
}
}
// RewriteUserID 重写body中的metadata.user_id
// 输入格式user_{clientId}_account__session_{sessionUUID}
// 输出格式user_{cachedClientID}_account_{accountUUID}_session_{newHash}
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
2025-12-18 13:50:39 +08:00
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
return body, nil
}
// 使用 RawMessage 保留其他字段的原始字节
var reqMap map[string]json.RawMessage
2025-12-18 13:50:39 +08:00
if err := json.Unmarshal(body, &reqMap); err != nil {
return body, nil
}
// 解析 metadata 字段
metadataRaw, ok := reqMap["metadata"]
2025-12-18 13:50:39 +08:00
if !ok {
return body, nil
}
var metadata map[string]any
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
return body, nil
}
2025-12-18 13:50:39 +08:00
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
return body, nil
}
// 匹配格式:
// 旧格式: user_{64位hex}_account__session_{uuid}
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
2025-12-18 13:50:39 +08:00
matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil {
return body, nil
}
// matches[1] = account UUID (可能为空), matches[2] = session UUID
sessionTail := matches[2] // 原始session UUID
2025-12-18 13:50:39 +08:00
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
newSessionHash := generateUUIDFromSeed(seed)
// 构建新的user_id
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash}
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
metadata["user_id"] = newUserID
// 只重新序列化 metadata 字段
newMetadataRaw, err := json.Marshal(metadata)
if err != nil {
return body, nil
}
reqMap["metadata"] = newMetadataRaw
2025-12-18 13:50:39 +08:00
return json.Marshal(reqMap)
}
// RewriteUserIDWithMasking 重写body中的metadata.user_id支持会话ID伪装
// 如果账号启用了会话ID伪装session_id_masking_enabled
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID15分钟内保持不变
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
// 先执行常规的 RewriteUserID 逻辑
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
if err != nil {
return newBody, err
}
// 检查是否启用会话ID伪装
if !account.IsSessionIDMaskingEnabled() {
return newBody, nil
}
// 使用 RawMessage 保留其他字段的原始字节
var reqMap map[string]json.RawMessage
if err := json.Unmarshal(newBody, &reqMap); err != nil {
return newBody, nil
}
// 解析 metadata 字段
metadataRaw, ok := reqMap["metadata"]
if !ok {
return newBody, nil
}
var metadata map[string]any
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
return newBody, nil
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
return newBody, nil
}
// 查找 _session_ 的位置,替换其后的内容
const sessionMarker = "_session_"
idx := strings.LastIndex(userID, sessionMarker)
if idx == -1 {
return newBody, nil
}
// 获取或生成固定的伪装 session ID
maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID)
if err != nil {
logger.LegacyPrintf("service.identity", "Warning: failed to get masked session ID for account %d: %v", account.ID, err)
return newBody, nil
}
if maskedSessionID == "" {
// 首次或已过期,生成新的伪装 session ID
maskedSessionID = generateRandomUUID()
logger.LegacyPrintf("service.identity", "Generated new masked session ID for account %d: %s", account.ID, maskedSessionID)
}
// 刷新 TTL每次请求都刷新保持 15 分钟有效期)
if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil {
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
}
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
slog.Debug("session_id_masking_applied",
"account_id", account.ID,
"before", userID,
"after", newUserID,
)
metadata["user_id"] = newUserID
// 只重新序列化 metadata 字段
newMetadataRaw, marshalErr := json.Marshal(metadata)
if marshalErr != nil {
return newBody, nil
}
reqMap["metadata"] = newMetadataRaw
return json.Marshal(reqMap)
}
// generateRandomUUID 生成随机 UUID v4 格式字符串
func generateRandomUUID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
// fallback: 使用时间戳生成
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
b = h[:16]
}
// 设置 UUID v4 版本和变体位
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
}
2025-12-18 13:50:39 +08:00
// generateClientID 生成64位十六进制客户端ID32字节随机数
func generateClientID() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
// 极罕见的情况,使用时间戳+固定值作为fallback
logger.LegacyPrintf("service.identity", "Warning: crypto/rand.Read failed: %v, using fallback", err)
2025-12-18 13:50:39 +08:00
// 使用SHA256(当前纳秒时间)作为fallback
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
return hex.EncodeToString(h[:])
}
return hex.EncodeToString(b)
}
// generateUUIDFromSeed 从种子生成确定性UUID v4格式字符串
func generateUUIDFromSeed(seed string) string {
hash := sha256.Sum256([]byte(seed))
bytes := hash[:16]
// 设置UUID v4版本和变体位
bytes[6] = (bytes[6] & 0x0f) | 0x40
bytes[8] = (bytes[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
}
// parseUserAgentVersion 解析user-agent版本号
// 例如claude-cli/2.1.2 -> (2, 1, 2)
2025-12-18 13:50:39 +08:00
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
// 匹配 xxx/x.y.z 格式
matches := userAgentVersionRegex.FindStringSubmatch(ua)
if len(matches) != 4 {
return 0, 0, 0, false
}
major, _ = strconv.Atoi(matches[1])
minor, _ = strconv.Atoi(matches[2])
patch, _ = strconv.Atoi(matches[3])
return major, minor, patch, true
}
// extractProduct 提取 User-Agent 中 "/" 前的产品名
// 例如claude-cli/2.1.22 (external, cli) -> "claude-cli"
func extractProduct(ua string) string {
if idx := strings.Index(ua, "/"); idx > 0 {
return strings.ToLower(ua[:idx])
}
return ""
}
2025-12-18 13:50:39 +08:00
// isNewerVersion 比较版本号判断newUA是否比cachedUA更新
// 要求产品名一致(防止浏览器 UA 如 Mozilla/5.0 误判为更新版本)
2025-12-18 13:50:39 +08:00
func isNewerVersion(newUA, cachedUA string) bool {
// 校验产品名一致性
newProduct := extractProduct(newUA)
cachedProduct := extractProduct(cachedUA)
if newProduct == "" || cachedProduct == "" || newProduct != cachedProduct {
return false
}
2025-12-18 13:50:39 +08:00
newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA)
cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA)
if !newOk || !cachedOk {
return false
}
// 比较版本号
if newMajor > cachedMajor {
return true
}
if newMajor < cachedMajor {
return false
}
if newMinor > cachedMinor {
return true
}
if newMinor < cachedMinor {
return false
}
return newPatch > cachedPatch
}