From e71be7e0f1b86bceb2086feea645f44c78ff9b8f Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 27 Feb 2026 15:13:05 +0800 Subject: [PATCH] fix: update image pricing tests for 2K tier and refactor claude max billing policy - Update 5 test assertions to match new 2K default price ($0.201 = base * 1.5) - Refactor claude max cache billing policy into reusable functions --- .../service/billing_service_image_test.go | 28 ++--- .../claude_max_cache_billing_policy.go | 100 ++++++++++++++++-- 2 files changed, 107 insertions(+), 21 deletions(-) diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go index 18a6b74d..59125814 100644 --- a/backend/internal/service/billing_service_image_test.go +++ b/backend/internal/service/billing_service_image_test.go @@ -12,14 +12,14 @@ import ( func TestCalculateImageCost_DefaultPricing(t *testing.T) { svc := &BillingService{} // pricingService 为 nil,使用硬编码默认值 - // 2K 尺寸,默认价格 $0.134 + // 2K 尺寸,默认价格 $0.134 * 1.5 = $0.201 cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) - require.InDelta(t, 0.134, cost.ActualCost, 0.0001) + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 多张图片 cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 3, nil, 1.0) - require.InDelta(t, 0.402, cost.TotalCost, 0.0001) + require.InDelta(t, 0.603, cost.TotalCost, 0.0001) } // TestCalculateImageCost_GroupCustomPricing 测试分组自定义价格 @@ -63,13 +63,13 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) { // 费率倍数 1.5x cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) // TotalCost 不变 - require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // ActualCost = 0.134 * 1.5 + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5 + require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5 // 费率倍数 2.0x cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 2, nil, 2.0) - require.InDelta(t, 0.268, cost.TotalCost, 0.0001) - require.InDelta(t, 0.536, cost.ActualCost, 0.0001) + require.InDelta(t, 0.402, cost.TotalCost, 0.0001) + require.InDelta(t, 0.804, cost.ActualCost, 0.0001) } // TestCalculateImageCost_ZeroCount 测试 imageCount=0 @@ -95,8 +95,8 @@ func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) { svc := &BillingService{} cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) - require.InDelta(t, 0.134, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 } // TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格 @@ -127,9 +127,9 @@ func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) { cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, groupConfig, 1.0) require.InDelta(t, 0.10, cost.TotalCost, 0.0001) - // 2K 回退默认价格 $0.134 + // 2K 回退默认价格 $0.201 (1.5倍) cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // 4K 回退默认价格 $0.268 (翻倍) cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0) @@ -140,10 +140,10 @@ func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) { func TestGetDefaultImagePrice_FallbackHardcoded(t *testing.T) { svc := &BillingService{} // pricingService 为 nil - // 1K 和 2K 使用相同的默认价格 $0.134 + // 1K 默认价格 $0.134,2K 默认价格 $0.201 (1.5倍) cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, nil, 1.0) require.InDelta(t, 0.134, cost.TotalCost, 0.0001) cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) } diff --git a/backend/internal/service/claude_max_cache_billing_policy.go b/backend/internal/service/claude_max_cache_billing_policy.go index 5f2e2def..021d968c 100644 --- a/backend/internal/service/claude_max_cache_billing_policy.go +++ b/backend/internal/service/claude_max_cache_billing_policy.go @@ -64,6 +64,70 @@ func applyClaudeMaxCacheBillingPolicy(input *RecordUsageInput) claudeMaxCacheBil return out } +// detectClaudeMaxCacheBillingOutcomeForUsage only returns whether Claude Max policy +// should influence downstream override decisions. It does not mutate usage. +func detectClaudeMaxCacheBillingOutcomeForUsage(usage ClaudeUsage, parsed *ParsedRequest, group *Group, model string) claudeMaxCacheBillingOutcome { + var out claudeMaxCacheBillingOutcome + if !shouldApplyClaudeMaxBillingRulesForUsage(group, model, parsed) { + return out + } + if hasCacheCreationTokens(usage) { + out.ForcedCache1H = true + return out + } + if shouldSimulateClaudeMaxUsageForUsage(usage, parsed) { + out.Simulated = true + } + return out +} + +func applyClaudeMaxCacheBillingPolicyToUsage(usage *ClaudeUsage, parsed *ParsedRequest, group *Group, model string, accountID int64) claudeMaxCacheBillingOutcome { + var out claudeMaxCacheBillingOutcome + if usage == nil || !shouldApplyClaudeMaxBillingRulesForUsage(group, model, parsed) { + return out + } + + resolvedModel := strings.TrimSpace(model) + if resolvedModel == "" && parsed != nil { + resolvedModel = strings.TrimSpace(parsed.Model) + } + + if hasCacheCreationTokens(*usage) { + before5m := usage.CacheCreation5mTokens + before1h := usage.CacheCreation1hTokens + changed := safelyForceCacheCreationTo1H(usage) + // Even when value is already 1h, still mark forced to skip account TTL override. + out.ForcedCache1H = true + if changed { + logger.LegacyPrintf("service.gateway", "force_claude_max_cache_1h: model=%s account=%d cache_creation_5m:%d->%d cache_creation_1h:%d->%d", + resolvedModel, + accountID, + before5m, + usage.CacheCreation5mTokens, + before1h, + usage.CacheCreation1hTokens, + ) + } + return out + } + + if !shouldSimulateClaudeMaxUsageForUsage(*usage, parsed) { + return out + } + beforeInputTokens := usage.InputTokens + out.Simulated = safelyProjectUsageToClaudeMax1H(usage, parsed) + if out.Simulated { + logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d", + resolvedModel, + accountID, + beforeInputTokens, + usage.InputTokens, + usage.CacheCreation1hTokens, + ) + } + return out +} + func isClaudeFamilyModel(model string) bool { normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model))) if normalized == "" { @@ -76,16 +140,22 @@ func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool { if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil { return false } - group := input.APIKey.Group + return shouldApplyClaudeMaxBillingRulesForUsage(input.APIKey.Group, input.Result.Model, input.ParsedRequest) +} + +func shouldApplyClaudeMaxBillingRulesForUsage(group *Group, model string, parsed *ParsedRequest) bool { + if group == nil { + return false + } if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic { return false } - model := input.Result.Model - if model == "" && input.ParsedRequest != nil { - model = input.ParsedRequest.Model + resolvedModel := model + if resolvedModel == "" && parsed != nil { + resolvedModel = parsed.Model } - if !isClaudeFamilyModel(model) { + if !isClaudeFamilyModel(resolvedModel) { return false } return true @@ -96,13 +166,19 @@ func hasCacheCreationTokens(usage ClaudeUsage) bool { } func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool { + if input == nil || input.Result == nil { + return false + } if !shouldApplyClaudeMaxBillingRules(input) { return false } - if !hasClaudeCacheSignals(input.ParsedRequest) { + return shouldSimulateClaudeMaxUsageForUsage(input.Result.Usage, input.ParsedRequest) +} + +func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool { + if !hasClaudeCacheSignals(parsed) { return false } - usage := input.Result.Usage if usage.InputTokens <= 0 { return false } @@ -149,6 +225,16 @@ func safelyApplyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRe return applyClaudeMaxUsageSimulation(result, parsed) } +func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) { + defer func() { + if r := recover(); r != nil { + logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r) + changed = false + } + }() + return projectUsageToClaudeMax1H(usage, parsed) +} + func safelyForceCacheCreationTo1H(usage *ClaudeUsage) (changed bool) { defer func() { if r := recover(); r != nil {