Files

1515 lines
38 KiB
Go
Raw Permalink Normal View History

// Package service provides business logic and domain services for the application.
package service
import (
"encoding/json"
"hash/fnv"
"reflect"
"sort"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
type Account struct {
ID int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
// RateMultiplier 账号计费倍率(>=0允许 0 表示该账号计费为 0
// 使用指针用于兼容旧版本调度缓存Redis中缺字段的情况nil 表示按 1.0 处理。
RateMultiplier *float64
LoadFactor *int // 调度负载因子nil 表示使用 Concurrency
2026-01-07 16:59:35 +08:00
Status string
ErrorMessage string
LastUsedAt *time.Time
ExpiresAt *time.Time
AutoPauseOnExpired bool
CreatedAt time.Time
UpdatedAt time.Time
Schedulable bool
RateLimitedAt *time.Time
RateLimitResetAt *time.Time
OverloadUntil *time.Time
TempUnschedulableUntil *time.Time
TempUnschedulableReason string
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
Proxy *Proxy
AccountGroups []AccountGroup
GroupIDs []int64
Groups []*Group
// model_mapping 热路径缓存(非持久化字段)
modelMappingCache map[string]string
modelMappingCacheReady bool
modelMappingCacheCredentialsPtr uintptr
modelMappingCacheRawPtr uintptr
modelMappingCacheRawLen int
modelMappingCacheRawSig uint64
}
type TempUnschedulableRule struct {
ErrorCode int `json:"error_code"`
Keywords []string `json:"keywords"`
DurationMinutes int `json:"duration_minutes"`
Description string `json:"description"`
}
func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
// BillingRateMultiplier 返回账号计费倍率。
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
// - 允许 0表示该账号计费为 0
// - 负数属于非法数据,出于安全考虑按 1.0 处理
func (a *Account) BillingRateMultiplier() float64 {
if a == nil || a.RateMultiplier == nil {
return 1.0
}
if *a.RateMultiplier < 0 {
return 1.0
}
return *a.RateMultiplier
}
func (a *Account) EffectiveLoadFactor() int {
if a == nil {
return 1
}
if a.LoadFactor != nil && *a.LoadFactor > 0 {
return *a.LoadFactor
}
if a.Concurrency > 0 {
return a.Concurrency
}
return 1
}
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
}
now := time.Now()
2026-01-07 16:59:35 +08:00
if a.AutoPauseOnExpired && a.ExpiresAt != nil && !now.Before(*a.ExpiresAt) {
return false
}
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
return true
}
func (a *Account) IsRateLimited() bool {
if a.RateLimitResetAt == nil {
return false
}
return time.Now().Before(*a.RateLimitResetAt)
}
func (a *Account) IsOverloaded() bool {
if a.OverloadUntil == nil {
return false
}
return time.Now().Before(*a.OverloadUntil)
}
func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
}
func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return ""
}
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
return "code_assist"
}
return oauthType
}
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
return tierID
}
func (a *Account) IsGeminiCodeAssist() bool {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return false
}
oauthType := a.GeminiOAuthType()
if oauthType == "" {
return strings.TrimSpace(a.GetCredential("project_id")) != ""
}
return oauthType == "code_assist"
}
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
func (a *Account) GetCredential(key string) string {
if a.Credentials == nil {
return ""
}
v, ok := a.Credentials[key]
if !ok || v == nil {
return ""
}
// 支持多种类型(兼容历史数据中 expires_at 等字段可能是数字或字符串)
switch val := v.(type) {
case string:
return val
case json.Number:
// GORM datatypes.JSONMap 使用 UseNumber() 解析,数字类型为 json.Number
return val.String()
case float64:
// JSON 解析后数字默认为 float64
return strconv.FormatInt(int64(val), 10)
case int64:
return strconv.FormatInt(val, 10)
case int:
return strconv.Itoa(val)
default:
return ""
}
}
// GetCredentialAsTime 解析凭证中的时间戳字段,支持多种格式
// 兼容以下格式:
// - RFC3339 字符串: "2025-01-01T00:00:00Z"
// - Unix 时间戳字符串: "1735689600"
// - Unix 时间戳数字: 1735689600 (float64/int64/json.Number)
func (a *Account) GetCredentialAsTime(key string) *time.Time {
s := a.GetCredential(key)
if s == "" {
return nil
}
// 尝试 RFC3339 格式
if t, err := time.Parse(time.RFC3339, s); err == nil {
return &t
}
// 尝试 Unix 时间戳(纯数字字符串)
if ts, err := strconv.ParseInt(s, 10, 64); err == nil {
t := time.Unix(ts, 0)
return &t
}
return nil
}
// GetCredentialAsInt64 解析凭证中的 int64 字段
// 用于读取 _token_version 等内部字段
func (a *Account) GetCredentialAsInt64(key string) int64 {
if a == nil || a.Credentials == nil {
return 0
}
val, ok := a.Credentials[key]
if !ok || val == nil {
return 0
}
switch v := val.(type) {
case int64:
return v
case float64:
return int64(v)
case int:
return int64(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return i
}
case string:
if i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
return i
}
}
return 0
}
func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil {
return false
}
raw, ok := a.Credentials["temp_unschedulable_enabled"]
if !ok || raw == nil {
return false
}
enabled, ok := raw.(bool)
return ok && enabled
}
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["temp_unschedulable_rules"]
if !ok || raw == nil {
return nil
}
arr, ok := raw.([]any)
if !ok {
return nil
}
rules := make([]TempUnschedulableRule, 0, len(arr))
for _, item := range arr {
entry, ok := item.(map[string]any)
if !ok || entry == nil {
continue
}
rule := TempUnschedulableRule{
ErrorCode: parseTempUnschedInt(entry["error_code"]),
Keywords: parseTempUnschedStrings(entry["keywords"]),
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
Description: parseTempUnschedString(entry["description"]),
}
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
continue
}
rules = append(rules, rule)
}
return rules
}
func parseTempUnschedString(value any) string {
s, ok := value.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
func parseTempUnschedStrings(value any) []string {
if value == nil {
return nil
}
var raw []string
switch v := value.(type) {
case []string:
raw = v
case []any:
raw = make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
raw = append(raw, s)
}
}
default:
return nil
}
out := make([]string, 0, len(raw))
for _, item := range raw {
s := strings.TrimSpace(item)
if s != "" {
out = append(out, s)
}
}
return out
}
2026-01-05 14:07:33 +08:00
func normalizeAccountNotes(value *string) *string {
if value == nil {
return nil
}
trimmed := strings.TrimSpace(*value)
if trimmed == "" {
return nil
}
return &trimmed
}
func parseTempUnschedInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}
func (a *Account) GetModelMapping() map[string]string {
credentialsPtr := mapPtr(a.Credentials)
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
rawPtr := mapPtr(rawMapping)
rawLen := len(rawMapping)
rawSig := uint64(0)
rawSigReady := false
if a.modelMappingCacheReady &&
a.modelMappingCacheCredentialsPtr == credentialsPtr &&
a.modelMappingCacheRawPtr == rawPtr &&
a.modelMappingCacheRawLen == rawLen {
rawSig = modelMappingSignature(rawMapping)
rawSigReady = true
if a.modelMappingCacheRawSig == rawSig {
return a.modelMappingCache
}
}
mapping := a.resolveModelMapping(rawMapping)
if !rawSigReady {
rawSig = modelMappingSignature(rawMapping)
}
a.modelMappingCache = mapping
a.modelMappingCacheReady = true
a.modelMappingCacheCredentialsPtr = credentialsPtr
a.modelMappingCacheRawPtr = rawPtr
a.modelMappingCacheRawLen = rawLen
a.modelMappingCacheRawSig = rawSig
return mapping
}
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
if a.Credentials == nil {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
return nil
}
if len(rawMapping) == 0 {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
return nil
}
result := make(map[string]string)
for k, v := range rawMapping {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
if a.Platform == domain.PlatformAntigravity {
ensureAntigravityDefaultPassthroughs(result, []string{
"gemini-3-flash",
"gemini-3.1-pro-high",
"gemini-3.1-pro-low",
})
}
return result
}
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
return nil
}
func mapPtr(m map[string]any) uintptr {
if m == nil {
return 0
}
return reflect.ValueOf(m).Pointer()
}
func modelMappingSignature(rawMapping map[string]any) uint64 {
if len(rawMapping) == 0 {
return 0
}
keys := make([]string, 0, len(rawMapping))
for k := range rawMapping {
keys = append(keys, k)
}
sort.Strings(keys)
h := fnv.New64a()
for _, k := range keys {
_, _ = h.Write([]byte(k))
_, _ = h.Write([]byte{0})
if v, ok := rawMapping[k].(string); ok {
_, _ = h.Write([]byte(v))
} else {
_, _ = h.Write([]byte{1})
}
_, _ = h.Write([]byte{0xff})
}
return h.Sum64()
}
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
if mapping == nil || model == "" {
return
}
if _, exists := mapping[model]; exists {
return
}
for pattern := range mapping {
if matchWildcard(pattern, model) {
return
}
}
mapping[model] = model
}
func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []string) {
for _, model := range models {
ensureAntigravityDefaultPassthrough(mapping, model)
}
}
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
// 如果未配置 mapping返回 true允许所有模型
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return true // 无映射 = 允许所有
}
// 精确匹配
if _, exists := mapping[requestedModel]; exists {
return true
}
// 通配符匹配
for pattern := range mapping {
if matchWildcard(pattern, requestedModel) {
return true
}
}
return false
}
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 如果未配置 mapping返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return requestedModel
}
// 精确匹配优先
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel
}
// 通配符匹配(最长优先)
return matchWildcardMapping(mapping, requestedModel)
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeAPIKey {
return ""
}
baseURL := a.GetCredential("base_url")
if baseURL == "" {
return "https://api.anthropic.com"
}
if a.Platform == PlatformAntigravity {
return strings.TrimRight(baseURL, "/") + "/antigravity"
}
return baseURL
}
// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。
// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。
func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string {
baseURL := strings.TrimSpace(a.GetCredential("base_url"))
if baseURL == "" {
return defaultBaseURL
}
if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey {
return strings.TrimRight(baseURL, "/") + "/antigravity"
}
return baseURL
}
func (a *Account) GetExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func (a *Account) GetClaudeUserID() string {
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
return v
}
return ""
}
// matchAntigravityWildcard 通配符匹配(仅支持末尾 *
// 用于 model_mapping 的通配符匹配
func matchAntigravityWildcard(pattern, str string) bool {
if strings.HasSuffix(pattern, "*") {
prefix := pattern[:len(pattern)-1]
return strings.HasPrefix(str, prefix)
}
return pattern == str
}
// matchWildcard 通用通配符匹配(仅支持末尾 *
// 复用 Antigravity 的通配符逻辑,供其他平台使用
func matchWildcard(pattern, str string) bool {
return matchAntigravityWildcard(pattern, str)
}
// matchWildcardMapping 通配符映射匹配(最长优先)
// 如果没有匹配,返回原始字符串
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
// 收集所有匹配的 pattern按长度降序排序最长优先
type patternMatch struct {
pattern string
target string
}
var matches []patternMatch
for pattern, target := range mapping {
if matchWildcard(pattern, requestedModel) {
matches = append(matches, patternMatch{pattern, target})
}
}
if len(matches) == 0 {
return requestedModel // 无匹配,返回原始模型名
}
// 按 pattern 长度降序排序
sort.Slice(matches, func(i, j int) bool {
if len(matches[i].pattern) != len(matches[j].pattern) {
return len(matches[i].pattern) > len(matches[j].pattern)
}
return matches[i].pattern < matches[j].pattern
})
return matches[0].target
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// IsPoolMode 检查 API Key 账号是否启用池模式。
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
func (a *Account) IsPoolMode() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["pool_mode"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
const (
defaultPoolModeRetryCount = 3
maxPoolModeRetryCount = 10
)
// GetPoolModeRetryCount 返回池模式同账号重试次数。
// 未配置或配置非法时回退为默认值 3小于 0 按 0 处理;过大则截断到 10。
func (a *Account) GetPoolModeRetryCount() int {
if a == nil || !a.IsPoolMode() || a.Credentials == nil {
return defaultPoolModeRetryCount
}
raw, ok := a.Credentials["pool_mode_retry_count"]
if !ok || raw == nil {
return defaultPoolModeRetryCount
}
count := parsePoolModeRetryCount(raw)
if count < 0 {
return 0
}
if count > maxPoolModeRetryCount {
return maxPoolModeRetryCount
}
return count
}
func parsePoolModeRetryCount(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return defaultPoolModeRetryCount
}
// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码
func isPoolModeRetryableStatus(statusCode int) bool {
switch statusCode {
case 401, 403, 429:
return true
default:
return false
}
}
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["custom_error_codes"]
if !ok || raw == nil {
return nil
}
if arr, ok := raw.([]any); ok {
result := make([]int, 0, len(arr))
for _, v := range arr {
if f, ok := v.(float64); ok {
result = append(result, int(f))
}
}
return result
}
return nil
}
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
if !a.IsCustomErrorCodesEnabled() {
return true
}
codes := a.GetCustomErrorCodes()
if len(codes) == 0 {
return true
}
for _, code := range codes {
if code == statusCode {
return true
}
}
return false
}
func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil {
return false
}
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}
func (a *Account) IsAnthropic() bool {
return a.Platform == PlatformAnthropic
}
func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
}
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeAPIKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
}
}
return "https://api.openai.com"
}
func (a *Account) GetOpenAIAccessToken() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("access_token")
}
func (a *Account) GetOpenAIRefreshToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("refresh_token")
}
func (a *Account) GetOpenAIIDToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("id_token")
}
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
return ""
}
return a.GetCredential("api_key")
}
func (a *Account) GetOpenAIUserAgent() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("user_agent")
}
func (a *Account) GetChatGPTAccountID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_account_id")
}
func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_user_id")
}
func (a *Account) GetOpenAIOrganizationID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("organization_id")
}
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() {
return nil
}
return a.GetCredentialAsTime("expires_at")
}
func (a *Account) IsOpenAITokenExpired() bool {
expiresAt := a.GetOpenAITokenExpiresAt()
if expiresAt == nil {
return false
}
return time.Now().Add(60 * time.Second).After(*expiresAt)
}
// IsMixedSchedulingEnabled 检查 antigravity 账户是否启用混合调度
// 启用后可参与 anthropic/gemini 分组的账户调度
func (a *Account) IsMixedSchedulingEnabled() bool {
if a.Platform != PlatformAntigravity {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["mixed_scheduling"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
//
// 新字段accounts.extra.openai_passthrough。
// 兼容字段accounts.extra.openai_oauth_passthrough历史 OAuth 开关)。
// 字段缺失或类型不正确时,按 false关闭处理。
func (a *Account) IsOpenAIPassthroughEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
if enabled, ok := a.Extra["openai_passthrough"].(bool); ok {
return enabled
}
if enabled, ok := a.Extra["openai_oauth_passthrough"].(bool); ok {
return enabled
}
return false
}
// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。
//
// 分类型新字段:
// - OAuth 账号accounts.extra.openai_oauth_responses_websockets_v2_enabled
// - API Key 账号accounts.extra.openai_apikey_responses_websockets_v2_enabled
//
// 兼容字段:
// - accounts.extra.responses_websockets_v2_enabled
// - accounts.extra.openai_ws_enabled历史开关
//
// 优先级:
// 1. 按账号类型读取分类型字段
// 2. 分类型字段缺失时,回退兼容字段
func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
if a.IsOpenAIOAuth() {
if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
}
if a.IsOpenAIApiKey() {
if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
}
if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok {
return enabled
}
return false
}
const (
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
OpenAIWSIngressModeCtxPool = "ctx_pool"
OpenAIWSIngressModePassthrough = "passthrough"
)
func normalizeOpenAIWSIngressMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAIWSIngressModeOff:
return OpenAIWSIngressModeOff
case OpenAIWSIngressModeCtxPool:
return OpenAIWSIngressModeCtxPool
case OpenAIWSIngressModePassthrough:
return OpenAIWSIngressModePassthrough
case OpenAIWSIngressModeShared:
return OpenAIWSIngressModeShared
case OpenAIWSIngressModeDedicated:
return OpenAIWSIngressModeDedicated
default:
return ""
}
}
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated {
return OpenAIWSIngressModeCtxPool
}
return normalized
}
return OpenAIWSIngressModeCtxPool
}
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式off/ctx_pool/passthrough
//
// 优先级:
// 1. 分类型 mode 新字段string
// 2. 分类型 enabled 旧字段bool
// 3. 兼容 enabled 旧字段bool
// 4. defaultMode非法时回退 ctx_pool
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
if a == nil || !a.IsOpenAI() {
return OpenAIWSIngressModeOff
}
if a.Extra == nil {
return resolvedDefault
}
resolveModeString := func(key string) (string, bool) {
raw, ok := a.Extra[key]
if !ok {
return "", false
}
mode, ok := raw.(string)
if !ok {
return "", false
}
normalized := normalizeOpenAIWSIngressMode(mode)
if normalized == "" {
return "", false
}
return normalized, true
}
resolveBoolMode := func(key string) (string, bool) {
raw, ok := a.Extra[key]
if !ok {
return "", false
}
enabled, ok := raw.(bool)
if !ok {
return "", false
}
if enabled {
return OpenAIWSIngressModeCtxPool, true
}
return OpenAIWSIngressModeOff, true
}
if a.IsOpenAIOAuth() {
if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok {
return mode
}
}
if a.IsOpenAIApiKey() {
if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok {
return mode
}
}
if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
return mode
}
// 兼容旧值shared/dedicated 语义都归并到 ctx_pool。
if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated {
return OpenAIWSIngressModeCtxPool
}
return resolvedDefault
}
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
// 字段accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["openai_ws_force_http"].(bool)
return ok && enabled
}
// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。
// 字段accounts.extra.openai_ws_allow_store_recovery。
func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool)
return ok && enabled
}
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
}
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。
// 字段accounts.extra.anthropic_passthrough。
// 字段缺失或类型不正确时,按 false关闭处理。
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
return false
}
enabled, ok := a.Extra["anthropic_passthrough"].(bool)
return ok && enabled
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
// 字段accounts.extra.codex_cli_only。
// 字段缺失或类型不正确时,按 false关闭处理。
func (a *Account) IsCodexCLIOnlyEnabled() bool {
if a == nil || !a.IsOpenAIOAuth() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["codex_cli_only"].(bool)
return ok && enabled
}
// WindowCostSchedulability 窗口费用调度状态
type WindowCostSchedulability int
const (
// WindowCostSchedulable 可正常调度
WindowCostSchedulable WindowCostSchedulability = iota
// WindowCostStickyOnly 仅允许粘性会话
WindowCostStickyOnly
// WindowCostNotSchedulable 完全不可调度
WindowCostNotSchedulable
)
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
}
// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
func (a *Account) IsTLSFingerprintEnabled() bool {
// 仅支持 Anthropic OAuth/SetupToken 账号
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["enable_tls_fingerprint"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetUserMsgQueueMode 获取用户消息队列模式
// "serialize" = 串行队列, "throttle" = 软性限速, "" = 未设置(使用全局配置)
func (a *Account) GetUserMsgQueueMode() string {
if a.Extra == nil {
return ""
}
// 优先读取新字段 user_msg_queue_mode白名单校验非法值视为未设置
if mode, ok := a.Extra["user_msg_queue_mode"].(string); ok && mode != "" {
if mode == config.UMQModeSerialize || mode == config.UMQModeThrottle {
return mode
}
return "" // 非法值 fallback 到全局配置
}
// 向后兼容: user_msg_queue_enabled: true → "serialize"
if enabled, ok := a.Extra["user_msg_queue_enabled"].(bool); ok && enabled {
return config.UMQModeSerialize
}
return ""
}
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将在一段时间内15分钟固定 metadata.user_id 中的 session ID
// 使上游认为请求来自同一个会话
func (a *Account) IsSessionIDMaskingEnabled() bool {
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["session_id_masking_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型5m 或 1h
func (a *Account) IsCacheTTLOverrideEnabled() bool {
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["cache_ttl_override_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetCacheTTLOverrideTarget 获取缓存 TTL 强制替换的目标类型
// 返回 "5m" 或 "1h",默认 "5m"
func (a *Account) GetCacheTTLOverrideTarget() string {
if a.Extra == nil {
return "5m"
}
if v, ok := a.Extra["cache_ttl_override_target"]; ok {
if target, ok := v.(string); ok && (target == "5m" || target == "1h") {
return target
}
}
return "5m"
}
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
// 返回 0 表示未启用
func (a *Account) GetQuotaLimit() float64 {
return a.getExtraFloat64("quota_limit")
}
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
func (a *Account) GetQuotaUsed() float64 {
return a.getExtraFloat64("quota_used")
}
// GetQuotaDailyLimit 获取日额度限制美元0 表示未启用
func (a *Account) GetQuotaDailyLimit() float64 {
return a.getExtraFloat64("quota_daily_limit")
}
// GetQuotaDailyUsed 获取当日已用额度(美元)
func (a *Account) GetQuotaDailyUsed() float64 {
return a.getExtraFloat64("quota_daily_used")
}
// GetQuotaWeeklyLimit 获取周额度限制美元0 表示未启用
func (a *Account) GetQuotaWeeklyLimit() float64 {
return a.getExtraFloat64("quota_weekly_limit")
}
// GetQuotaWeeklyUsed 获取本周已用额度(美元)
func (a *Account) GetQuotaWeeklyUsed() float64 {
return a.getExtraFloat64("quota_weekly_used")
}
// getExtraFloat64 从 Extra 中读取指定 key 的 float64 值
func (a *Account) getExtraFloat64(key string) float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra[key]; ok {
return parseExtraFloat64(v)
}
return 0
}
// getExtraTime 从 Extra 中读取 RFC3339 时间戳
func (a *Account) getExtraTime(key string) time.Time {
if a.Extra == nil {
return time.Time{}
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
return t
}
if t, err := time.Parse(time.RFC3339, s); err == nil {
return t
}
}
}
return time.Time{}
}
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
func (a *Account) HasAnyQuotaLimit() bool {
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
}
// isPeriodExpired 检查指定周期(自 periodStart 起经过 dur是否已过期
func isPeriodExpired(periodStart time.Time, dur time.Duration) bool {
if periodStart.IsZero() {
return true // 从未使用过,视为过期(下次 increment 会初始化)
}
return time.Since(periodStart) >= dur
}
// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true
func (a *Account) IsQuotaExceeded() bool {
// 总额度
if limit := a.GetQuotaLimit(); limit > 0 && a.GetQuotaUsed() >= limit {
return true
}
// 日额度(周期过期视为未超限,下次 increment 会重置)
if limit := a.GetQuotaDailyLimit(); limit > 0 {
start := a.getExtraTime("quota_daily_start")
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
return true
}
}
// 周额度
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
start := a.getExtraTime("quota_weekly_start")
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
return true
}
}
return false
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["window_cost_limit"]; ok {
return parseExtraFloat64(v)
}
return 0
}
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
// 默认值为 10
func (a *Account) GetWindowCostStickyReserve() float64 {
if a.Extra == nil {
return 10.0
}
if v, ok := a.Extra["window_cost_sticky_reserve"]; ok {
val := parseExtraFloat64(v)
if val > 0 {
return val
}
}
return 10.0
}
// GetMaxSessions 获取最大并发会话数
// 返回 0 表示未启用
func (a *Account) GetMaxSessions() int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["max_sessions"]; ok {
return parseExtraInt(v)
}
return 0
}
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
// 默认值为 5 分钟
func (a *Account) GetSessionIdleTimeoutMinutes() int {
if a.Extra == nil {
return 5
}
if v, ok := a.Extra["session_idle_timeout_minutes"]; ok {
val := parseExtraInt(v)
if val > 0 {
return val
}
}
return 5
}
// GetBaseRPM 获取基础 RPM 限制
// 返回 0 表示未启用(负数视为无效配置,按 0 处理)
func (a *Account) GetBaseRPM() int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["base_rpm"]; ok {
val := parseExtraInt(v)
if val > 0 {
return val
}
}
return 0
}
// GetRPMStrategy 获取 RPM 策略
// "tiered" = 三区模型(默认), "sticky_exempt" = 粘性豁免
func (a *Account) GetRPMStrategy() string {
if a.Extra == nil {
return "tiered"
}
if v, ok := a.Extra["rpm_strategy"]; ok {
if s, ok := v.(string); ok && s == "sticky_exempt" {
return "sticky_exempt"
}
}
return "tiered"
}
// GetRPMStickyBuffer 获取 RPM 粘性缓冲数量
// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1
func (a *Account) GetRPMStickyBuffer() int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["rpm_sticky_buffer"]; ok {
val := parseExtraInt(v)
if val > 0 {
return val
}
}
base := a.GetBaseRPM()
buffer := base / 5
if buffer < 1 && base > 0 {
buffer = 1
}
return buffer
}
// CheckRPMSchedulability 根据当前 RPM 计数检查调度状态
// 复用 WindowCostSchedulability 三态Schedulable / StickyOnly / NotSchedulable
func (a *Account) CheckRPMSchedulability(currentRPM int) WindowCostSchedulability {
baseRPM := a.GetBaseRPM()
if baseRPM <= 0 {
return WindowCostSchedulable
}
if currentRPM < baseRPM {
return WindowCostSchedulable
}
strategy := a.GetRPMStrategy()
if strategy == "sticky_exempt" {
return WindowCostStickyOnly // 粘性豁免无红区
}
// tiered: 黄区 + 红区
buffer := a.GetRPMStickyBuffer()
if currentRPM < baseRPM+buffer {
return WindowCostStickyOnly
}
return WindowCostNotSchedulable
}
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable可正常调度
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly仅粘性会话
// - 费用 >= 阈值+预留: WindowCostNotSchedulable不可调度
func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) WindowCostSchedulability {
limit := a.GetWindowCostLimit()
if limit <= 0 {
return WindowCostSchedulable
}
if currentWindowCost < limit {
return WindowCostSchedulable
}
stickyReserve := a.GetWindowCostStickyReserve()
if currentWindowCost < limit+stickyReserve {
return WindowCostStickyOnly
}
return WindowCostNotSchedulable
}
// GetCurrentWindowStartTime 获取当前有效的窗口开始时间
// 逻辑:
// 1. 如果窗口未过期SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart
// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始)
func (a *Account) GetCurrentWindowStartTime() time.Time {
now := time.Now()
// 窗口未过期,使用记录的窗口开始时间
if a.SessionWindowStart != nil && a.SessionWindowEnd != nil && now.Before(*a.SessionWindowEnd) {
return *a.SessionWindowStart
}
// 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始)
// 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致
return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func parseExtraFloat64(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
if f, err := v.Float64(); err == nil {
return f
}
case string:
if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil {
return f
}
}
return 0
}
// parseExtraInt 从 extra 字段解析 int 值
// ParseExtraInt 从 extra 字段的 any 值解析为 int。
// 支持 int, int64, float64, json.Number, string 类型,无法解析时返回 0。
func ParseExtraInt(value any) int {
return parseExtraInt(value)
}
func parseExtraInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}