mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
refactor: decouple claude max cache policy and add tokenizer
This commit is contained in:
@@ -59,6 +59,7 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
@@ -109,6 +110,8 @@ require (
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
|
||||
@@ -64,6 +64,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
@@ -223,6 +225,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
|
||||
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
|
||||
500
backend/internal/service/claude_max_cache_billing_policy.go
Normal file
500
backend/internal/service/claude_max_cache_billing_policy.go
Normal file
@@ -0,0 +1,500 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type claudeMaxCacheBillingOutcome struct {
|
||||
Simulated bool
|
||||
ForcedCache1H bool
|
||||
}
|
||||
|
||||
func applyClaudeMaxCacheBillingPolicy(input *RecordUsageInput) claudeMaxCacheBillingOutcome {
|
||||
var out claudeMaxCacheBillingOutcome
|
||||
if !shouldApplyClaudeMaxBillingRules(input) {
|
||||
return out
|
||||
}
|
||||
|
||||
if input == nil || input.Result == nil {
|
||||
return out
|
||||
}
|
||||
result := input.Result
|
||||
usage := &result.Usage
|
||||
accountID := int64(0)
|
||||
if input.Account != nil {
|
||||
accountID = input.Account.ID
|
||||
}
|
||||
|
||||
if hasCacheCreationTokens(*usage) {
|
||||
before5m := usage.CacheCreation5mTokens
|
||||
before1h := usage.CacheCreation1hTokens
|
||||
out.ForcedCache1H = safelyForceCacheCreationTo1H(usage)
|
||||
if out.ForcedCache1H {
|
||||
logger.LegacyPrintf("service.gateway", "force_claude_max_cache_1h: model=%s account=%d cache_creation_5m:%d->%d cache_creation_1h:%d->%d",
|
||||
result.Model,
|
||||
accountID,
|
||||
before5m,
|
||||
usage.CacheCreation5mTokens,
|
||||
before1h,
|
||||
usage.CacheCreation1hTokens,
|
||||
)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
if !shouldSimulateClaudeMaxUsage(input) {
|
||||
return out
|
||||
}
|
||||
beforeInputTokens := usage.InputTokens
|
||||
out.Simulated = safelyApplyClaudeMaxUsageSimulation(result, input.ParsedRequest)
|
||||
if out.Simulated {
|
||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
|
||||
result.Model,
|
||||
accountID,
|
||||
beforeInputTokens,
|
||||
usage.InputTokens,
|
||||
usage.CacheCreation1hTokens,
|
||||
)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func isClaudeFamilyModel(model string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(normalized, "claude-")
|
||||
}
|
||||
|
||||
func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool {
|
||||
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
|
||||
return false
|
||||
}
|
||||
group := input.APIKey.Group
|
||||
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
|
||||
return false
|
||||
}
|
||||
|
||||
model := input.Result.Model
|
||||
if model == "" && input.ParsedRequest != nil {
|
||||
model = input.ParsedRequest.Model
|
||||
}
|
||||
if !isClaudeFamilyModel(model) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func hasCacheCreationTokens(usage ClaudeUsage) bool {
|
||||
return usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0
|
||||
}
|
||||
|
||||
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
|
||||
if !shouldApplyClaudeMaxBillingRules(input) {
|
||||
return false
|
||||
}
|
||||
if !hasClaudeCacheSignals(input.ParsedRequest) {
|
||||
return false
|
||||
}
|
||||
usage := input.Result.Usage
|
||||
if usage.InputTokens <= 0 {
|
||||
return false
|
||||
}
|
||||
if hasCacheCreationTokens(usage) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func forceCacheCreationTo1H(usage *ClaudeUsage) bool {
|
||||
if usage == nil || !hasCacheCreationTokens(*usage) {
|
||||
return false
|
||||
}
|
||||
|
||||
before5m := usage.CacheCreation5mTokens
|
||||
before1h := usage.CacheCreation1hTokens
|
||||
beforeAgg := usage.CacheCreationInputTokens
|
||||
|
||||
_ = applyCacheTTLOverride(usage, "1h")
|
||||
total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||
if total <= 0 {
|
||||
total = usage.CacheCreationInputTokens
|
||||
}
|
||||
if total <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
usage.CacheCreation5mTokens = 0
|
||||
usage.CacheCreation1hTokens = total
|
||||
usage.CacheCreationInputTokens = total
|
||||
|
||||
return before5m != usage.CacheCreation5mTokens ||
|
||||
before1h != usage.CacheCreation1hTokens ||
|
||||
beforeAgg != usage.CacheCreationInputTokens
|
||||
}
|
||||
|
||||
func safelyApplyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) (changed bool) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r)
|
||||
changed = false
|
||||
}
|
||||
}()
|
||||
return applyClaudeMaxUsageSimulation(result, parsed)
|
||||
}
|
||||
|
||||
func safelyForceCacheCreationTo1H(usage *ClaudeUsage) (changed bool) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.LegacyPrintf("service.gateway", "force_cache_creation_1h skipped: panic=%v", r)
|
||||
changed = false
|
||||
}
|
||||
}()
|
||||
return forceCacheCreationTo1H(usage)
|
||||
}
|
||||
|
||||
func applyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) bool {
|
||||
if result == nil {
|
||||
return false
|
||||
}
|
||||
return projectUsageToClaudeMax1H(&result.Usage, parsed)
|
||||
}
|
||||
|
||||
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
|
||||
if usage == nil {
|
||||
return false
|
||||
}
|
||||
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||
if totalWindowTokens <= 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
simulatedInputTokens := computeClaudeMaxProjectedInputTokens(totalWindowTokens, parsed)
|
||||
if simulatedInputTokens <= 0 {
|
||||
simulatedInputTokens = 1
|
||||
}
|
||||
if simulatedInputTokens >= totalWindowTokens {
|
||||
simulatedInputTokens = totalWindowTokens - 1
|
||||
}
|
||||
|
||||
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
|
||||
if usage.InputTokens == simulatedInputTokens &&
|
||||
usage.CacheCreation5mTokens == 0 &&
|
||||
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
|
||||
usage.CacheCreationInputTokens == cacheCreation1hTokens {
|
||||
return false
|
||||
}
|
||||
|
||||
usage.InputTokens = simulatedInputTokens
|
||||
usage.CacheCreation5mTokens = 0
|
||||
usage.CacheCreation1hTokens = cacheCreation1hTokens
|
||||
usage.CacheCreationInputTokens = cacheCreation1hTokens
|
||||
return true
|
||||
}
|
||||
|
||||
type claudeCacheProjection struct {
|
||||
HasBreakpoint bool
|
||||
BreakpointCount int
|
||||
TotalEstimatedTokens int
|
||||
TailEstimatedTokens int
|
||||
}
|
||||
|
||||
func computeClaudeMaxProjectedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
|
||||
if totalWindowTokens <= 1 {
|
||||
return totalWindowTokens
|
||||
}
|
||||
|
||||
projection := analyzeClaudeCacheProjection(parsed)
|
||||
if !projection.HasBreakpoint || projection.TotalEstimatedTokens <= 0 || projection.TailEstimatedTokens <= 0 {
|
||||
return totalWindowTokens
|
||||
}
|
||||
|
||||
totalEstimate := int64(projection.TotalEstimatedTokens)
|
||||
tailEstimate := int64(projection.TailEstimatedTokens)
|
||||
if tailEstimate > totalEstimate {
|
||||
tailEstimate = totalEstimate
|
||||
}
|
||||
|
||||
scaled := (int64(totalWindowTokens)*tailEstimate + totalEstimate/2) / totalEstimate
|
||||
if scaled <= 0 {
|
||||
scaled = 1
|
||||
}
|
||||
if scaled >= int64(totalWindowTokens) {
|
||||
scaled = int64(totalWindowTokens - 1)
|
||||
}
|
||||
return int(scaled)
|
||||
}
|
||||
|
||||
func hasClaudeCacheSignals(parsed *ParsedRequest) bool {
|
||||
if parsed == nil {
|
||||
return false
|
||||
}
|
||||
if hasTopLevelEphemeralCacheControl(parsed) {
|
||||
return true
|
||||
}
|
||||
return countExplicitCacheBreakpoints(parsed) > 0
|
||||
}
|
||||
|
||||
func hasTopLevelEphemeralCacheControl(parsed *ParsedRequest) bool {
|
||||
if parsed == nil || len(parsed.Body) == 0 {
|
||||
return false
|
||||
}
|
||||
cacheType := strings.TrimSpace(gjson.GetBytes(parsed.Body, "cache_control.type").String())
|
||||
return strings.EqualFold(cacheType, "ephemeral")
|
||||
}
|
||||
|
||||
func analyzeClaudeCacheProjection(parsed *ParsedRequest) claudeCacheProjection {
|
||||
var projection claudeCacheProjection
|
||||
if parsed == nil {
|
||||
return projection
|
||||
}
|
||||
|
||||
total := 0
|
||||
lastBreakpointAt := -1
|
||||
|
||||
switch system := parsed.System.(type) {
|
||||
case string:
|
||||
total += claudeMaxMessageOverheadTokens + estimateClaudeTextTokens(system)
|
||||
case []any:
|
||||
for _, raw := range system {
|
||||
block, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
total += claudeMaxUnknownContentTokens
|
||||
continue
|
||||
}
|
||||
total += estimateClaudeBlockTokens(block)
|
||||
if hasEphemeralCacheControl(block) {
|
||||
lastBreakpointAt = total
|
||||
projection.BreakpointCount++
|
||||
projection.HasBreakpoint = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rawMsg := range parsed.Messages {
|
||||
total += claudeMaxMessageOverheadTokens
|
||||
msg, ok := rawMsg.(map[string]any)
|
||||
if !ok {
|
||||
total += claudeMaxUnknownContentTokens
|
||||
continue
|
||||
}
|
||||
content, exists := msg["content"]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
msgTokens, msgLastBreak, msgBreakCount := estimateClaudeContentTokens(content)
|
||||
total += msgTokens
|
||||
if msgBreakCount > 0 {
|
||||
lastBreakpointAt = total - msgTokens + msgLastBreak
|
||||
projection.BreakpointCount += msgBreakCount
|
||||
projection.HasBreakpoint = true
|
||||
}
|
||||
}
|
||||
|
||||
if total <= 0 {
|
||||
total = 1
|
||||
}
|
||||
projection.TotalEstimatedTokens = total
|
||||
|
||||
if projection.HasBreakpoint && lastBreakpointAt >= 0 {
|
||||
tail := total - lastBreakpointAt
|
||||
if tail <= 0 {
|
||||
tail = 1
|
||||
}
|
||||
projection.TailEstimatedTokens = tail
|
||||
return projection
|
||||
}
|
||||
|
||||
if hasTopLevelEphemeralCacheControl(parsed) {
|
||||
tail := estimateLastUserMessageTokens(parsed)
|
||||
if tail <= 0 {
|
||||
tail = 1
|
||||
}
|
||||
projection.HasBreakpoint = true
|
||||
projection.BreakpointCount = 1
|
||||
projection.TailEstimatedTokens = tail
|
||||
}
|
||||
return projection
|
||||
}
|
||||
|
||||
func countExplicitCacheBreakpoints(parsed *ParsedRequest) int {
|
||||
if parsed == nil {
|
||||
return 0
|
||||
}
|
||||
total := 0
|
||||
if system, ok := parsed.System.([]any); ok {
|
||||
for _, raw := range system {
|
||||
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, rawMsg := range parsed.Messages {
|
||||
msg, ok := rawMsg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msg["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, raw := range content {
|
||||
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func hasEphemeralCacheControl(block map[string]any) bool {
|
||||
if block == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := block["cache_control"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch cc := raw.(type) {
|
||||
case map[string]any:
|
||||
cacheType, _ := cc["type"].(string)
|
||||
return strings.EqualFold(strings.TrimSpace(cacheType), "ephemeral")
|
||||
case map[string]string:
|
||||
return strings.EqualFold(strings.TrimSpace(cc["type"]), "ephemeral")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func estimateClaudeContentTokens(content any) (tokens int, lastBreakAt int, breakpointCount int) {
|
||||
switch value := content.(type) {
|
||||
case string:
|
||||
return estimateClaudeTextTokens(value), -1, 0
|
||||
case []any:
|
||||
total := 0
|
||||
lastBreak := -1
|
||||
breaks := 0
|
||||
for _, raw := range value {
|
||||
block, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
total += claudeMaxUnknownContentTokens
|
||||
continue
|
||||
}
|
||||
total += estimateClaudeBlockTokens(block)
|
||||
if hasEphemeralCacheControl(block) {
|
||||
lastBreak = total
|
||||
breaks++
|
||||
}
|
||||
}
|
||||
return total, lastBreak, breaks
|
||||
default:
|
||||
return estimateStructuredTokens(value), -1, 0
|
||||
}
|
||||
}
|
||||
|
||||
func estimateClaudeBlockTokens(block map[string]any) int {
|
||||
if block == nil {
|
||||
return claudeMaxUnknownContentTokens
|
||||
}
|
||||
tokens := claudeMaxBlockOverheadTokens
|
||||
blockType, _ := block["type"].(string)
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := block["text"].(string); ok {
|
||||
tokens += estimateClaudeTextTokens(text)
|
||||
}
|
||||
case "tool_result":
|
||||
if content, ok := block["content"]; ok {
|
||||
nested, _, _ := estimateClaudeContentTokens(content)
|
||||
tokens += nested
|
||||
}
|
||||
case "tool_use":
|
||||
if name, ok := block["name"].(string); ok {
|
||||
tokens += estimateClaudeTextTokens(name)
|
||||
}
|
||||
if input, ok := block["input"]; ok {
|
||||
tokens += estimateStructuredTokens(input)
|
||||
}
|
||||
default:
|
||||
if text, ok := block["text"].(string); ok {
|
||||
tokens += estimateClaudeTextTokens(text)
|
||||
} else if content, ok := block["content"]; ok {
|
||||
nested, _, _ := estimateClaudeContentTokens(content)
|
||||
tokens += nested
|
||||
}
|
||||
}
|
||||
if tokens <= claudeMaxBlockOverheadTokens {
|
||||
tokens += claudeMaxUnknownContentTokens
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func estimateLastUserMessageTokens(parsed *ParsedRequest) int {
|
||||
if parsed == nil || len(parsed.Messages) == 0 {
|
||||
return 0
|
||||
}
|
||||
for i := len(parsed.Messages) - 1; i >= 0; i-- {
|
||||
msg, ok := parsed.Messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := msg["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "user") {
|
||||
continue
|
||||
}
|
||||
tokens, _, _ := estimateClaudeContentTokens(msg["content"])
|
||||
return claudeMaxMessageOverheadTokens + tokens
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func estimateStructuredTokens(v any) int {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
raw, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return claudeMaxUnknownContentTokens
|
||||
}
|
||||
return estimateClaudeTextTokens(string(raw))
|
||||
}
|
||||
|
||||
func estimateClaudeTextTokens(text string) int {
|
||||
if tokens, ok := estimateTokensByThirdPartyTokenizer(text); ok {
|
||||
return tokens
|
||||
}
|
||||
return estimateClaudeTextTokensHeuristic(text)
|
||||
}
|
||||
|
||||
func estimateClaudeTextTokensHeuristic(text string) int {
|
||||
normalized := strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
|
||||
if normalized == "" {
|
||||
return 0
|
||||
}
|
||||
asciiChars := 0
|
||||
nonASCIIChars := 0
|
||||
for _, r := range normalized {
|
||||
if r <= 127 {
|
||||
asciiChars++
|
||||
} else {
|
||||
nonASCIIChars++
|
||||
}
|
||||
}
|
||||
tokens := nonASCIIChars
|
||||
if asciiChars > 0 {
|
||||
tokens += (asciiChars + 3) / 4
|
||||
}
|
||||
if words := len(strings.Fields(normalized)); words > tokens {
|
||||
tokens = words
|
||||
}
|
||||
if tokens <= 0 {
|
||||
return 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
|
||||
usage := &ClaudeUsage{
|
||||
@@ -13,8 +16,18 @@ func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
|
||||
Model: "claude-sonnet-4-5",
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": "请帮我总结这段代码并给出优化建议",
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": strings.Repeat("cached context ", 200),
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "summarize quickly",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -34,6 +47,9 @@ func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
|
||||
if usage.InputTokens <= 0 || usage.InputTokens >= 1200 {
|
||||
t.Fatalf("simulated input out of range, got=%d", usage.InputTokens)
|
||||
}
|
||||
if usage.InputTokens > 100 {
|
||||
t.Fatalf("simulated input should stay near cache breakpoint tail, got=%d", usage.InputTokens)
|
||||
}
|
||||
if usage.CacheCreation1hTokens <= 0 {
|
||||
t.Fatalf("cache_creation_1h should be > 0, got=%d", usage.CacheCreation1hTokens)
|
||||
}
|
||||
@@ -42,22 +58,29 @@ func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeMaxSimulatedInputTokens_Deterministic(t *testing.T) {
|
||||
func TestComputeClaudeMaxProjectedInputTokens_Deterministic(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
Model: "claude-opus-4-5",
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "请整理以下日志并定位错误根因"},
|
||||
map[string]any{"type": "tool_use", "name": "grep_logs"},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "build context",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "what is failing now",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got1 := computeClaudeMaxSimulatedInputTokens(4096, parsed)
|
||||
got2 := computeClaudeMaxSimulatedInputTokens(4096, parsed)
|
||||
got1 := computeClaudeMaxProjectedInputTokens(4096, parsed)
|
||||
got2 := computeClaudeMaxProjectedInputTokens(4096, parsed)
|
||||
if got1 != got2 {
|
||||
t.Fatalf("non-deterministic input tokens: %d != %d", got1, got2)
|
||||
}
|
||||
@@ -78,13 +101,54 @@ func TestShouldSimulateClaudeMaxUsage(t *testing.T) {
|
||||
CacheCreation1hTokens: 0,
|
||||
},
|
||||
},
|
||||
ParsedRequest: &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "cached",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "tail",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
APIKey: &APIKey{Group: group},
|
||||
}
|
||||
|
||||
if !shouldSimulateClaudeMaxUsage(input) {
|
||||
t.Fatalf("expected simulate=true for claude group without cache creation")
|
||||
t.Fatalf("expected simulate=true for claude group with cache signal")
|
||||
}
|
||||
|
||||
input.ParsedRequest = &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "no cache signal"},
|
||||
},
|
||||
}
|
||||
if shouldSimulateClaudeMaxUsage(input) {
|
||||
t.Fatalf("expected simulate=false when request has no cache signal")
|
||||
}
|
||||
|
||||
input.ParsedRequest = &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "cached",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
input.Result.Usage.CacheCreationInputTokens = 100
|
||||
if shouldSimulateClaudeMaxUsage(input) {
|
||||
t.Fatalf("expected simulate=false when cache creation already exists")
|
||||
|
||||
41
backend/internal/service/claude_tokenizer.go
Normal file
41
backend/internal/service/claude_tokenizer.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
tiktoken "github.com/pkoukk/tiktoken-go"
|
||||
tiktokenloader "github.com/pkoukk/tiktoken-go-loader"
|
||||
)
|
||||
|
||||
var (
|
||||
claudeTokenizerOnce sync.Once
|
||||
claudeTokenizer *tiktoken.Tiktoken
|
||||
)
|
||||
|
||||
func getClaudeTokenizer() *tiktoken.Tiktoken {
|
||||
claudeTokenizerOnce.Do(func() {
|
||||
// Use offline loader to avoid runtime dictionary download.
|
||||
tiktoken.SetBpeLoader(tiktokenloader.NewOfflineLoader())
|
||||
// Use a high-capacity tokenizer as the default approximation for Claude payloads.
|
||||
enc, err := tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
|
||||
if err != nil {
|
||||
enc, err = tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
|
||||
}
|
||||
if err == nil {
|
||||
claudeTokenizer = enc
|
||||
}
|
||||
})
|
||||
return claudeTokenizer
|
||||
}
|
||||
|
||||
func estimateTokensByThirdPartyTokenizer(text string) (int, bool) {
|
||||
enc := getClaudeTokenizer()
|
||||
if enc == nil {
|
||||
return 0, false
|
||||
}
|
||||
tokens := len(enc.EncodeOrdinary(text))
|
||||
if tokens <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return tokens, true
|
||||
}
|
||||
@@ -50,8 +50,18 @@ func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsAndSkipsTTLOverride(t *tes
|
||||
Model: "claude-sonnet-4",
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": "please summarize the logs and provide root cause analysis",
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "long cached context for prior turns",
|
||||
"cache_control": map[string]any{"type": "ephemeral"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "please summarize the logs and provide root cause analysis",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -138,3 +148,53 @@ func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T)
|
||||
require.Equal(t, 0, log.CacheCreation1hTokens)
|
||||
require.True(t, log.CacheTTLOverridden, "TTL override 生效时应打标")
|
||||
}
|
||||
|
||||
func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationForce1H(t *testing.T) {
|
||||
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
||||
svc := newGatewayServiceForRecordUsageTest(repo)
|
||||
|
||||
groupID := int64(13)
|
||||
input := &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "req-sim-3",
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 20,
|
||||
CacheCreationInputTokens: 120,
|
||||
CacheCreation5mTokens: 120,
|
||||
},
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 3,
|
||||
GroupID: &groupID,
|
||||
Group: &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1,
|
||||
SimulateClaudeMaxEnabled: true,
|
||||
},
|
||||
},
|
||||
User: &User{ID: 4},
|
||||
Account: &Account{
|
||||
ID: 5,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"cache_ttl_override_enabled": true,
|
||||
"cache_ttl_override_target": "5m",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := svc.RecordUsage(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, repo.last)
|
||||
|
||||
log := repo.last
|
||||
require.Equal(t, 20, log.InputTokens, "existing cache creation should not project input tokens")
|
||||
require.Equal(t, 0, log.CacheCreation5mTokens, "existing cache creation should be forced to 1h")
|
||||
require.Equal(t, 120, log.CacheCreation1hTokens)
|
||||
require.Equal(t, 120, log.CacheCreationTokens)
|
||||
require.True(t, log.CacheTTLOverridden, "force-to-1h should mark cache ttl overridden")
|
||||
}
|
||||
|
||||
@@ -57,12 +57,9 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
claudeMaxSimInputMinTokens = 8
|
||||
claudeMaxSimInputMaxTokens = 96
|
||||
claudeMaxSimBaseOverheadTokens = 8
|
||||
claudeMaxSimPerBlockOverhead = 2
|
||||
claudeMaxSimSummaryMaxRunes = 160
|
||||
claudeMaxSimContextDivisor = 16
|
||||
claudeMaxMessageOverheadTokens = 3
|
||||
claudeMaxBlockOverheadTokens = 1
|
||||
claudeMaxUnknownContentTokens = 4
|
||||
)
|
||||
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
@@ -5575,224 +5572,6 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
|
||||
return multiplier
|
||||
}
|
||||
|
||||
func isClaudeFamilyModel(model string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(normalized, "claude-")
|
||||
}
|
||||
|
||||
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
|
||||
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
|
||||
return false
|
||||
}
|
||||
group := input.APIKey.Group
|
||||
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
|
||||
return false
|
||||
}
|
||||
|
||||
model := input.Result.Model
|
||||
if model == "" && input.ParsedRequest != nil {
|
||||
model = input.ParsedRequest.Model
|
||||
}
|
||||
if !isClaudeFamilyModel(model) {
|
||||
return false
|
||||
}
|
||||
|
||||
usage := input.Result.Usage
|
||||
if usage.InputTokens <= 0 {
|
||||
return false
|
||||
}
|
||||
if usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func applyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) bool {
|
||||
if result == nil {
|
||||
return false
|
||||
}
|
||||
return projectUsageToClaudeMax1H(&result.Usage, parsed)
|
||||
}
|
||||
|
||||
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
|
||||
if usage == nil {
|
||||
return false
|
||||
}
|
||||
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
||||
if totalWindowTokens <= 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
simulatedInputTokens := computeClaudeMaxSimulatedInputTokens(totalWindowTokens, parsed)
|
||||
if simulatedInputTokens <= 0 {
|
||||
simulatedInputTokens = 1
|
||||
}
|
||||
if simulatedInputTokens >= totalWindowTokens {
|
||||
simulatedInputTokens = totalWindowTokens - 1
|
||||
}
|
||||
|
||||
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
|
||||
if usage.InputTokens == simulatedInputTokens &&
|
||||
usage.CacheCreation5mTokens == 0 &&
|
||||
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
|
||||
usage.CacheCreationInputTokens == cacheCreation1hTokens {
|
||||
return false
|
||||
}
|
||||
|
||||
usage.InputTokens = simulatedInputTokens
|
||||
usage.CacheCreation5mTokens = 0
|
||||
usage.CacheCreation1hTokens = cacheCreation1hTokens
|
||||
usage.CacheCreationInputTokens = cacheCreation1hTokens
|
||||
return true
|
||||
}
|
||||
|
||||
func computeClaudeMaxSimulatedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
|
||||
if totalWindowTokens <= 1 {
|
||||
return totalWindowTokens
|
||||
}
|
||||
|
||||
summary, blockCount := extractTailUserMessageSummary(parsed)
|
||||
if blockCount <= 0 {
|
||||
blockCount = 1
|
||||
}
|
||||
|
||||
asciiChars := 0
|
||||
nonASCIIChars := 0
|
||||
for _, r := range summary {
|
||||
if r <= 127 {
|
||||
asciiChars++
|
||||
continue
|
||||
}
|
||||
nonASCIIChars++
|
||||
}
|
||||
|
||||
lexicalTokens := nonASCIIChars
|
||||
if asciiChars > 0 {
|
||||
lexicalTokens += (asciiChars + 3) / 4
|
||||
}
|
||||
wordCount := len(strings.Fields(summary))
|
||||
if wordCount > lexicalTokens {
|
||||
lexicalTokens = wordCount
|
||||
}
|
||||
if lexicalTokens == 0 {
|
||||
lexicalTokens = 1
|
||||
}
|
||||
|
||||
structuralTokens := claudeMaxSimBaseOverheadTokens + blockCount*claudeMaxSimPerBlockOverhead
|
||||
rawInputTokens := structuralTokens + lexicalTokens
|
||||
|
||||
maxInputTokens := clampInt(totalWindowTokens/claudeMaxSimContextDivisor, claudeMaxSimInputMinTokens, claudeMaxSimInputMaxTokens)
|
||||
if totalWindowTokens <= claudeMaxSimInputMinTokens+1 {
|
||||
maxInputTokens = totalWindowTokens - 1
|
||||
}
|
||||
if maxInputTokens <= 0 {
|
||||
return totalWindowTokens
|
||||
}
|
||||
|
||||
minInputTokens := 1
|
||||
if totalWindowTokens > claudeMaxSimInputMinTokens+1 {
|
||||
minInputTokens = claudeMaxSimInputMinTokens
|
||||
}
|
||||
return clampInt(rawInputTokens, minInputTokens, maxInputTokens)
|
||||
}
|
||||
|
||||
func extractTailUserMessageSummary(parsed *ParsedRequest) (string, int) {
|
||||
if parsed == nil || len(parsed.Messages) == 0 {
|
||||
return "", 1
|
||||
}
|
||||
for i := len(parsed.Messages) - 1; i >= 0; i-- {
|
||||
message, ok := parsed.Messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := message["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "user") {
|
||||
continue
|
||||
}
|
||||
summary, blockCount := summarizeUserContentBlocks(message["content"])
|
||||
if blockCount <= 0 {
|
||||
blockCount = 1
|
||||
}
|
||||
return summary, blockCount
|
||||
}
|
||||
return "", 1
|
||||
}
|
||||
|
||||
func summarizeUserContentBlocks(content any) (string, int) {
|
||||
appendSegment := func(segments []string, raw string) []string {
|
||||
normalized := strings.Join(strings.Fields(strings.TrimSpace(raw)), " ")
|
||||
if normalized == "" {
|
||||
return segments
|
||||
}
|
||||
return append(segments, normalized)
|
||||
}
|
||||
|
||||
switch value := content.(type) {
|
||||
case string:
|
||||
return trimClaudeMaxSummary(value), 1
|
||||
case []any:
|
||||
if len(value) == 0 {
|
||||
return "", 1
|
||||
}
|
||||
segments := make([]string, 0, len(value))
|
||||
for _, blockRaw := range value {
|
||||
block, ok := blockRaw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
blockType, _ := block["type"].(string)
|
||||
switch blockType {
|
||||
case "text":
|
||||
if text, ok := block["text"].(string); ok {
|
||||
segments = appendSegment(segments, text)
|
||||
}
|
||||
case "tool_result":
|
||||
nestedSummary, _ := summarizeUserContentBlocks(block["content"])
|
||||
segments = appendSegment(segments, nestedSummary)
|
||||
case "tool_use":
|
||||
if name, ok := block["name"].(string); ok {
|
||||
segments = appendSegment(segments, name)
|
||||
}
|
||||
default:
|
||||
if text, ok := block["text"].(string); ok {
|
||||
segments = appendSegment(segments, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
return trimClaudeMaxSummary(strings.Join(segments, " ")), len(value)
|
||||
default:
|
||||
return "", 1
|
||||
}
|
||||
}
|
||||
|
||||
func trimClaudeMaxSummary(summary string) string {
|
||||
normalized := strings.Join(strings.Fields(strings.TrimSpace(summary)), " ")
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
runes := []rune(normalized)
|
||||
if len(runes) > claudeMaxSimSummaryMaxRunes {
|
||||
return string(runes[:claudeMaxSimSummaryMaxRunes])
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func clampInt(v, minValue, maxValue int) int {
|
||||
if minValue > maxValue {
|
||||
return minValue
|
||||
}
|
||||
if v < minValue {
|
||||
return minValue
|
||||
}
|
||||
if v > maxValue {
|
||||
return maxValue
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
@@ -5829,25 +5608,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
result.Usage.InputTokens = 0
|
||||
}
|
||||
|
||||
// Claude 分组模拟:将无写缓存 usage 映射为 claude-max 风格的 1h cache creation。
|
||||
simulatedClaudeMax := false
|
||||
if shouldSimulateClaudeMaxUsage(input) {
|
||||
beforeInputTokens := result.Usage.InputTokens
|
||||
simulatedClaudeMax = applyClaudeMaxUsageSimulation(result, input.ParsedRequest)
|
||||
if simulatedClaudeMax {
|
||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
|
||||
result.Model,
|
||||
account.ID,
|
||||
beforeInputTokens,
|
||||
result.Usage.InputTokens,
|
||||
result.Usage.CacheCreation1hTokens,
|
||||
)
|
||||
}
|
||||
}
|
||||
// Claude Max cache billing policy (group-level): force existing cache creation to 1h,
|
||||
// otherwise simulate projection only when request carries cache signals.
|
||||
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicy(input)
|
||||
simulatedClaudeMax := claudeMaxOutcome.Simulated
|
||||
forcedClaudeMax1H := claudeMaxOutcome.ForcedCache1H
|
||||
|
||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||
cacheTTLOverridden := false
|
||||
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
|
||||
cacheTTLOverridden := forcedClaudeMax1H
|
||||
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax && !forcedClaudeMax1H {
|
||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user