refactor: decouple claude max cache policy and add tokenizer

This commit is contained in:
erio
2026-02-27 12:18:22 +08:00
parent 886464b2e9
commit 6da2f54e50
7 changed files with 695 additions and 252 deletions

View File

@@ -57,12 +57,9 @@ const (
)
const (
claudeMaxSimInputMinTokens = 8
claudeMaxSimInputMaxTokens = 96
claudeMaxSimBaseOverheadTokens = 8
claudeMaxSimPerBlockOverhead = 2
claudeMaxSimSummaryMaxRunes = 160
claudeMaxSimContextDivisor = 16
claudeMaxMessageOverheadTokens = 3
claudeMaxBlockOverheadTokens = 1
claudeMaxUnknownContentTokens = 4
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
@@ -5575,224 +5572,6 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
return multiplier
}
func isClaudeFamilyModel(model string) bool {
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
if normalized == "" {
return false
}
return strings.Contains(normalized, "claude-")
}
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
return false
}
group := input.APIKey.Group
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
return false
}
model := input.Result.Model
if model == "" && input.ParsedRequest != nil {
model = input.ParsedRequest.Model
}
if !isClaudeFamilyModel(model) {
return false
}
usage := input.Result.Usage
if usage.InputTokens <= 0 {
return false
}
if usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0 {
return false
}
return true
}
func applyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) bool {
if result == nil {
return false
}
return projectUsageToClaudeMax1H(&result.Usage, parsed)
}
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
if usage == nil {
return false
}
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if totalWindowTokens <= 1 {
return false
}
simulatedInputTokens := computeClaudeMaxSimulatedInputTokens(totalWindowTokens, parsed)
if simulatedInputTokens <= 0 {
simulatedInputTokens = 1
}
if simulatedInputTokens >= totalWindowTokens {
simulatedInputTokens = totalWindowTokens - 1
}
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
if usage.InputTokens == simulatedInputTokens &&
usage.CacheCreation5mTokens == 0 &&
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
usage.CacheCreationInputTokens == cacheCreation1hTokens {
return false
}
usage.InputTokens = simulatedInputTokens
usage.CacheCreation5mTokens = 0
usage.CacheCreation1hTokens = cacheCreation1hTokens
usage.CacheCreationInputTokens = cacheCreation1hTokens
return true
}
func computeClaudeMaxSimulatedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
if totalWindowTokens <= 1 {
return totalWindowTokens
}
summary, blockCount := extractTailUserMessageSummary(parsed)
if blockCount <= 0 {
blockCount = 1
}
asciiChars := 0
nonASCIIChars := 0
for _, r := range summary {
if r <= 127 {
asciiChars++
continue
}
nonASCIIChars++
}
lexicalTokens := nonASCIIChars
if asciiChars > 0 {
lexicalTokens += (asciiChars + 3) / 4
}
wordCount := len(strings.Fields(summary))
if wordCount > lexicalTokens {
lexicalTokens = wordCount
}
if lexicalTokens == 0 {
lexicalTokens = 1
}
structuralTokens := claudeMaxSimBaseOverheadTokens + blockCount*claudeMaxSimPerBlockOverhead
rawInputTokens := structuralTokens + lexicalTokens
maxInputTokens := clampInt(totalWindowTokens/claudeMaxSimContextDivisor, claudeMaxSimInputMinTokens, claudeMaxSimInputMaxTokens)
if totalWindowTokens <= claudeMaxSimInputMinTokens+1 {
maxInputTokens = totalWindowTokens - 1
}
if maxInputTokens <= 0 {
return totalWindowTokens
}
minInputTokens := 1
if totalWindowTokens > claudeMaxSimInputMinTokens+1 {
minInputTokens = claudeMaxSimInputMinTokens
}
return clampInt(rawInputTokens, minInputTokens, maxInputTokens)
}
func extractTailUserMessageSummary(parsed *ParsedRequest) (string, int) {
if parsed == nil || len(parsed.Messages) == 0 {
return "", 1
}
for i := len(parsed.Messages) - 1; i >= 0; i-- {
message, ok := parsed.Messages[i].(map[string]any)
if !ok {
continue
}
role, _ := message["role"].(string)
if !strings.EqualFold(strings.TrimSpace(role), "user") {
continue
}
summary, blockCount := summarizeUserContentBlocks(message["content"])
if blockCount <= 0 {
blockCount = 1
}
return summary, blockCount
}
return "", 1
}
func summarizeUserContentBlocks(content any) (string, int) {
appendSegment := func(segments []string, raw string) []string {
normalized := strings.Join(strings.Fields(strings.TrimSpace(raw)), " ")
if normalized == "" {
return segments
}
return append(segments, normalized)
}
switch value := content.(type) {
case string:
return trimClaudeMaxSummary(value), 1
case []any:
if len(value) == 0 {
return "", 1
}
segments := make([]string, 0, len(value))
for _, blockRaw := range value {
block, ok := blockRaw.(map[string]any)
if !ok {
continue
}
blockType, _ := block["type"].(string)
switch blockType {
case "text":
if text, ok := block["text"].(string); ok {
segments = appendSegment(segments, text)
}
case "tool_result":
nestedSummary, _ := summarizeUserContentBlocks(block["content"])
segments = appendSegment(segments, nestedSummary)
case "tool_use":
if name, ok := block["name"].(string); ok {
segments = appendSegment(segments, name)
}
default:
if text, ok := block["text"].(string); ok {
segments = appendSegment(segments, text)
}
}
}
return trimClaudeMaxSummary(strings.Join(segments, " ")), len(value)
default:
return "", 1
}
}
func trimClaudeMaxSummary(summary string) string {
normalized := strings.Join(strings.Fields(strings.TrimSpace(summary)), " ")
if normalized == "" {
return ""
}
runes := []rune(normalized)
if len(runes) > claudeMaxSimSummaryMaxRunes {
return string(runes[:claudeMaxSimSummaryMaxRunes])
}
return normalized
}
func clampInt(v, minValue, maxValue int) int {
if minValue > maxValue {
return minValue
}
if v < minValue {
return minValue
}
if v > maxValue {
return maxValue
}
return v
}
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
@@ -5829,25 +5608,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
result.Usage.InputTokens = 0
}
// Claude 分组模拟:将无写缓存 usage 映射为 claude-max 风格的 1h cache creation
simulatedClaudeMax := false
if shouldSimulateClaudeMaxUsage(input) {
beforeInputTokens := result.Usage.InputTokens
simulatedClaudeMax = applyClaudeMaxUsageSimulation(result, input.ParsedRequest)
if simulatedClaudeMax {
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
result.Model,
account.ID,
beforeInputTokens,
result.Usage.InputTokens,
result.Usage.CacheCreation1hTokens,
)
}
}
// Claude Max cache billing policy (group-level): force existing cache creation to 1h,
// otherwise simulate projection only when request carries cache signals.
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicy(input)
simulatedClaudeMax := claudeMaxOutcome.Simulated
forcedClaudeMax1H := claudeMaxOutcome.ForcedCache1H
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
cacheTTLOverridden := forcedClaudeMax1H
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax && !forcedClaudeMax1H {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}