refactor: decouple claude max cache policy and add tokenizer

This commit is contained in:
erio
2026-02-27 12:18:22 +08:00
parent 886464b2e9
commit 6da2f54e50
7 changed files with 695 additions and 252 deletions

View File

@@ -59,6 +59,7 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
@@ -109,6 +110,8 @@ require (
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/quic-go/qpack v0.6.0 // indirect

View File

@@ -64,6 +64,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
@@ -223,6 +225,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=

View File

@@ -0,0 +1,500 @@
package service
import (
"encoding/json"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
)
type claudeMaxCacheBillingOutcome struct {
Simulated bool
ForcedCache1H bool
}
func applyClaudeMaxCacheBillingPolicy(input *RecordUsageInput) claudeMaxCacheBillingOutcome {
var out claudeMaxCacheBillingOutcome
if !shouldApplyClaudeMaxBillingRules(input) {
return out
}
if input == nil || input.Result == nil {
return out
}
result := input.Result
usage := &result.Usage
accountID := int64(0)
if input.Account != nil {
accountID = input.Account.ID
}
if hasCacheCreationTokens(*usage) {
before5m := usage.CacheCreation5mTokens
before1h := usage.CacheCreation1hTokens
out.ForcedCache1H = safelyForceCacheCreationTo1H(usage)
if out.ForcedCache1H {
logger.LegacyPrintf("service.gateway", "force_claude_max_cache_1h: model=%s account=%d cache_creation_5m:%d->%d cache_creation_1h:%d->%d",
result.Model,
accountID,
before5m,
usage.CacheCreation5mTokens,
before1h,
usage.CacheCreation1hTokens,
)
}
return out
}
if !shouldSimulateClaudeMaxUsage(input) {
return out
}
beforeInputTokens := usage.InputTokens
out.Simulated = safelyApplyClaudeMaxUsageSimulation(result, input.ParsedRequest)
if out.Simulated {
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
result.Model,
accountID,
beforeInputTokens,
usage.InputTokens,
usage.CacheCreation1hTokens,
)
}
return out
}
func isClaudeFamilyModel(model string) bool {
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
if normalized == "" {
return false
}
return strings.Contains(normalized, "claude-")
}
func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool {
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
return false
}
group := input.APIKey.Group
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
return false
}
model := input.Result.Model
if model == "" && input.ParsedRequest != nil {
model = input.ParsedRequest.Model
}
if !isClaudeFamilyModel(model) {
return false
}
return true
}
func hasCacheCreationTokens(usage ClaudeUsage) bool {
return usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0
}
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
if !shouldApplyClaudeMaxBillingRules(input) {
return false
}
if !hasClaudeCacheSignals(input.ParsedRequest) {
return false
}
usage := input.Result.Usage
if usage.InputTokens <= 0 {
return false
}
if hasCacheCreationTokens(usage) {
return false
}
return true
}
func forceCacheCreationTo1H(usage *ClaudeUsage) bool {
if usage == nil || !hasCacheCreationTokens(*usage) {
return false
}
before5m := usage.CacheCreation5mTokens
before1h := usage.CacheCreation1hTokens
beforeAgg := usage.CacheCreationInputTokens
_ = applyCacheTTLOverride(usage, "1h")
total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if total <= 0 {
total = usage.CacheCreationInputTokens
}
if total <= 0 {
return false
}
usage.CacheCreation5mTokens = 0
usage.CacheCreation1hTokens = total
usage.CacheCreationInputTokens = total
return before5m != usage.CacheCreation5mTokens ||
before1h != usage.CacheCreation1hTokens ||
beforeAgg != usage.CacheCreationInputTokens
}
func safelyApplyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) (changed bool) {
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r)
changed = false
}
}()
return applyClaudeMaxUsageSimulation(result, parsed)
}
func safelyForceCacheCreationTo1H(usage *ClaudeUsage) (changed bool) {
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.gateway", "force_cache_creation_1h skipped: panic=%v", r)
changed = false
}
}()
return forceCacheCreationTo1H(usage)
}
func applyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) bool {
if result == nil {
return false
}
return projectUsageToClaudeMax1H(&result.Usage, parsed)
}
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
if usage == nil {
return false
}
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if totalWindowTokens <= 1 {
return false
}
simulatedInputTokens := computeClaudeMaxProjectedInputTokens(totalWindowTokens, parsed)
if simulatedInputTokens <= 0 {
simulatedInputTokens = 1
}
if simulatedInputTokens >= totalWindowTokens {
simulatedInputTokens = totalWindowTokens - 1
}
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
if usage.InputTokens == simulatedInputTokens &&
usage.CacheCreation5mTokens == 0 &&
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
usage.CacheCreationInputTokens == cacheCreation1hTokens {
return false
}
usage.InputTokens = simulatedInputTokens
usage.CacheCreation5mTokens = 0
usage.CacheCreation1hTokens = cacheCreation1hTokens
usage.CacheCreationInputTokens = cacheCreation1hTokens
return true
}
type claudeCacheProjection struct {
HasBreakpoint bool
BreakpointCount int
TotalEstimatedTokens int
TailEstimatedTokens int
}
func computeClaudeMaxProjectedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
if totalWindowTokens <= 1 {
return totalWindowTokens
}
projection := analyzeClaudeCacheProjection(parsed)
if !projection.HasBreakpoint || projection.TotalEstimatedTokens <= 0 || projection.TailEstimatedTokens <= 0 {
return totalWindowTokens
}
totalEstimate := int64(projection.TotalEstimatedTokens)
tailEstimate := int64(projection.TailEstimatedTokens)
if tailEstimate > totalEstimate {
tailEstimate = totalEstimate
}
scaled := (int64(totalWindowTokens)*tailEstimate + totalEstimate/2) / totalEstimate
if scaled <= 0 {
scaled = 1
}
if scaled >= int64(totalWindowTokens) {
scaled = int64(totalWindowTokens - 1)
}
return int(scaled)
}
func hasClaudeCacheSignals(parsed *ParsedRequest) bool {
if parsed == nil {
return false
}
if hasTopLevelEphemeralCacheControl(parsed) {
return true
}
return countExplicitCacheBreakpoints(parsed) > 0
}
func hasTopLevelEphemeralCacheControl(parsed *ParsedRequest) bool {
if parsed == nil || len(parsed.Body) == 0 {
return false
}
cacheType := strings.TrimSpace(gjson.GetBytes(parsed.Body, "cache_control.type").String())
return strings.EqualFold(cacheType, "ephemeral")
}
func analyzeClaudeCacheProjection(parsed *ParsedRequest) claudeCacheProjection {
var projection claudeCacheProjection
if parsed == nil {
return projection
}
total := 0
lastBreakpointAt := -1
switch system := parsed.System.(type) {
case string:
total += claudeMaxMessageOverheadTokens + estimateClaudeTextTokens(system)
case []any:
for _, raw := range system {
block, ok := raw.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
total += estimateClaudeBlockTokens(block)
if hasEphemeralCacheControl(block) {
lastBreakpointAt = total
projection.BreakpointCount++
projection.HasBreakpoint = true
}
}
}
for _, rawMsg := range parsed.Messages {
total += claudeMaxMessageOverheadTokens
msg, ok := rawMsg.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
content, exists := msg["content"]
if !exists {
continue
}
msgTokens, msgLastBreak, msgBreakCount := estimateClaudeContentTokens(content)
total += msgTokens
if msgBreakCount > 0 {
lastBreakpointAt = total - msgTokens + msgLastBreak
projection.BreakpointCount += msgBreakCount
projection.HasBreakpoint = true
}
}
if total <= 0 {
total = 1
}
projection.TotalEstimatedTokens = total
if projection.HasBreakpoint && lastBreakpointAt >= 0 {
tail := total - lastBreakpointAt
if tail <= 0 {
tail = 1
}
projection.TailEstimatedTokens = tail
return projection
}
if hasTopLevelEphemeralCacheControl(parsed) {
tail := estimateLastUserMessageTokens(parsed)
if tail <= 0 {
tail = 1
}
projection.HasBreakpoint = true
projection.BreakpointCount = 1
projection.TailEstimatedTokens = tail
}
return projection
}
func countExplicitCacheBreakpoints(parsed *ParsedRequest) int {
if parsed == nil {
return 0
}
total := 0
if system, ok := parsed.System.([]any); ok {
for _, raw := range system {
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
total++
}
}
}
for _, rawMsg := range parsed.Messages {
msg, ok := rawMsg.(map[string]any)
if !ok {
continue
}
content, ok := msg["content"].([]any)
if !ok {
continue
}
for _, raw := range content {
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
total++
}
}
}
return total
}
func hasEphemeralCacheControl(block map[string]any) bool {
if block == nil {
return false
}
raw, ok := block["cache_control"]
if !ok || raw == nil {
return false
}
switch cc := raw.(type) {
case map[string]any:
cacheType, _ := cc["type"].(string)
return strings.EqualFold(strings.TrimSpace(cacheType), "ephemeral")
case map[string]string:
return strings.EqualFold(strings.TrimSpace(cc["type"]), "ephemeral")
default:
return false
}
}
func estimateClaudeContentTokens(content any) (tokens int, lastBreakAt int, breakpointCount int) {
switch value := content.(type) {
case string:
return estimateClaudeTextTokens(value), -1, 0
case []any:
total := 0
lastBreak := -1
breaks := 0
for _, raw := range value {
block, ok := raw.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
total += estimateClaudeBlockTokens(block)
if hasEphemeralCacheControl(block) {
lastBreak = total
breaks++
}
}
return total, lastBreak, breaks
default:
return estimateStructuredTokens(value), -1, 0
}
}
func estimateClaudeBlockTokens(block map[string]any) int {
if block == nil {
return claudeMaxUnknownContentTokens
}
tokens := claudeMaxBlockOverheadTokens
blockType, _ := block["type"].(string)
switch blockType {
case "text":
if text, ok := block["text"].(string); ok {
tokens += estimateClaudeTextTokens(text)
}
case "tool_result":
if content, ok := block["content"]; ok {
nested, _, _ := estimateClaudeContentTokens(content)
tokens += nested
}
case "tool_use":
if name, ok := block["name"].(string); ok {
tokens += estimateClaudeTextTokens(name)
}
if input, ok := block["input"]; ok {
tokens += estimateStructuredTokens(input)
}
default:
if text, ok := block["text"].(string); ok {
tokens += estimateClaudeTextTokens(text)
} else if content, ok := block["content"]; ok {
nested, _, _ := estimateClaudeContentTokens(content)
tokens += nested
}
}
if tokens <= claudeMaxBlockOverheadTokens {
tokens += claudeMaxUnknownContentTokens
}
return tokens
}
func estimateLastUserMessageTokens(parsed *ParsedRequest) int {
if parsed == nil || len(parsed.Messages) == 0 {
return 0
}
for i := len(parsed.Messages) - 1; i >= 0; i-- {
msg, ok := parsed.Messages[i].(map[string]any)
if !ok {
continue
}
role, _ := msg["role"].(string)
if !strings.EqualFold(strings.TrimSpace(role), "user") {
continue
}
tokens, _, _ := estimateClaudeContentTokens(msg["content"])
return claudeMaxMessageOverheadTokens + tokens
}
return 0
}
func estimateStructuredTokens(v any) int {
if v == nil {
return 0
}
raw, err := json.Marshal(v)
if err != nil {
return claudeMaxUnknownContentTokens
}
return estimateClaudeTextTokens(string(raw))
}
func estimateClaudeTextTokens(text string) int {
if tokens, ok := estimateTokensByThirdPartyTokenizer(text); ok {
return tokens
}
return estimateClaudeTextTokensHeuristic(text)
}
func estimateClaudeTextTokensHeuristic(text string) int {
normalized := strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
if normalized == "" {
return 0
}
asciiChars := 0
nonASCIIChars := 0
for _, r := range normalized {
if r <= 127 {
asciiChars++
} else {
nonASCIIChars++
}
}
tokens := nonASCIIChars
if asciiChars > 0 {
tokens += (asciiChars + 3) / 4
}
if words := len(strings.Fields(normalized)); words > tokens {
tokens = words
}
if tokens <= 0 {
return 1
}
return tokens
}

