diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 2ace9f61..49e2b412 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7706,8 +7706,109 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage } } +// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 +type recordUsageOpts struct { + // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入) + ParsedRequest *ParsedRequest + + // EnableClaudePath 启用 Claude 路径特有逻辑: + // - Claude Max 缓存计费策略 + // - Sora 媒体类型分支(image/video/prompt) + // - MediaType 字段写入使用日志 + EnableClaudePath bool + + // 长上下文计费(仅 Gemini 路径需要) + LongContextThreshold int + LongContextMultiplier float64 +} + // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { + return s.recordUsageCore(ctx, &recordUsageCoreInput{ + Result: input.Result, + APIKey: input.APIKey, + User: input.User, + Account: input.Account, + Subscription: input.Subscription, + InboundEndpoint: input.InboundEndpoint, + UpstreamEndpoint: input.UpstreamEndpoint, + UserAgent: input.UserAgent, + IPAddress: input.IPAddress, + RequestPayloadHash: input.RequestPayloadHash, + ForceCacheBilling: input.ForceCacheBilling, + APIKeyService: input.APIKeyService, + ChannelUsageFields: input.ChannelUsageFields, + }, &recordUsageOpts{ + ParsedRequest: input.ParsedRequest, + EnableClaudePath: true, + }) +} + +// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) +type RecordUsageLongContextInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + InboundEndpoint string // 入站端点(客户端请求路径) + UpstreamEndpoint string // 上游端点(标准化后的上游路径) + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + LongContextThreshold int // 长上下文阈值(如 200000) + LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) + + ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) +} + +// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) +func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { + return s.recordUsageCore(ctx, &recordUsageCoreInput{ + Result: input.Result, + APIKey: input.APIKey, + User: input.User, + Account: input.Account, + Subscription: input.Subscription, + InboundEndpoint: input.InboundEndpoint, + UpstreamEndpoint: input.UpstreamEndpoint, + UserAgent: input.UserAgent, + IPAddress: input.IPAddress, + RequestPayloadHash: input.RequestPayloadHash, + ForceCacheBilling: input.ForceCacheBilling, + APIKeyService: input.APIKeyService, + ChannelUsageFields: input.ChannelUsageFields, + }, &recordUsageOpts{ + LongContextThreshold: input.LongContextThreshold, + LongContextMultiplier: input.LongContextMultiplier, + }) +} + +// recordUsageCoreInput 是 recordUsageCore 的公共输入字段,从两种输入结构体中提取。 +type recordUsageCoreInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + InboundEndpoint string + UpstreamEndpoint string + UserAgent string + IPAddress string + RequestPayloadHash string + ForceCacheBilling bool + APIKeyService APIKeyQuotaUpdater + ChannelUsageFields +} + +// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 +// opts 中的字段控制两者之间的差异行为: +// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 +// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt) +// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext +func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { result := input.Result apiKey := input.APIKey user := input.User @@ -7723,9 +7824,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu result.Usage.InputTokens = 0 } - // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + // Claude Max cache billing policy(仅 Claude 路径启用) cacheTTLOverridden := false - if account.IsCacheTTLOverrideEnabled() { + simulatedClaudeMax := false + if opts.EnableClaudePath { + var apiKeyGroup *Group + if apiKey != nil { + apiKeyGroup = apiKey.Group + } + claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID) + simulatedClaudeMax = claudeMaxOutcome.Simulated || + (shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage)) + } + + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax { applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 } @@ -7740,7 +7853,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } - var cost *CostBreakdown // 确定计费模型 billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" { @@ -7756,100 +7868,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu requestedModel = input.OriginalModel } - // 根据请求类型选择计费方式 - if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo { - var soraConfig *SoraPriceConfig - if apiKey.Group != nil { - soraConfig = &SoraPriceConfig{ - ImagePrice360: apiKey.Group.SoraImagePrice360, - ImagePrice540: apiKey.Group.SoraImagePrice540, - VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, - VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, - } - } - if result.MediaType == MediaTypeImage { - cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) - } else { - cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) - } - } else if result.MediaType == MediaTypePrompt { - cost = &CostBreakdown{} - } else if result.ImageCount > 0 { - // 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本) - hasChannelPricing := false - if s.resolver != nil && apiKey.Group != nil { - gid := apiKey.Group.ID - resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) - if resolved.Source == PricingSourceChannel { - hasChannelPricing = true - } - } - if hasChannelPricing { - // 渠道定价优先 → 由 CalculateCostUnified 按 resolved.Mode 分发计费 - tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - ImageOutputTokens: result.Usage.ImageOutputTokens, - } - gid := apiKey.Group.ID - var err error - cost, err = s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: &gid, - Tokens: tokens, - RequestCount: 1, - RateMultiplier: multiplier, - Resolver: s.resolver, - }) - if err != nil { - logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) - cost = &CostBreakdown{ActualCost: 0} - } - } else { - // 无渠道定价 → 走按次计费(默认,兼容旧版本) - var groupConfig *ImagePriceConfig - if apiKey.Group != nil { - groupConfig = &ImagePriceConfig{ - Price1K: apiKey.Group.ImagePrice1K, - Price2K: apiKey.Group.ImagePrice2K, - Price4K: apiKey.Group.ImagePrice4K, - } - } - cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) - } - } else { - // Token 计费 - tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, - CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - ImageOutputTokens: result.Usage.ImageOutputTokens, - } - var err error - if s.resolver != nil && apiKey.Group != nil { - gid := apiKey.Group.ID - groupID := &gid - cost, err = s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: groupID, - Tokens: tokens, - RequestCount: 1, - RateMultiplier: multiplier, - Resolver: s.resolver, - }) - } else { - cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) - } - if err != nil { - logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) - cost = &CostBreakdown{ActualCost: 0} - } - } + // 计算费用 + cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts) // 判断计费方式:订阅模式 vs 余额模式 isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() @@ -7859,13 +7879,222 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } // 创建使用日志 + usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, + requestedModel, multiplier, billingType, cacheTTLOverridden, cost, opts) + + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + + requestID := usageLog.RequestID + accountRateMultiplier := account.BillingRateMultiplier() + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr + } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + + return nil +} + +// calculateRecordUsageCost 根据请求类型和选项计算费用。 +func (s *GatewayService) calculateRecordUsageCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, + opts *recordUsageOpts, +) *CostBreakdown { + // Sora 媒体类型分支(仅 Claude 路径启用) + if opts.EnableClaudePath { + if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo { + return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier) + } + if result.MediaType == MediaTypePrompt { + return &CostBreakdown{} + } + } + + // 图片生成计费 + if result.ImageCount > 0 { + return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) + } + + // Token 计费 + return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) +} + +// calculateSoraMediaCost 计算 Sora 图片/视频的费用。 +func (s *GatewayService) calculateSoraMediaCost( + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, +) *CostBreakdown { + var soraConfig *SoraPriceConfig + if apiKey.Group != nil { + soraConfig = &SoraPriceConfig{ + ImagePrice360: apiKey.Group.SoraImagePrice360, + ImagePrice540: apiKey.Group.SoraImagePrice540, + VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, + } + } + if result.MediaType == MediaTypeImage { + return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) + } + return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) +} + +// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。 +func (s *GatewayService) calculateImageCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, +) *CostBreakdown { + hasChannelPricing := false + if s.resolver != nil && apiKey.Group != nil { + gid := apiKey.Group.ID + resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) + if resolved.Source == PricingSourceChannel { + hasChannelPricing = true + } + } + if hasChannelPricing { + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + } + gid := apiKey.Group.ID + cost, err := s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + Resolver: s.resolver, + }) + if err != nil { + logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) + return &CostBreakdown{ActualCost: 0} + } + return cost + } + + var groupConfig *ImagePriceConfig + if apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } + } + return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) +} + +// calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。 +func (s *GatewayService) calculateTokenCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, + opts *recordUsageOpts, +) *CostBreakdown { + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + } + + var cost *CostBreakdown + var err error + + // 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用) + useUnified := false + if s.resolver != nil && apiKey.Group != nil { + gid := apiKey.Group.ID + resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) + if resolved.Source == PricingSourceChannel { + cost, err = s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + Resolver: s.resolver, + }) + useUnified = true + } + } + if !useUnified { + if opts.LongContextThreshold > 0 { + // 长上下文双倍计费(如 Gemini 200K 阈值) + cost, err = s.billingService.CalculateCostWithLongContext( + billingModel, tokens, multiplier, + opts.LongContextThreshold, opts.LongContextMultiplier, + ) + } else { + cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) + } + } + if err != nil { + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) + return &CostBreakdown{ActualCost: 0} + } + return cost +} + +// buildRecordUsageLog 构建使用日志并设置计费模式。 +func (s *GatewayService) buildRecordUsageLog( + ctx context.Context, + input *recordUsageCoreInput, + result *ForwardResult, + apiKey *APIKey, + user *User, + account *Account, + subscription *UserSubscription, + requestedModel string, + multiplier float64, + billingType int8, + cacheTTLOverridden bool, + cost *CostBreakdown, + opts *recordUsageOpts, +) *UsageLog { durationMs := int(result.Duration.Milliseconds()) var imageSize *string if result.ImageSize != "" { imageSize = &result.ImageSize } var mediaType *string - if strings.TrimSpace(result.MediaType) != "" { + if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" { mediaType = &result.MediaType } accountRateMultiplier := account.BillingRateMultiplier() @@ -7912,8 +8141,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu usageLog.ActualCost = cost.ActualCost } - // 设置计费模式 - if result.MediaType != MediaTypeImage && result.MediaType != MediaTypeVideo && result.MediaType != MediaTypePrompt { + // 设置计费模式:Sora 媒体类型自身已确定计费模式(由上游处理),跳过 + isSoraMedia := opts.EnableClaudePath && + (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt) + if !isSoraMedia { if cost != nil && cost.BillingMode != "" { billingMode := cost.BillingMode usageLog.BillingMode = &billingMode @@ -7944,307 +8175,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu usageLog.SubscriptionID = &subscription.ID } - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") - logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil - } - - billingErr := func() error { - _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ - Cost: cost, - User: user, - APIKey: apiKey, - Account: account, - Subscription: subscription, - RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), - IsSubscriptionBill: isSubscriptionBilling, - AccountRateMultiplier: accountRateMultiplier, - APIKeyService: input.APIKeyService, - }, s.billingDeps(), s.usageBillingRepo) - return err - }() - - if billingErr != nil { - return billingErr - } - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") - - return nil -} - -// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) -type RecordUsageLongContextInput struct { - Result *ForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 - InboundEndpoint string // 入站端点(客户端请求路径) - UpstreamEndpoint string // 上游端点(标准化后的上游路径) - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 - LongContextThreshold int // 长上下文阈值(如 200000) - LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) - - ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) -} - -// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) -func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { - result := input.Result - apiKey := input.APIKey - user := input.User - account := input.Account - subscription := input.Subscription - - // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens - // 用于粘性会话切换时的特殊计费处理 - if input.ForceCacheBilling && result.Usage.InputTokens > 0 { - logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", - result.Usage.InputTokens, account.ID) - result.Usage.CacheReadInputTokens += result.Usage.InputTokens - result.Usage.InputTokens = 0 - } - - // Cache TTL Override: 确保计费时 token 分类与账号设置一致 - cacheTTLOverridden := false - if account.IsCacheTTLOverrideEnabled() { - applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) - cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 - } - - // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := 1.0 - if s.cfg != nil { - multiplier = s.cfg.Default.RateMultiplier - } - if apiKey.GroupID != nil && apiKey.Group != nil { - groupDefault := apiKey.Group.RateMultiplier - multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) - } - - var cost *CostBreakdown - // 确定计费模型 - billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) - if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" { - billingModel = input.ChannelMappedModel - } - if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { - billingModel = input.OriginalModel - } - - // 确定 RequestedModel(渠道映射前的原始模型) - requestedModel := result.Model - if input.OriginalModel != "" { - requestedModel = input.OriginalModel - } - - // 根据请求类型选择计费方式 - if result.ImageCount > 0 { - // 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本) - hasChannelPricing := false - if s.resolver != nil && apiKey.Group != nil { - gid := apiKey.Group.ID - resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) - if resolved.Source == PricingSourceChannel { - hasChannelPricing = true - } - } - if hasChannelPricing { - tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - ImageOutputTokens: result.Usage.ImageOutputTokens, - } - gid := apiKey.Group.ID - var err error - cost, err = s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: &gid, - Tokens: tokens, - RequestCount: 1, - RateMultiplier: multiplier, - Resolver: s.resolver, - }) - if err != nil { - logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) - cost = &CostBreakdown{ActualCost: 0} - } - } else { - var groupConfig *ImagePriceConfig - if apiKey.Group != nil { - groupConfig = &ImagePriceConfig{ - Price1K: apiKey.Group.ImagePrice1K, - Price2K: apiKey.Group.ImagePrice2K, - Price4K: apiKey.Group.ImagePrice4K, - } - } - cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) - } - } else { - // Token 计费(使用长上下文计费方法) - tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, - CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - ImageOutputTokens: result.Usage.ImageOutputTokens, - } - var err error - // 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用) - useUnified := false - if s.resolver != nil && apiKey.Group != nil { - gid := apiKey.Group.ID - resolved := s.resolver.Resolve(ctx, PricingInput{ - Model: billingModel, - GroupID: &gid, - }) - if resolved.Source == PricingSourceChannel { - // 有渠道定价,渠道区间已包含上下文分层 - cost, err = s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: &gid, - Tokens: tokens, - RequestCount: 1, - RateMultiplier: multiplier, - Resolver: s.resolver, - }) - useUnified = true - } - } - if !useUnified { - // 无渠道定价,保持原有长上下文双倍计费逻辑(如 Gemini 200K 阈值) - cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) - } - if err != nil { - logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) - cost = &CostBreakdown{ActualCost: 0} - } - } - - // 判断计费方式:订阅模式 vs 余额模式 - isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() - billingType := BillingTypeBalance - if isSubscriptionBilling { - billingType = BillingTypeSubscription - } - - // 创建使用日志 - durationMs := int(result.Duration.Milliseconds()) - var imageSize *string - if result.ImageSize != "" { - imageSize = &result.ImageSize - } - accountRateMultiplier := account.BillingRateMultiplier() - requestID := resolveUsageBillingRequestID(ctx, result.RequestID) - usageLog := &UsageLog{ - UserID: user.ID, - APIKeyID: apiKey.ID, - AccountID: account.ID, - RequestID: requestID, - Model: result.Model, - RequestedModel: requestedModel, - UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), - ReasoningEffort: result.ReasoningEffort, - InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), - UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, - CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - ImageOutputTokens: result.Usage.ImageOutputTokens, - RateMultiplier: multiplier, - AccountRateMultiplier: &accountRateMultiplier, - BillingType: billingType, - Stream: result.Stream, - DurationMs: &durationMs, - FirstTokenMs: result.FirstTokenMs, - ImageCount: result.ImageCount, - ImageSize: imageSize, - CacheTTLOverridden: cacheTTLOverridden, - ChannelID: optionalInt64Ptr(input.ChannelID), - ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), - CreatedAt: time.Now(), - } - if cost != nil { - usageLog.InputCost = cost.InputCost - usageLog.OutputCost = cost.OutputCost - usageLog.ImageOutputCost = cost.ImageOutputCost - usageLog.CacheCreationCost = cost.CacheCreationCost - usageLog.CacheReadCost = cost.CacheReadCost - usageLog.TotalCost = cost.TotalCost - usageLog.ActualCost = cost.ActualCost - } - - // 设置计费模式 - if cost != nil && cost.BillingMode != "" { - billingMode := cost.BillingMode - usageLog.BillingMode = &billingMode - } else if result.ImageCount > 0 { - billingMode := string(BillingModeImage) - usageLog.BillingMode = &billingMode - } else { - billingMode := string(BillingModeToken) - usageLog.BillingMode = &billingMode - } - - // 添加 UserAgent - if input.UserAgent != "" { - usageLog.UserAgent = &input.UserAgent - } - - // 添加 IPAddress - if input.IPAddress != "" { - usageLog.IPAddress = &input.IPAddress - } - - // 添加分组和订阅关联 - if apiKey.GroupID != nil { - usageLog.GroupID = apiKey.GroupID - } - if subscription != nil { - usageLog.SubscriptionID = &subscription.ID - } - - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") - logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil - } - - billingErr := func() error { - _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ - Cost: cost, - User: user, - APIKey: apiKey, - Account: account, - Subscription: subscription, - RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), - IsSubscriptionBill: isSubscriptionBilling, - AccountRateMultiplier: accountRateMultiplier, - APIKeyService: input.APIKeyService, - }, s.billingDeps(), s.usageBillingRepo) - return err - }() - - if billingErr != nil { - return billingErr - } - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") - - return nil + return usageLog } // ResolveChannelMapping 委托渠道服务解析模型映射