refactor: isolate claude max response usage simulation by group toggle

This commit is contained in:
erio
2026-02-27 16:14:07 +08:00
parent e71be7e0f1
commit 61ef73cb12
5 changed files with 356 additions and 48 deletions

View File

@@ -31,19 +31,7 @@ func applyClaudeMaxCacheBillingPolicy(input *RecordUsageInput) claudeMaxCacheBil
} }
if hasCacheCreationTokens(*usage) { if hasCacheCreationTokens(*usage) {
before5m := usage.CacheCreation5mTokens // Upstream already returned cache creation usage; keep original usage.
before1h := usage.CacheCreation1hTokens
out.ForcedCache1H = safelyForceCacheCreationTo1H(usage)
if out.ForcedCache1H {
logger.LegacyPrintf("service.gateway", "force_claude_max_cache_1h: model=%s account=%d cache_creation_5m:%d->%d cache_creation_1h:%d->%d",
result.Model,
accountID,
before5m,
usage.CacheCreation5mTokens,
before1h,
usage.CacheCreation1hTokens,
)
}
return out return out
} }
@@ -72,7 +60,7 @@ func detectClaudeMaxCacheBillingOutcomeForUsage(usage ClaudeUsage, parsed *Parse
return out return out
} }
if hasCacheCreationTokens(usage) { if hasCacheCreationTokens(usage) {
out.ForcedCache1H = true // Upstream already returned cache creation usage; keep original usage.
return out return out
} }
if shouldSimulateClaudeMaxUsageForUsage(usage, parsed) { if shouldSimulateClaudeMaxUsageForUsage(usage, parsed) {
@@ -93,21 +81,7 @@ func applyClaudeMaxCacheBillingPolicyToUsage(usage *ClaudeUsage, parsed *ParsedR
} }
if hasCacheCreationTokens(*usage) { if hasCacheCreationTokens(*usage) {
before5m := usage.CacheCreation5mTokens // Upstream already returned cache creation usage; keep original usage.
before1h := usage.CacheCreation1hTokens
changed := safelyForceCacheCreationTo1H(usage)
// Even when value is already 1h, still mark forced to skip account TTL override.
out.ForcedCache1H = true
if changed {
logger.LegacyPrintf("service.gateway", "force_claude_max_cache_1h: model=%s account=%d cache_creation_5m:%d->%d cache_creation_1h:%d->%d",
resolvedModel,
accountID,
before5m,
usage.CacheCreation5mTokens,
before1h,
usage.CacheCreation1hTokens,
)
}
return out return out
} }

View File

@@ -0,0 +1,147 @@
package service
import (
"context"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/tidwall/sjson"
)
type claudeMaxResponseRewriteContext struct {
Parsed *ParsedRequest
Group *Group
}
type claudeMaxResponseRewriteContextKeyType struct{}
var claudeMaxResponseRewriteContextKey = claudeMaxResponseRewriteContextKeyType{}
func withClaudeMaxResponseRewriteContext(ctx context.Context, c *gin.Context, parsed *ParsedRequest) context.Context {
if ctx == nil {
ctx = context.Background()
}
value := claudeMaxResponseRewriteContext{
Parsed: parsed,
Group: claudeMaxGroupFromGinContext(c),
}
return context.WithValue(ctx, claudeMaxResponseRewriteContextKey, value)
}
func claudeMaxResponseRewriteContextFromContext(ctx context.Context) claudeMaxResponseRewriteContext {
if ctx == nil {
return claudeMaxResponseRewriteContext{}
}
value, _ := ctx.Value(claudeMaxResponseRewriteContextKey).(claudeMaxResponseRewriteContext)
return value
}
func claudeMaxGroupFromGinContext(c *gin.Context) *Group {
if c == nil {
return nil
}
raw, exists := c.Get("api_key")
if !exists {
return nil
}
apiKey, ok := raw.(*APIKey)
if !ok || apiKey == nil {
return nil
}
return apiKey.Group
}
func applyClaudeMaxSimulationToUsage(ctx context.Context, usage *ClaudeUsage, model string, accountID int64) claudeMaxCacheBillingOutcome {
var out claudeMaxCacheBillingOutcome
if usage == nil {
return out
}
rewriteCtx := claudeMaxResponseRewriteContextFromContext(ctx)
return applyClaudeMaxCacheBillingPolicyToUsage(usage, rewriteCtx.Parsed, rewriteCtx.Group, model, accountID)
}
func applyClaudeMaxSimulationToUsageJSONMap(ctx context.Context, usageObj map[string]any, model string, accountID int64) claudeMaxCacheBillingOutcome {
var out claudeMaxCacheBillingOutcome
if usageObj == nil {
return out
}
usage := claudeUsageFromJSONMap(usageObj)
out = applyClaudeMaxSimulationToUsage(ctx, &usage, model, accountID)
if out.Simulated {
rewriteClaudeUsageJSONMap(usageObj, usage)
}
return out
}
func rewriteClaudeUsageJSONBytes(body []byte, usage ClaudeUsage) []byte {
updated := body
var err error
updated, err = sjson.SetBytes(updated, "usage.input_tokens", usage.InputTokens)
if err != nil {
return body
}
updated, err = sjson.SetBytes(updated, "usage.cache_creation_input_tokens", usage.CacheCreationInputTokens)
if err != nil {
return body
}
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_5m_input_tokens", usage.CacheCreation5mTokens)
if err != nil {
return body
}
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_1h_input_tokens", usage.CacheCreation1hTokens)
if err != nil {
return body
}
return updated
}
func claudeUsageFromJSONMap(usageObj map[string]any) ClaudeUsage {
var usage ClaudeUsage
if usageObj == nil {
return usage
}
usage.InputTokens = usageIntFromAny(usageObj["input_tokens"])
usage.OutputTokens = usageIntFromAny(usageObj["output_tokens"])
usage.CacheCreationInputTokens = usageIntFromAny(usageObj["cache_creation_input_tokens"])
usage.CacheReadInputTokens = usageIntFromAny(usageObj["cache_read_input_tokens"])
if ccObj, ok := usageObj["cache_creation"].(map[string]any); ok {
usage.CacheCreation5mTokens = usageIntFromAny(ccObj["ephemeral_5m_input_tokens"])
usage.CacheCreation1hTokens = usageIntFromAny(ccObj["ephemeral_1h_input_tokens"])
}
return usage
}
func rewriteClaudeUsageJSONMap(usageObj map[string]any, usage ClaudeUsage) {
if usageObj == nil {
return
}
usageObj["input_tokens"] = usage.InputTokens
usageObj["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
ccObj, _ := usageObj["cache_creation"].(map[string]any)
if ccObj == nil {
ccObj = make(map[string]any, 2)
usageObj["cache_creation"] = ccObj
}
ccObj["ephemeral_5m_input_tokens"] = usage.CacheCreation5mTokens
ccObj["ephemeral_1h_input_tokens"] = usage.CacheCreation1hTokens
}
func usageIntFromAny(v any) int {
switch value := v.(type) {
case int:
return value
case int64:
return int(value)
case float64:
return int(value)
case json.Number:
if n, err := value.Int64(); err == nil {
return int(n)
}
}
return 0
}

View File

@@ -32,7 +32,7 @@ func newGatewayServiceForRecordUsageTest(repo UsageLogRepository) *GatewayServic
} }
} }
func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsAndSkipsTTLOverride(t *testing.T) { func TestRecordUsage_SimulateClaudeMaxEnabled_DoesNotProjectAndSkipsTTLOverride(t *testing.T) {
repo := &usageLogRepoRecordUsageStub{inserted: true} repo := &usageLogRepoRecordUsageStub{inserted: true}
svc := newGatewayServiceForRecordUsageTest(repo) svc := newGatewayServiceForRecordUsageTest(repo)
@@ -92,12 +92,11 @@ func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsAndSkipsTTLOverride(t *tes
require.NotNil(t, repo.last) require.NotNil(t, repo.last)
log := repo.last log := repo.last
total := log.InputTokens + log.CacheCreation5mTokens + log.CacheCreation1hTokens require.Equal(t, 160, log.InputTokens)
require.Equal(t, 160, total, "token 总量应保持不变") require.Equal(t, 0, log.CacheCreationTokens)
require.Greater(t, log.CacheCreation1hTokens, 0, "应映射为 1h cache creation") require.Equal(t, 0, log.CacheCreation5mTokens)
require.Equal(t, 0, log.CacheCreation5mTokens, "模拟成功后不应再被 TTL override 改写为 5m") require.Equal(t, 0, log.CacheCreation1hTokens)
require.Equal(t, log.CacheCreation1hTokens, log.CacheCreationTokens, "聚合 cache_creation_tokens 应与 1h 一致") require.False(t, log.CacheTTLOverridden, "simulate outcome should skip account ttl override")
require.False(t, log.CacheTTLOverridden, "模拟成功时应跳过 TTL override 标记")
} }
func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T) { func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T) {
@@ -144,12 +143,12 @@ func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T)
log := repo.last log := repo.last
require.Equal(t, 120, log.CacheCreationTokens) require.Equal(t, 120, log.CacheCreationTokens)
require.Equal(t, 120, log.CacheCreation5mTokens, "关闭模拟时应执行 TTL override 到 5m") require.Equal(t, 120, log.CacheCreation5mTokens)
require.Equal(t, 0, log.CacheCreation1hTokens) require.Equal(t, 0, log.CacheCreation1hTokens)
require.True(t, log.CacheTTLOverridden, "TTL override 生效时应打标") require.True(t, log.CacheTTLOverridden)
} }
func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationForce1H(t *testing.T) { func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationBypassesSimulation(t *testing.T) {
repo := &usageLogRepoRecordUsageStub{inserted: true} repo := &usageLogRepoRecordUsageStub{inserted: true}
svc := newGatewayServiceForRecordUsageTest(repo) svc := newGatewayServiceForRecordUsageTest(repo)
@@ -192,9 +191,9 @@ func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationForce1H(t *te
require.NotNil(t, repo.last) require.NotNil(t, repo.last)
log := repo.last log := repo.last
require.Equal(t, 20, log.InputTokens, "existing cache creation should not project input tokens") require.Equal(t, 20, log.InputTokens)
require.Equal(t, 0, log.CacheCreation5mTokens, "existing cache creation should be forced to 1h") require.Equal(t, 120, log.CacheCreation5mTokens)
require.Equal(t, 120, log.CacheCreation1hTokens) require.Equal(t, 0, log.CacheCreation1hTokens)
require.Equal(t, 120, log.CacheCreationTokens) require.Equal(t, 120, log.CacheCreationTokens)
require.True(t, log.CacheTTLOverridden, "force-to-1h should mark cache ttl overridden") require.True(t, log.CacheTTLOverridden, "existing cache_creation should remain under normal account ttl flow")
} }

View File

@@ -0,0 +1,170 @@
package service
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestHandleNonStreamingResponse_UsageAlignedWithClaudeMaxSimulation(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &GatewayService{
cfg: &config.Config{},
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 11,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
}
group := &Group{
ID: 99,
Platform: PlatformAnthropic,
SimulateClaudeMaxEnabled: true,
}
parsed := &ParsedRequest{
Model: "claude-sonnet-4",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "long cached context",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "new user question",
},
},
},
},
}
upstreamBody := []byte(`{"id":"msg_1","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: ioNopCloserBytes(upstreamBody),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
c.Set("api_key", &APIKey{Group: group})
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
require.NoError(t, err)
require.NotNil(t, usage)
var rendered struct {
Usage ClaudeUsage `json:"usage"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rendered))
rendered.Usage.CacheCreation5mTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_5m_input_tokens").Int())
rendered.Usage.CacheCreation1hTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_1h_input_tokens").Int())
require.Equal(t, rendered.Usage.InputTokens, usage.InputTokens)
require.Equal(t, rendered.Usage.OutputTokens, usage.OutputTokens)
require.Equal(t, rendered.Usage.CacheCreationInputTokens, usage.CacheCreationInputTokens)
require.Equal(t, rendered.Usage.CacheCreation5mTokens, usage.CacheCreation5mTokens)
require.Equal(t, rendered.Usage.CacheCreation1hTokens, usage.CacheCreation1hTokens)
require.Equal(t, rendered.Usage.CacheReadInputTokens, usage.CacheReadInputTokens)
require.Greater(t, usage.CacheCreation1hTokens, 0)
require.Equal(t, 0, usage.CacheCreation5mTokens)
require.Less(t, usage.InputTokens, 120)
}
func TestHandleNonStreamingResponse_ClaudeMaxDisabled_NoSimulationIntercept(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &GatewayService{
cfg: &config.Config{},
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 12,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
}
group := &Group{
ID: 100,
Platform: PlatformAnthropic,
SimulateClaudeMaxEnabled: false,
}
parsed := &ParsedRequest{
Model: "claude-sonnet-4",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "long cached context",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "new user question",
},
},
},
},
}
upstreamBody := []byte(`{"id":"msg_2","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: ioNopCloserBytes(upstreamBody),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
c.Set("api_key", &APIKey{Group: group})
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, 120, usage.InputTokens)
require.Equal(t, 0, usage.CacheCreationInputTokens)
require.Equal(t, 0, usage.CacheCreation5mTokens)
require.Equal(t, 0, usage.CacheCreation1hTokens)
}
func ioNopCloserBytes(b []byte) *readCloserFromBytes {
return &readCloserFromBytes{Reader: bytes.NewReader(b)}
}
type readCloserFromBytes struct {
*bytes.Reader
}
func (r *readCloserFromBytes) Close() error {
return nil
}

View File

@@ -3709,6 +3709,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
// 处理正常响应 // 处理正常响应
ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed)
var usage *ClaudeUsage var usage *ClaudeUsage
var firstTokenMs *int var firstTokenMs *int
var clientDisconnect bool var clientDisconnect bool
@@ -5105,6 +5106,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage
skipAccountTTLOverride := false
pendingEventLines := make([]string, 0, 4) pendingEventLines := make([]string, 0, 4)
@@ -5164,17 +5166,25 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if msg, ok := event["message"].(map[string]any); ok { if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok {
reconcileCachedTokens(u) reconcileCachedTokens(u)
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
skipAccountTTLOverride = true
}
} }
} }
} }
if eventType == "message_delta" { if eventType == "message_delta" {
if u, ok := event["usage"].(map[string]any); ok { if u, ok := event["usage"].(map[string]any); ok {
reconcileCachedTokens(u) reconcileCachedTokens(u)
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
skipAccountTTLOverride = true
}
} }
} }
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() { if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride {
overrideTarget := account.GetCacheTTLOverrideTarget() overrideTarget := account.GetCacheTTLOverrideTarget()
if eventType == "message_start" { if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok { if msg, ok := event["message"].(map[string]any); ok {
@@ -5465,8 +5475,13 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
} }
} }
claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
body = rewriteClaudeUsageJSONBytes(body, response.Usage)
}
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() { if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated && !claudeMaxOutcome.ForcedCache1H {
overrideTarget := account.GetCacheTTLOverrideTarget() overrideTarget := account.GetCacheTTLOverrideTarget()
if applyCacheTTLOverride(&response.Usage, overrideTarget) { if applyCacheTTLOverride(&response.Usage, overrideTarget) {
// 同步更新 body JSON 中的嵌套 cache_creation 对象 // 同步更新 body JSON 中的嵌套 cache_creation 对象
@@ -5608,9 +5623,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
result.Usage.InputTokens = 0 result.Usage.InputTokens = 0
} }
// Claude Max cache billing policy (group-level): force existing cache creation to 1h, // Claude Max cache billing policy (group-level): RecordUsage only checks outcome.
// otherwise simulate projection only when request carries cache signals. var apiKeyGroup *Group
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicy(input) if apiKey != nil {
apiKeyGroup = apiKey.Group
}
claudeMaxOutcome := detectClaudeMaxCacheBillingOutcomeForUsage(result.Usage, input.ParsedRequest, apiKeyGroup, result.Model)
simulatedClaudeMax := claudeMaxOutcome.Simulated simulatedClaudeMax := claudeMaxOutcome.Simulated
forcedClaudeMax1H := claudeMaxOutcome.ForcedCache1H forcedClaudeMax1H := claudeMaxOutcome.ForcedCache1H