View File

@@ -1,6 +1,9 @@
package service
import "testing"
import (
"strings"
"testing"
)
func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
usage := &ClaudeUsage{
@@ -13,8 +16,18 @@ func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
Model: "claude-sonnet-4-5",
Messages: []any{
map[string]any{
"role": "user",
"content": "请帮我总结这段代码并给出优化建议",
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": strings.Repeat("cached context ", 200),
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "summarize quickly",
},
},
},
},
}
@@ -34,6 +47,9 @@ func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
if usage.InputTokens <= 0 || usage.InputTokens >= 1200 {
t.Fatalf("simulated input out of range, got=%d", usage.InputTokens)
}
if usage.InputTokens > 100 {
t.Fatalf("simulated input should stay near cache breakpoint tail, got=%d", usage.InputTokens)
}
if usage.CacheCreation1hTokens <= 0 {
t.Fatalf("cache_creation_1h should be > 0, got=%d", usage.CacheCreation1hTokens)
}
@@ -42,22 +58,29 @@ func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
}
}
func TestComputeClaudeMaxSimulatedInputTokens_Deterministic(t *testing.T) {
func TestComputeClaudeMaxProjectedInputTokens_Deterministic(t *testing.T) {
parsed := &ParsedRequest{
Model: "claude-opus-4-5",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "请整理以下日志并定位错误根因"},
map[string]any{"type": "tool_use", "name": "grep_logs"},
map[string]any{
"type": "text",
"text": "build context",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "what is failing now",
},
},
},
},
}
got1 := computeClaudeMaxSimulatedInputTokens(4096, parsed)
got2 := computeClaudeMaxSimulatedInputTokens(4096, parsed)
got1 := computeClaudeMaxProjectedInputTokens(4096, parsed)
got2 := computeClaudeMaxProjectedInputTokens(4096, parsed)
if got1 != got2 {
t.Fatalf("non-deterministic input tokens: %d != %d", got1, got2)
}
@@ -78,13 +101,54 @@ func TestShouldSimulateClaudeMaxUsage(t *testing.T) {
CacheCreation1hTokens: 0,
},
},
ParsedRequest: &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "cached",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "tail",
},
},
},
},
},
APIKey: &APIKey{Group: group},
}
if !shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=true for claude group without cache creation")
t.Fatalf("expected simulate=true for claude group with cache signal")
}
input.ParsedRequest = &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "no cache signal"},
},
}
if shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=false when request has no cache signal")
}
input.ParsedRequest = &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "cached",
"cache_control": map[string]any{"type": "ephemeral"},
},
},
},
},
}
input.Result.Usage.CacheCreationInputTokens = 100
if shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=false when cache creation already exists")

