diff --git a/backend/go.mod b/backend/go.mod index ec3cf509..70b675fa 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -59,6 +59,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/docker/docker v28.5.1+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect @@ -109,6 +110,8 @@ require ( github.com/opencontainers/image-spec v1.1.1 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pkoukk/tiktoken-go v0.1.8 // indirect + github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/quic-go/qpack v0.6.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 9e9fc545..6b4c2f7c 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -64,6 +64,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM= github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= @@ -223,6 +225,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6 github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo= +github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= +github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4= +github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/backend/internal/service/claude_max_cache_billing_policy.go b/backend/internal/service/claude_max_cache_billing_policy.go new file mode 100644 index 00000000..5f2e2def --- /dev/null +++ b/backend/internal/service/claude_max_cache_billing_policy.go @@ -0,0 +1,500 @@ +package service + +import ( + "encoding/json" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" +) + +type claudeMaxCacheBillingOutcome struct { + Simulated bool + ForcedCache1H bool +} + +func applyClaudeMaxCacheBillingPolicy(input *RecordUsageInput) claudeMaxCacheBillingOutcome { + var out claudeMaxCacheBillingOutcome + if !shouldApplyClaudeMaxBillingRules(input) { + return out + } + + if input == nil || input.Result == nil { + return out + } + result := input.Result + usage := &result.Usage + accountID := int64(0) + if input.Account != nil { + accountID = input.Account.ID + } + + if hasCacheCreationTokens(*usage) { + before5m := usage.CacheCreation5mTokens + before1h := usage.CacheCreation1hTokens + out.ForcedCache1H = safelyForceCacheCreationTo1H(usage) + if out.ForcedCache1H { + logger.LegacyPrintf("service.gateway", "force_claude_max_cache_1h: model=%s account=%d cache_creation_5m:%d->%d cache_creation_1h:%d->%d", + result.Model, + accountID, + before5m, + usage.CacheCreation5mTokens, + before1h, + usage.CacheCreation1hTokens, + ) + } + return out + } + + if !shouldSimulateClaudeMaxUsage(input) { + return out + } + beforeInputTokens := usage.InputTokens + out.Simulated = safelyApplyClaudeMaxUsageSimulation(result, input.ParsedRequest) + if out.Simulated { + logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d", + result.Model, + accountID, + beforeInputTokens, + usage.InputTokens, + usage.CacheCreation1hTokens, + ) + } + return out +} + +func isClaudeFamilyModel(model string) bool { + normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model))) + if normalized == "" { + return false + } + return strings.Contains(normalized, "claude-") +} + +func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool { + if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil { + return false + } + group := input.APIKey.Group + if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic { + return false + } + + model := input.Result.Model + if model == "" && input.ParsedRequest != nil { + model = input.ParsedRequest.Model + } + if !isClaudeFamilyModel(model) { + return false + } + return true +} + +func hasCacheCreationTokens(usage ClaudeUsage) bool { + return usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0 +} + +func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool { + if !shouldApplyClaudeMaxBillingRules(input) { + return false + } + if !hasClaudeCacheSignals(input.ParsedRequest) { + return false + } + usage := input.Result.Usage + if usage.InputTokens <= 0 { + return false + } + if hasCacheCreationTokens(usage) { + return false + } + return true +} + +func forceCacheCreationTo1H(usage *ClaudeUsage) bool { + if usage == nil || !hasCacheCreationTokens(*usage) { + return false + } + + before5m := usage.CacheCreation5mTokens + before1h := usage.CacheCreation1hTokens + beforeAgg := usage.CacheCreationInputTokens + + _ = applyCacheTTLOverride(usage, "1h") + total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens + if total <= 0 { + total = usage.CacheCreationInputTokens + } + if total <= 0 { + return false + } + + usage.CacheCreation5mTokens = 0 + usage.CacheCreation1hTokens = total + usage.CacheCreationInputTokens = total + + return before5m != usage.CacheCreation5mTokens || + before1h != usage.CacheCreation1hTokens || + beforeAgg != usage.CacheCreationInputTokens +} + +func safelyApplyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) (changed bool) { + defer func() { + if r := recover(); r != nil { + logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r) + changed = false + } + }() + return applyClaudeMaxUsageSimulation(result, parsed) +} + +func safelyForceCacheCreationTo1H(usage *ClaudeUsage) (changed bool) { + defer func() { + if r := recover(); r != nil { + logger.LegacyPrintf("service.gateway", "force_cache_creation_1h skipped: panic=%v", r) + changed = false + } + }() + return forceCacheCreationTo1H(usage) +} + +func applyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) bool { + if result == nil { + return false + } + return projectUsageToClaudeMax1H(&result.Usage, parsed) +} + +func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool { + if usage == nil { + return false + } + totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens + if totalWindowTokens <= 1 { + return false + } + + simulatedInputTokens := computeClaudeMaxProjectedInputTokens(totalWindowTokens, parsed) + if simulatedInputTokens <= 0 { + simulatedInputTokens = 1 + } + if simulatedInputTokens >= totalWindowTokens { + simulatedInputTokens = totalWindowTokens - 1 + } + + cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens + if usage.InputTokens == simulatedInputTokens && + usage.CacheCreation5mTokens == 0 && + usage.CacheCreation1hTokens == cacheCreation1hTokens && + usage.CacheCreationInputTokens == cacheCreation1hTokens { + return false + } + + usage.InputTokens = simulatedInputTokens + usage.CacheCreation5mTokens = 0 + usage.CacheCreation1hTokens = cacheCreation1hTokens + usage.CacheCreationInputTokens = cacheCreation1hTokens + return true +} + +type claudeCacheProjection struct { + HasBreakpoint bool + BreakpointCount int + TotalEstimatedTokens int + TailEstimatedTokens int +} + +func computeClaudeMaxProjectedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int { + if totalWindowTokens <= 1 { + return totalWindowTokens + } + + projection := analyzeClaudeCacheProjection(parsed) + if !projection.HasBreakpoint || projection.TotalEstimatedTokens <= 0 || projection.TailEstimatedTokens <= 0 { + return totalWindowTokens + } + + totalEstimate := int64(projection.TotalEstimatedTokens) + tailEstimate := int64(projection.TailEstimatedTokens) + if tailEstimate > totalEstimate { + tailEstimate = totalEstimate + } + + scaled := (int64(totalWindowTokens)*tailEstimate + totalEstimate/2) / totalEstimate + if scaled <= 0 { + scaled = 1 + } + if scaled >= int64(totalWindowTokens) { + scaled = int64(totalWindowTokens - 1) + } + return int(scaled) +} + +func hasClaudeCacheSignals(parsed *ParsedRequest) bool { + if parsed == nil { + return false + } + if hasTopLevelEphemeralCacheControl(parsed) { + return true + } + return countExplicitCacheBreakpoints(parsed) > 0 +} + +func hasTopLevelEphemeralCacheControl(parsed *ParsedRequest) bool { + if parsed == nil || len(parsed.Body) == 0 { + return false + } + cacheType := strings.TrimSpace(gjson.GetBytes(parsed.Body, "cache_control.type").String()) + return strings.EqualFold(cacheType, "ephemeral") +} + +func analyzeClaudeCacheProjection(parsed *ParsedRequest) claudeCacheProjection { + var projection claudeCacheProjection + if parsed == nil { + return projection + } + + total := 0 + lastBreakpointAt := -1 + + switch system := parsed.System.(type) { + case string: + total += claudeMaxMessageOverheadTokens + estimateClaudeTextTokens(system) + case []any: + for _, raw := range system { + block, ok := raw.(map[string]any) + if !ok { + total += claudeMaxUnknownContentTokens + continue + } + total += estimateClaudeBlockTokens(block) + if hasEphemeralCacheControl(block) { + lastBreakpointAt = total + projection.BreakpointCount++ + projection.HasBreakpoint = true + } + } + } + + for _, rawMsg := range parsed.Messages { + total += claudeMaxMessageOverheadTokens + msg, ok := rawMsg.(map[string]any) + if !ok { + total += claudeMaxUnknownContentTokens + continue + } + content, exists := msg["content"] + if !exists { + continue + } + msgTokens, msgLastBreak, msgBreakCount := estimateClaudeContentTokens(content) + total += msgTokens + if msgBreakCount > 0 { + lastBreakpointAt = total - msgTokens + msgLastBreak + projection.BreakpointCount += msgBreakCount + projection.HasBreakpoint = true + } + } + + if total <= 0 { + total = 1 + } + projection.TotalEstimatedTokens = total + + if projection.HasBreakpoint && lastBreakpointAt >= 0 { + tail := total - lastBreakpointAt + if tail <= 0 { + tail = 1 + } + projection.TailEstimatedTokens = tail + return projection + } + + if hasTopLevelEphemeralCacheControl(parsed) { + tail := estimateLastUserMessageTokens(parsed) + if tail <= 0 { + tail = 1 + } + projection.HasBreakpoint = true + projection.BreakpointCount = 1 + projection.TailEstimatedTokens = tail + } + return projection +} + +func countExplicitCacheBreakpoints(parsed *ParsedRequest) int { + if parsed == nil { + return 0 + } + total := 0 + if system, ok := parsed.System.([]any); ok { + for _, raw := range system { + if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) { + total++ + } + } + } + for _, rawMsg := range parsed.Messages { + msg, ok := rawMsg.(map[string]any) + if !ok { + continue + } + content, ok := msg["content"].([]any) + if !ok { + continue + } + for _, raw := range content { + if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) { + total++ + } + } + } + return total +} + +func hasEphemeralCacheControl(block map[string]any) bool { + if block == nil { + return false + } + raw, ok := block["cache_control"] + if !ok || raw == nil { + return false + } + switch cc := raw.(type) { + case map[string]any: + cacheType, _ := cc["type"].(string) + return strings.EqualFold(strings.TrimSpace(cacheType), "ephemeral") + case map[string]string: + return strings.EqualFold(strings.TrimSpace(cc["type"]), "ephemeral") + default: + return false + } +} + +func estimateClaudeContentTokens(content any) (tokens int, lastBreakAt int, breakpointCount int) { + switch value := content.(type) { + case string: + return estimateClaudeTextTokens(value), -1, 0 + case []any: + total := 0 + lastBreak := -1 + breaks := 0 + for _, raw := range value { + block, ok := raw.(map[string]any) + if !ok { + total += claudeMaxUnknownContentTokens + continue + } + total += estimateClaudeBlockTokens(block) + if hasEphemeralCacheControl(block) { + lastBreak = total + breaks++ + } + } + return total, lastBreak, breaks + default: + return estimateStructuredTokens(value), -1, 0 + } +} + +func estimateClaudeBlockTokens(block map[string]any) int { + if block == nil { + return claudeMaxUnknownContentTokens + } + tokens := claudeMaxBlockOverheadTokens + blockType, _ := block["type"].(string) + switch blockType { + case "text": + if text, ok := block["text"].(string); ok { + tokens += estimateClaudeTextTokens(text) + } + case "tool_result": + if content, ok := block["content"]; ok { + nested, _, _ := estimateClaudeContentTokens(content) + tokens += nested + } + case "tool_use": + if name, ok := block["name"].(string); ok { + tokens += estimateClaudeTextTokens(name) + } + if input, ok := block["input"]; ok { + tokens += estimateStructuredTokens(input) + } + default: + if text, ok := block["text"].(string); ok { + tokens += estimateClaudeTextTokens(text) + } else if content, ok := block["content"]; ok { + nested, _, _ := estimateClaudeContentTokens(content) + tokens += nested + } + } + if tokens <= claudeMaxBlockOverheadTokens { + tokens += claudeMaxUnknownContentTokens + } + return tokens +} + +func estimateLastUserMessageTokens(parsed *ParsedRequest) int { + if parsed == nil || len(parsed.Messages) == 0 { + return 0 + } + for i := len(parsed.Messages) - 1; i >= 0; i-- { + msg, ok := parsed.Messages[i].(map[string]any) + if !ok { + continue + } + role, _ := msg["role"].(string) + if !strings.EqualFold(strings.TrimSpace(role), "user") { + continue + } + tokens, _, _ := estimateClaudeContentTokens(msg["content"]) + return claudeMaxMessageOverheadTokens + tokens + } + return 0 +} + +func estimateStructuredTokens(v any) int { + if v == nil { + return 0 + } + raw, err := json.Marshal(v) + if err != nil { + return claudeMaxUnknownContentTokens + } + return estimateClaudeTextTokens(string(raw)) +} + +func estimateClaudeTextTokens(text string) int { + if tokens, ok := estimateTokensByThirdPartyTokenizer(text); ok { + return tokens + } + return estimateClaudeTextTokensHeuristic(text) +} + +func estimateClaudeTextTokensHeuristic(text string) int { + normalized := strings.Join(strings.Fields(strings.TrimSpace(text)), " ") + if normalized == "" { + return 0 + } + asciiChars := 0 + nonASCIIChars := 0 + for _, r := range normalized { + if r <= 127 { + asciiChars++ + } else { + nonASCIIChars++ + } + } + tokens := nonASCIIChars + if asciiChars > 0 { + tokens += (asciiChars + 3) / 4 + } + if words := len(strings.Fields(normalized)); words > tokens { + tokens = words + } + if tokens <= 0 { + return 1 + } + return tokens +} diff --git a/backend/internal/service/claude_max_simulation_test.go b/backend/internal/service/claude_max_simulation_test.go index 8f4690a0..3d2ae2e6 100644 --- a/backend/internal/service/claude_max_simulation_test.go +++ b/backend/internal/service/claude_max_simulation_test.go @@ -1,6 +1,9 @@ package service -import "testing" +import ( + "strings" + "testing" +) func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) { usage := &ClaudeUsage{ @@ -13,8 +16,18 @@ func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) { Model: "claude-sonnet-4-5", Messages: []any{ map[string]any{ - "role": "user", - "content": "请帮我总结这段代码并给出优化建议", + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": strings.Repeat("cached context ", 200), + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "summarize quickly", + }, + }, }, }, } @@ -34,6 +47,9 @@ func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) { if usage.InputTokens <= 0 || usage.InputTokens >= 1200 { t.Fatalf("simulated input out of range, got=%d", usage.InputTokens) } + if usage.InputTokens > 100 { + t.Fatalf("simulated input should stay near cache breakpoint tail, got=%d", usage.InputTokens) + } if usage.CacheCreation1hTokens <= 0 { t.Fatalf("cache_creation_1h should be > 0, got=%d", usage.CacheCreation1hTokens) } @@ -42,22 +58,29 @@ func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) { } } -func TestComputeClaudeMaxSimulatedInputTokens_Deterministic(t *testing.T) { +func TestComputeClaudeMaxProjectedInputTokens_Deterministic(t *testing.T) { parsed := &ParsedRequest{ Model: "claude-opus-4-5", Messages: []any{ map[string]any{ "role": "user", "content": []any{ - map[string]any{"type": "text", "text": "请整理以下日志并定位错误根因"}, - map[string]any{"type": "tool_use", "name": "grep_logs"}, + map[string]any{ + "type": "text", + "text": "build context", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "what is failing now", + }, }, }, }, } - got1 := computeClaudeMaxSimulatedInputTokens(4096, parsed) - got2 := computeClaudeMaxSimulatedInputTokens(4096, parsed) + got1 := computeClaudeMaxProjectedInputTokens(4096, parsed) + got2 := computeClaudeMaxProjectedInputTokens(4096, parsed) if got1 != got2 { t.Fatalf("non-deterministic input tokens: %d != %d", got1, got2) } @@ -78,13 +101,54 @@ func TestShouldSimulateClaudeMaxUsage(t *testing.T) { CacheCreation1hTokens: 0, }, }, + ParsedRequest: &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": "cached", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "tail", + }, + }, + }, + }, + }, APIKey: &APIKey{Group: group}, } if !shouldSimulateClaudeMaxUsage(input) { - t.Fatalf("expected simulate=true for claude group without cache creation") + t.Fatalf("expected simulate=true for claude group with cache signal") } + input.ParsedRequest = &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "no cache signal"}, + }, + } + if shouldSimulateClaudeMaxUsage(input) { + t.Fatalf("expected simulate=false when request has no cache signal") + } + + input.ParsedRequest = &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": "cached", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + }, + }, + }, + } input.Result.Usage.CacheCreationInputTokens = 100 if shouldSimulateClaudeMaxUsage(input) { t.Fatalf("expected simulate=false when cache creation already exists") diff --git a/backend/internal/service/claude_tokenizer.go b/backend/internal/service/claude_tokenizer.go new file mode 100644 index 00000000..61f5e961 --- /dev/null +++ b/backend/internal/service/claude_tokenizer.go @@ -0,0 +1,41 @@ +package service + +import ( + "sync" + + tiktoken "github.com/pkoukk/tiktoken-go" + tiktokenloader "github.com/pkoukk/tiktoken-go-loader" +) + +var ( + claudeTokenizerOnce sync.Once + claudeTokenizer *tiktoken.Tiktoken +) + +func getClaudeTokenizer() *tiktoken.Tiktoken { + claudeTokenizerOnce.Do(func() { + // Use offline loader to avoid runtime dictionary download. + tiktoken.SetBpeLoader(tiktokenloader.NewOfflineLoader()) + // Use a high-capacity tokenizer as the default approximation for Claude payloads. + enc, err := tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE) + if err != nil { + enc, err = tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE) + } + if err == nil { + claudeTokenizer = enc + } + }) + return claudeTokenizer +} + +func estimateTokensByThirdPartyTokenizer(text string) (int, bool) { + enc := getClaudeTokenizer() + if enc == nil { + return 0, false + } + tokens := len(enc.EncodeOrdinary(text)) + if tokens <= 0 { + return 0, false + } + return tokens, true +} diff --git a/backend/internal/service/gateway_record_usage_claude_max_test.go b/backend/internal/service/gateway_record_usage_claude_max_test.go index a4c8850c..445519f8 100644 --- a/backend/internal/service/gateway_record_usage_claude_max_test.go +++ b/backend/internal/service/gateway_record_usage_claude_max_test.go @@ -50,8 +50,18 @@ func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsAndSkipsTTLOverride(t *tes Model: "claude-sonnet-4", Messages: []any{ map[string]any{ - "role": "user", - "content": "please summarize the logs and provide root cause analysis", + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": "long cached context for prior turns", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "please summarize the logs and provide root cause analysis", + }, + }, }, }, }, @@ -138,3 +148,53 @@ func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T) require.Equal(t, 0, log.CacheCreation1hTokens) require.True(t, log.CacheTTLOverridden, "TTL override 生效时应打标") } + +func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationForce1H(t *testing.T) { + repo := &usageLogRepoRecordUsageStub{inserted: true} + svc := newGatewayServiceForRecordUsageTest(repo) + + groupID := int64(13) + input := &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "req-sim-3", + Model: "claude-sonnet-4", + Duration: time.Second, + Usage: ClaudeUsage{ + InputTokens: 20, + CacheCreationInputTokens: 120, + CacheCreation5mTokens: 120, + }, + }, + APIKey: &APIKey{ + ID: 3, + GroupID: &groupID, + Group: &Group{ + ID: groupID, + Platform: PlatformAnthropic, + RateMultiplier: 1, + SimulateClaudeMaxEnabled: true, + }, + }, + User: &User{ID: 4}, + Account: &Account{ + ID: 5, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "cache_ttl_override_enabled": true, + "cache_ttl_override_target": "5m", + }, + }, + } + + err := svc.RecordUsage(context.Background(), input) + require.NoError(t, err) + require.NotNil(t, repo.last) + + log := repo.last + require.Equal(t, 20, log.InputTokens, "existing cache creation should not project input tokens") + require.Equal(t, 0, log.CacheCreation5mTokens, "existing cache creation should be forced to 1h") + require.Equal(t, 120, log.CacheCreation1hTokens) + require.Equal(t, 120, log.CacheCreationTokens) + require.True(t, log.CacheTTLOverridden, "force-to-1h should mark cache ttl overridden") +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 71d69561..53b1fd28 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -57,12 +57,9 @@ const ( ) const ( - claudeMaxSimInputMinTokens = 8 - claudeMaxSimInputMaxTokens = 96 - claudeMaxSimBaseOverheadTokens = 8 - claudeMaxSimPerBlockOverhead = 2 - claudeMaxSimSummaryMaxRunes = 160 - claudeMaxSimContextDivisor = 16 + claudeMaxMessageOverheadTokens = 3 + claudeMaxBlockOverheadTokens = 1 + claudeMaxUnknownContentTokens = 4 ) // ForceCacheBillingContextKey 强制缓存计费上下文键 @@ -5575,224 +5572,6 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, return multiplier } -func isClaudeFamilyModel(model string) bool { - normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model))) - if normalized == "" { - return false - } - return strings.Contains(normalized, "claude-") -} - -func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool { - if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil { - return false - } - group := input.APIKey.Group - if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic { - return false - } - - model := input.Result.Model - if model == "" && input.ParsedRequest != nil { - model = input.ParsedRequest.Model - } - if !isClaudeFamilyModel(model) { - return false - } - - usage := input.Result.Usage - if usage.InputTokens <= 0 { - return false - } - if usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0 { - return false - } - return true -} - -func applyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) bool { - if result == nil { - return false - } - return projectUsageToClaudeMax1H(&result.Usage, parsed) -} - -func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool { - if usage == nil { - return false - } - totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens - if totalWindowTokens <= 1 { - return false - } - - simulatedInputTokens := computeClaudeMaxSimulatedInputTokens(totalWindowTokens, parsed) - if simulatedInputTokens <= 0 { - simulatedInputTokens = 1 - } - if simulatedInputTokens >= totalWindowTokens { - simulatedInputTokens = totalWindowTokens - 1 - } - - cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens - if usage.InputTokens == simulatedInputTokens && - usage.CacheCreation5mTokens == 0 && - usage.CacheCreation1hTokens == cacheCreation1hTokens && - usage.CacheCreationInputTokens == cacheCreation1hTokens { - return false - } - - usage.InputTokens = simulatedInputTokens - usage.CacheCreation5mTokens = 0 - usage.CacheCreation1hTokens = cacheCreation1hTokens - usage.CacheCreationInputTokens = cacheCreation1hTokens - return true -} - -func computeClaudeMaxSimulatedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int { - if totalWindowTokens <= 1 { - return totalWindowTokens - } - - summary, blockCount := extractTailUserMessageSummary(parsed) - if blockCount <= 0 { - blockCount = 1 - } - - asciiChars := 0 - nonASCIIChars := 0 - for _, r := range summary { - if r <= 127 { - asciiChars++ - continue - } - nonASCIIChars++ - } - - lexicalTokens := nonASCIIChars - if asciiChars > 0 { - lexicalTokens += (asciiChars + 3) / 4 - } - wordCount := len(strings.Fields(summary)) - if wordCount > lexicalTokens { - lexicalTokens = wordCount - } - if lexicalTokens == 0 { - lexicalTokens = 1 - } - - structuralTokens := claudeMaxSimBaseOverheadTokens + blockCount*claudeMaxSimPerBlockOverhead - rawInputTokens := structuralTokens + lexicalTokens - - maxInputTokens := clampInt(totalWindowTokens/claudeMaxSimContextDivisor, claudeMaxSimInputMinTokens, claudeMaxSimInputMaxTokens) - if totalWindowTokens <= claudeMaxSimInputMinTokens+1 { - maxInputTokens = totalWindowTokens - 1 - } - if maxInputTokens <= 0 { - return totalWindowTokens - } - - minInputTokens := 1 - if totalWindowTokens > claudeMaxSimInputMinTokens+1 { - minInputTokens = claudeMaxSimInputMinTokens - } - return clampInt(rawInputTokens, minInputTokens, maxInputTokens) -} - -func extractTailUserMessageSummary(parsed *ParsedRequest) (string, int) { - if parsed == nil || len(parsed.Messages) == 0 { - return "", 1 - } - for i := len(parsed.Messages) - 1; i >= 0; i-- { - message, ok := parsed.Messages[i].(map[string]any) - if !ok { - continue - } - role, _ := message["role"].(string) - if !strings.EqualFold(strings.TrimSpace(role), "user") { - continue - } - summary, blockCount := summarizeUserContentBlocks(message["content"]) - if blockCount <= 0 { - blockCount = 1 - } - return summary, blockCount - } - return "", 1 -} - -func summarizeUserContentBlocks(content any) (string, int) { - appendSegment := func(segments []string, raw string) []string { - normalized := strings.Join(strings.Fields(strings.TrimSpace(raw)), " ") - if normalized == "" { - return segments - } - return append(segments, normalized) - } - - switch value := content.(type) { - case string: - return trimClaudeMaxSummary(value), 1 - case []any: - if len(value) == 0 { - return "", 1 - } - segments := make([]string, 0, len(value)) - for _, blockRaw := range value { - block, ok := blockRaw.(map[string]any) - if !ok { - continue - } - blockType, _ := block["type"].(string) - switch blockType { - case "text": - if text, ok := block["text"].(string); ok { - segments = appendSegment(segments, text) - } - case "tool_result": - nestedSummary, _ := summarizeUserContentBlocks(block["content"]) - segments = appendSegment(segments, nestedSummary) - case "tool_use": - if name, ok := block["name"].(string); ok { - segments = appendSegment(segments, name) - } - default: - if text, ok := block["text"].(string); ok { - segments = appendSegment(segments, text) - } - } - } - return trimClaudeMaxSummary(strings.Join(segments, " ")), len(value) - default: - return "", 1 - } -} - -func trimClaudeMaxSummary(summary string) string { - normalized := strings.Join(strings.Fields(strings.TrimSpace(summary)), " ") - if normalized == "" { - return "" - } - runes := []rune(normalized) - if len(runes) > claudeMaxSimSummaryMaxRunes { - return string(runes[:claudeMaxSimSummaryMaxRunes]) - } - return normalized -} - -func clampInt(v, minValue, maxValue int) int { - if minValue > maxValue { - return minValue - } - if v < minValue { - return minValue - } - if v > maxValue { - return maxValue - } - return v -} - // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult @@ -5829,25 +5608,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu result.Usage.InputTokens = 0 } - // Claude 分组模拟:将无写缓存 usage 映射为 claude-max 风格的 1h cache creation。 - simulatedClaudeMax := false - if shouldSimulateClaudeMaxUsage(input) { - beforeInputTokens := result.Usage.InputTokens - simulatedClaudeMax = applyClaudeMaxUsageSimulation(result, input.ParsedRequest) - if simulatedClaudeMax { - logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d", - result.Model, - account.ID, - beforeInputTokens, - result.Usage.InputTokens, - result.Usage.CacheCreation1hTokens, - ) - } - } + // Claude Max cache billing policy (group-level): force existing cache creation to 1h, + // otherwise simulate projection only when request carries cache signals. + claudeMaxOutcome := applyClaudeMaxCacheBillingPolicy(input) + simulatedClaudeMax := claudeMaxOutcome.Simulated + forcedClaudeMax1H := claudeMaxOutcome.ForcedCache1H // Cache TTL Override: 确保计费时 token 分类与账号设置一致 - cacheTTLOverridden := false - if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax { + cacheTTLOverridden := forcedClaudeMax1H + if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax && !forcedClaudeMax1H { applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 }