mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-26 17:34:47 +08:00
refactor: decouple claude max cache policy and add tokenizer
This commit is contained in:
@@ -57,12 +57,9 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
claudeMaxSimInputMinTokens = 8
|
||||
claudeMaxSimInputMaxTokens = 96
|
||||
claudeMaxSimBaseOverheadTokens = 8
|
||||
claudeMaxSimPerBlockOverhead = 2
|
||||
claudeMaxSimSummaryMaxRunes = 160
|
||||
claudeMaxSimContextDivisor = 16
|
||||
claudeMaxMessageOverheadTokens = 3
|
||||
claudeMaxBlockOverheadTokens = 1
|
||||
claudeMaxUnknownContentTokens = 4
|
||||
)
|
||||
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
@@ -5575,224 +5572,6 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
|
||||
return multiplier
|
||||
}
|
||||
|
||||
func isClaudeFamilyModel(model string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(normalized, "claude-")
|
||||
}
|
||||
|
||||
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
|
||||
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
|
||||
return false
|
||||
}
|
||||
group := input.APIKey.Group
|
||||
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
|
||||
return false
|
||||
}
|
||||
|
||||
model := input.Result.Model
|
||||
if model == "" && input.ParsedRequest != nil {
|
||||
model = input.ParsedRequest.Model
|
||||
}
|
||||
if !isClaudeFamilyModel(model) {
|
||||
return false
|
||||
}
|
||||
|
||||
usage := input.Result.Usage
|
||||
if usage.InputTokens <= 0 {
|
||||
return false
|
||||
}
|
||||
if usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func applyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) bool {
|
||||
if result == nil {
|
||||
return false
|
||||
}
|
||||
return projectUsageToClaudeMax1H(&result.Usage, parsed)
|
||||
}
|
||||
|
||||
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
|
||||
if usage == nil {
|
||||
return false
|
||||
}
|
||||
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||
if totalWindowTokens <= 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
simulatedInputTokens := computeClaudeMaxSimulatedInputTokens(totalWindowTokens, parsed)
|
||||
if simulatedInputTokens <= 0 {
|
||||
simulatedInputTokens = 1
|
||||
}
|
||||
if simulatedInputTokens >= totalWindowTokens {
|
||||
simulatedInputTokens = totalWindowTokens - 1
|
||||
}
|
||||
|
||||
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
|
||||
if usage.InputTokens == simulatedInputTokens &&
|
||||
usage.CacheCreation5mTokens == 0 &&
|
||||
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
|
||||
usage.CacheCreationInputTokens == cacheCreation1hTokens {
|
||||
return false
|
||||
}
|
||||
|
||||
usage.InputTokens = simulatedInputTokens
|
||||
usage.CacheCreation5mTokens = 0
|
||||
usage.CacheCreation1hTokens = cacheCreation1hTokens
|
||||
usage.CacheCreationInputTokens = cacheCreation1hTokens
|
||||
return true
|
||||
}
|
||||
|
||||
func computeClaudeMaxSimulatedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
|
||||
if totalWindowTokens <= 1 {
|
||||
return totalWindowTokens
|
||||
}
|
||||
|
||||
summary, blockCount := extractTailUserMessageSummary(parsed)
|
||||
if blockCount <= 0 {
|
||||
blockCount = 1
|
||||
}
|
||||
|
||||
asciiChars := 0
|
||||
nonASCIIChars := 0
|
||||
for _, r := range summary {
|
||||
if r <= 127 {
|
||||
asciiChars++
|
||||
continue
|
||||
}
|
||||
nonASCIIChars++
|
||||
}
|
||||
|
||||
lexicalTokens := nonASCIIChars
|
||||
if asciiChars > 0 {
|
||||
lexicalTokens += (asciiChars + 3) / 4
|
||||
}
|
||||
wordCount := len(strings.Fields(summary))
|
||||
if wordCount > lexicalTokens {
|
||||
lexicalTokens = wordCount
|
||||
}
|
||||
if lexicalTokens == 0 {
|
||||
lexicalTokens = 1
|
||||
}
|
||||
|
||||
structuralTokens := claudeMaxSimBaseOverheadTokens + blockCount*claudeMaxSimPerBlockOverhead
|
||||
rawInputTokens := structuralTokens + lexicalTokens
|
||||
|
||||
maxInputTokens := clampInt(totalWindowTokens/claudeMaxSimContextDivisor, claudeMaxSimInputMinTokens, claudeMaxSimInputMaxTokens)
|
||||
if totalWindowTokens <= claudeMaxSimInputMinTokens+1 {
|
||||
maxInputTokens = totalWindowTokens - 1
|
||||
}
|
||||
if maxInputTokens <= 0 {
|
||||
return totalWindowTokens
|
||||
}
|
||||
|
||||
minInputTokens := 1
|
||||
if totalWindowTokens > claudeMaxSimInputMinTokens+1 {
|
||||
minInputTokens = claudeMaxSimInputMinTokens
|
||||
}
|
||||
return clampInt(rawInputTokens, minInputTokens, maxInputTokens)
|
||||
}
|
||||
|
||||
func extractTailUserMessageSummary(parsed *ParsedRequest) (string, int) {
|
||||
if parsed == nil || len(parsed.Messages) == 0 {
|
||||
return "", 1
|
||||
}
|
||||
for i := len(parsed.Messages) - 1; i >= 0; i-- {
|
||||
message, ok := parsed.Messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := message["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "user") {
|
||||
continue
|
||||
}
|
||||
summary, blockCount := summarizeUserContentBlocks(message["content"])
|
||||
if blockCount <= 0 {
|
||||
blockCount = 1
|
||||
}
|
||||
return summary, blockCount
|
||||
}
|
||||
return "", 1
|
||||
}
|
||||
|
||||
func summarizeUserContentBlocks(content any) (string, int) {
|
||||
appendSegment := func(segments []string, raw string) []string {
|
||||
normalized := strings.Join(strings.Fields(strings.TrimSpace(raw)), " ")
|
||||
if normalized == "" {
|
||||
return segments
|
||||
}
|
||||
return append(segments, normalized)
|
||||
}
|
||||
|
||||
switch value := content.(type) {
|
||||
case string:
|
||||
return trimClaudeMaxSummary(value), 1
|
||||
case []any:
|
||||
if len(value) == 0 {
|
||||
return "", 1
|
||||
}
|
||||
segments := make([]string, 0, len(value))
|
||||
for _, blockRaw := range value {
|
||||
block, ok := blockRaw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
blockType, _ := block["type"].(string)
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := block["text"].(string); ok {
|
||||
segments = appendSegment(segments, text)
|
||||
}
|
||||
case "tool_result":
|
||||
nestedSummary, _ := summarizeUserContentBlocks(block["content"])
|
||||
segments = appendSegment(segments, nestedSummary)
|
||||
case "tool_use":
|
||||
if name, ok := block["name"].(string); ok {
|
||||
segments = appendSegment(segments, name)
|
||||
}
|
||||
default:
|
||||
if text, ok := block["text"].(string); ok {
|
||||
segments = appendSegment(segments, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
return trimClaudeMaxSummary(strings.Join(segments, " ")), len(value)
|
||||
default:
|
||||
return "", 1
|
||||
}
|
||||
}
|
||||
|
||||
func trimClaudeMaxSummary(summary string) string {
|
||||
normalized := strings.Join(strings.Fields(strings.TrimSpace(summary)), " ")
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
runes := []rune(normalized)
|
||||
if len(runes) > claudeMaxSimSummaryMaxRunes {
|
||||
return string(runes[:claudeMaxSimSummaryMaxRunes])
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func clampInt(v, minValue, maxValue int) int {
|
||||
if minValue > maxValue {
|
||||
return minValue
|
||||
}
|
||||
if v < minValue {
|
||||
return minValue
|
||||
}
|
||||
if v > maxValue {
|
||||
return maxValue
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
@@ -5829,25 +5608,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
result.Usage.InputTokens = 0
|
||||
}
|
||||
|
||||
// Claude 分组模拟:将无写缓存 usage 映射为 claude-max 风格的 1h cache creation。
|
||||
simulatedClaudeMax := false
|
||||
if shouldSimulateClaudeMaxUsage(input) {
|
||||
beforeInputTokens := result.Usage.InputTokens
|
||||
simulatedClaudeMax = applyClaudeMaxUsageSimulation(result, input.ParsedRequest)
|
||||
if simulatedClaudeMax {
|
||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
|
||||
result.Model,
|
||||
account.ID,
|
||||
beforeInputTokens,
|
||||
result.Usage.InputTokens,
|
||||
result.Usage.CacheCreation1hTokens,
|
||||
)
|
||||
}
|
||||
}
|
||||
// Claude Max cache billing policy (group-level): force existing cache creation to 1h,
|
||||
// otherwise simulate projection only when request carries cache signals.
|
||||
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicy(input)
|
||||
simulatedClaudeMax := claudeMaxOutcome.Simulated
|
||||
forcedClaudeMax1H := claudeMaxOutcome.ForcedCache1H
|
||||
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||
cacheTTLOverridden := false
|
||||
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
|
||||
cacheTTLOverridden := forcedClaudeMax1H
|
||||
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax && !forcedClaudeMax1H {
|
||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user