View File

@@ -0,0 +1,41 @@
package service
import (
"sync"
tiktoken "github.com/pkoukk/tiktoken-go"
tiktokenloader "github.com/pkoukk/tiktoken-go-loader"
)
var (
claudeTokenizerOnce sync.Once
claudeTokenizer *tiktoken.Tiktoken
)
func getClaudeTokenizer() *tiktoken.Tiktoken {
claudeTokenizerOnce.Do(func() {
// Use offline loader to avoid runtime dictionary download.
tiktoken.SetBpeLoader(tiktokenloader.NewOfflineLoader())
// Use a high-capacity tokenizer as the default approximation for Claude payloads.
enc, err := tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
if err != nil {
enc, err = tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
}
if err == nil {
claudeTokenizer = enc
}
})
return claudeTokenizer
}
func estimateTokensByThirdPartyTokenizer(text string) (int, bool) {
enc := getClaudeTokenizer()
if enc == nil {
return 0, false
}
tokens := len(enc.EncodeOrdinary(text))
if tokens <= 0 {
return 0, false
}
return tokens, true
}

View File

@@ -50,8 +50,18 @@ func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsAndSkipsTTLOverride(t *tes
Model: "claude-sonnet-4",
Messages: []any{
map[string]any{
"role": "user",
"content": "please summarize the logs and provide root cause analysis",
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "long cached context for prior turns",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "please summarize the logs and provide root cause analysis",
},
},
},
},
},
@@ -138,3 +148,53 @@ func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T)
require.Equal(t, 0, log.CacheCreation1hTokens)
require.True(t, log.CacheTTLOverridden, "TTL override 生效时应打标")
}
func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationForce1H(t *testing.T) {
repo := &usageLogRepoRecordUsageStub{inserted: true}
svc := newGatewayServiceForRecordUsageTest(repo)
groupID := int64(13)
input := &RecordUsageInput{
Result: &ForwardResult{
RequestID: "req-sim-3",
Model: "claude-sonnet-4",
Duration: time.Second,
Usage: ClaudeUsage{
InputTokens: 20,
CacheCreationInputTokens: 120,
CacheCreation5mTokens: 120,
},
},
APIKey: &APIKey{
ID: 3,
GroupID: &groupID,
Group: &Group{
ID: groupID,
Platform: PlatformAnthropic,
RateMultiplier: 1,
SimulateClaudeMaxEnabled: true,
},
},
User: &User{ID: 4},
Account: &Account{
ID: 5,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
},
}
err := svc.RecordUsage(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, repo.last)
log := repo.last
require.Equal(t, 20, log.InputTokens, "existing cache creation should not project input tokens")
require.Equal(t, 0, log.CacheCreation5mTokens, "existing cache creation should be forced to 1h")
require.Equal(t, 120, log.CacheCreation1hTokens)
require.Equal(t, 120, log.CacheCreationTokens)
require.True(t, log.CacheTTLOverridden, "force-to-1h should mark cache ttl overridden")
}

View File

@@ -57,12 +57,9 @@ const (
)
const (
claudeMaxSimInputMinTokens = 8
claudeMaxSimInputMaxTokens = 96
claudeMaxSimBaseOverheadTokens = 8
claudeMaxSimPerBlockOverhead = 2
claudeMaxSimSummaryMaxRunes = 160
claudeMaxSimContextDivisor = 16
claudeMaxMessageOverheadTokens = 3
claudeMaxBlockOverheadTokens = 1
claudeMaxUnknownContentTokens = 4
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
@@ -5575,224 +5572,6 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
return multiplier
}
func isClaudeFamilyModel(model string) bool {
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
if normalized == "" {
return false
}
return strings.Contains(normalized, "claude-")
}
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
return false
}
group := input.APIKey.Group
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
return false
}
model := input.Result.Model
if model == "" && input.ParsedRequest != nil {
model = input.ParsedRequest.Model
}
if !isClaudeFamilyModel(model) {
return false
}
usage := input.Result.Usage
if usage.InputTokens <= 0 {
return false
}
if usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0 {
return false
}
return true
}
func applyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) bool {
if result == nil {
return false
}
return projectUsageToClaudeMax1H(&result.Usage, parsed)
}
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
if usage == nil {
return false
}
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if totalWindowTokens <= 1 {
return false
}
simulatedInputTokens := computeClaudeMaxSimulatedInputTokens(totalWindowTokens, parsed)
if simulatedInputTokens <= 0 {
simulatedInputTokens = 1
}
if simulatedInputTokens >= totalWindowTokens {
simulatedInputTokens = totalWindowTokens - 1
}
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
if usage.InputTokens == simulatedInputTokens &&
usage.CacheCreation5mTokens == 0 &&
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
usage.CacheCreationInputTokens == cacheCreation1hTokens {
return false
}
usage.InputTokens = simulatedInputTokens
usage.CacheCreation5mTokens = 0
usage.CacheCreation1hTokens = cacheCreation1hTokens
usage.CacheCreationInputTokens = cacheCreation1hTokens
return true
}
func computeClaudeMaxSimulatedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
if totalWindowTokens <= 1 {
return totalWindowTokens
}
summary, blockCount := extractTailUserMessageSummary(parsed)
if blockCount <= 0 {
blockCount = 1
}
asciiChars := 0
nonASCIIChars := 0
for _, r := range summary {
if r <= 127 {
asciiChars++
continue
}
nonASCIIChars++
}
lexicalTokens := nonASCIIChars
if asciiChars > 0 {
lexicalTokens += (asciiChars + 3) / 4
}
wordCount := len(strings.Fields(summary))
if wordCount > lexicalTokens {
lexicalTokens = wordCount
}
if lexicalTokens == 0 {
lexicalTokens = 1
}
structuralTokens := claudeMaxSimBaseOverheadTokens + blockCount*claudeMaxSimPerBlockOverhead
rawInputTokens := structuralTokens + lexicalTokens
maxInputTokens := clampInt(totalWindowTokens/claudeMaxSimContextDivisor, claudeMaxSimInputMinTokens, claudeMaxSimInputMaxTokens)
if totalWindowTokens <= claudeMaxSimInputMinTokens+1 {
maxInputTokens = totalWindowTokens - 1
}
if maxInputTokens <= 0 {
return totalWindowTokens
}
minInputTokens := 1
if totalWindowTokens > claudeMaxSimInputMinTokens+1 {
minInputTokens = claudeMaxSimInputMinTokens
}
return clampInt(rawInputTokens, minInputTokens, maxInputTokens)
}
func extractTailUserMessageSummary(parsed *ParsedRequest) (string, int) {
if parsed == nil || len(parsed.Messages) == 0 {
return "", 1
}
for i := len(parsed.Messages) - 1; i >= 0; i-- {
message, ok := parsed.Messages[i].(map[string]any)
if !ok {
continue
}
role, _ := message["role"].(string)
if !strings.EqualFold(strings.TrimSpace(role), "user") {
continue
}
summary, blockCount := summarizeUserContentBlocks(message["content"])
if blockCount <= 0 {
blockCount = 1
}
return summary, blockCount
}
return "", 1
}
func summarizeUserContentBlocks(content any) (string, int) {
appendSegment := func(segments []string, raw string) []string {
normalized := strings.Join(strings.Fields(strings.TrimSpace(raw)), " ")
if normalized == "" {
return segments
}
return append(segments, normalized)
}
switch value := content.(type) {
case string:
return trimClaudeMaxSummary(value), 1
case []any:
if len(value) == 0 {
return "", 1
}
segments := make([]string, 0, len(value))
for _, blockRaw := range value {
block, ok := blockRaw.(map[string]any)
if !ok {
continue
}
blockType, _ := block["type"].(string)
switch blockType {
case "text":
if text, ok := block["text"].(string); ok {
segments = appendSegment(segments, text)
}
case "tool_result":
nestedSummary, _ := summarizeUserContentBlocks(block["content"])
segments = appendSegment(segments, nestedSummary)
case "tool_use":
if name, ok := block["name"].(string); ok {
segments = appendSegment(segments, name)
}
default:
if text, ok := block["text"].(string); ok {
segments = appendSegment(segments, text)
}
}
}
return trimClaudeMaxSummary(strings.Join(segments, " ")), len(value)
default:
return "", 1
}
}
func trimClaudeMaxSummary(summary string) string {
normalized := strings.Join(strings.Fields(strings.TrimSpace(summary)), " ")
if normalized == "" {
return ""
}
runes := []rune(normalized)
if len(runes) > claudeMaxSimSummaryMaxRunes {
return string(runes[:claudeMaxSimSummaryMaxRunes])
}
return normalized
}
func clampInt(v, minValue, maxValue int) int {
if minValue > maxValue {
return minValue
}
if v < minValue {
return minValue
}
if v > maxValue {
return maxValue
}
return v
}
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
@@ -5829,25 +5608,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
result.Usage.InputTokens = 0
}
// Claude 分组模拟:将无写缓存 usage 映射为 claude-max 风格的 1h cache creation
simulatedClaudeMax := false
if shouldSimulateClaudeMaxUsage(input) {
beforeInputTokens := result.Usage.InputTokens
simulatedClaudeMax = applyClaudeMaxUsageSimulation(result, input.ParsedRequest)
if simulatedClaudeMax {
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
result.Model,
account.ID,
beforeInputTokens,
result.Usage.InputTokens,
result.Usage.CacheCreation1hTokens,
)
}
}
// Claude Max cache billing policy (group-level): force existing cache creation to 1h,
// otherwise simulate projection only when request carries cache signals.
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicy(input)
simulatedClaudeMax := claudeMaxOutcome.Simulated
forcedClaudeMax1H := claudeMaxOutcome.ForcedCache1H
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
cacheTTLOverridden := forcedClaudeMax1H
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax && !forcedClaudeMax1H {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}