From 61ef73cb12dbbd997e6fc7ef8c50d8fa130476b4 Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 27 Feb 2026 16:14:07 +0800 Subject: [PATCH] refactor: isolate claude max response usage simulation by group toggle --- .../claude_max_cache_billing_policy.go | 32 +--- .../gateway_claude_max_response_helpers.go | 147 +++++++++++++++ .../gateway_record_usage_claude_max_test.go | 27 ++- .../gateway_response_usage_sync_test.go | 170 ++++++++++++++++++ backend/internal/service/gateway_service.go | 28 ++- 5 files changed, 356 insertions(+), 48 deletions(-) create mode 100644 backend/internal/service/gateway_claude_max_response_helpers.go create mode 100644 backend/internal/service/gateway_response_usage_sync_test.go diff --git a/backend/internal/service/claude_max_cache_billing_policy.go b/backend/internal/service/claude_max_cache_billing_policy.go index 021d968c..398c9ec8 100644 --- a/backend/internal/service/claude_max_cache_billing_policy.go +++ b/backend/internal/service/claude_max_cache_billing_policy.go @@ -31,19 +31,7 @@ func applyClaudeMaxCacheBillingPolicy(input *RecordUsageInput) claudeMaxCacheBil } if hasCacheCreationTokens(*usage) { - before5m := usage.CacheCreation5mTokens - before1h := usage.CacheCreation1hTokens - out.ForcedCache1H = safelyForceCacheCreationTo1H(usage) - if out.ForcedCache1H { - logger.LegacyPrintf("service.gateway", "force_claude_max_cache_1h: model=%s account=%d cache_creation_5m:%d->%d cache_creation_1h:%d->%d", - result.Model, - accountID, - before5m, - usage.CacheCreation5mTokens, - before1h, - usage.CacheCreation1hTokens, - ) - } + // Upstream already returned cache creation usage; keep original usage. return out } @@ -72,7 +60,7 @@ func detectClaudeMaxCacheBillingOutcomeForUsage(usage ClaudeUsage, parsed *Parse return out } if hasCacheCreationTokens(usage) { - out.ForcedCache1H = true + // Upstream already returned cache creation usage; keep original usage. return out } if shouldSimulateClaudeMaxUsageForUsage(usage, parsed) { @@ -93,21 +81,7 @@ func applyClaudeMaxCacheBillingPolicyToUsage(usage *ClaudeUsage, parsed *ParsedR } if hasCacheCreationTokens(*usage) { - before5m := usage.CacheCreation5mTokens - before1h := usage.CacheCreation1hTokens - changed := safelyForceCacheCreationTo1H(usage) - // Even when value is already 1h, still mark forced to skip account TTL override. - out.ForcedCache1H = true - if changed { - logger.LegacyPrintf("service.gateway", "force_claude_max_cache_1h: model=%s account=%d cache_creation_5m:%d->%d cache_creation_1h:%d->%d", - resolvedModel, - accountID, - before5m, - usage.CacheCreation5mTokens, - before1h, - usage.CacheCreation1hTokens, - ) - } + // Upstream already returned cache creation usage; keep original usage. return out } diff --git a/backend/internal/service/gateway_claude_max_response_helpers.go b/backend/internal/service/gateway_claude_max_response_helpers.go new file mode 100644 index 00000000..b4c7e819 --- /dev/null +++ b/backend/internal/service/gateway_claude_max_response_helpers.go @@ -0,0 +1,147 @@ +package service + +import ( + "context" + "encoding/json" + + "github.com/gin-gonic/gin" + "github.com/tidwall/sjson" +) + +type claudeMaxResponseRewriteContext struct { + Parsed *ParsedRequest + Group *Group +} + +type claudeMaxResponseRewriteContextKeyType struct{} + +var claudeMaxResponseRewriteContextKey = claudeMaxResponseRewriteContextKeyType{} + +func withClaudeMaxResponseRewriteContext(ctx context.Context, c *gin.Context, parsed *ParsedRequest) context.Context { + if ctx == nil { + ctx = context.Background() + } + value := claudeMaxResponseRewriteContext{ + Parsed: parsed, + Group: claudeMaxGroupFromGinContext(c), + } + return context.WithValue(ctx, claudeMaxResponseRewriteContextKey, value) +} + +func claudeMaxResponseRewriteContextFromContext(ctx context.Context) claudeMaxResponseRewriteContext { + if ctx == nil { + return claudeMaxResponseRewriteContext{} + } + value, _ := ctx.Value(claudeMaxResponseRewriteContextKey).(claudeMaxResponseRewriteContext) + return value +} + +func claudeMaxGroupFromGinContext(c *gin.Context) *Group { + if c == nil { + return nil + } + raw, exists := c.Get("api_key") + if !exists { + return nil + } + apiKey, ok := raw.(*APIKey) + if !ok || apiKey == nil { + return nil + } + return apiKey.Group +} + +func applyClaudeMaxSimulationToUsage(ctx context.Context, usage *ClaudeUsage, model string, accountID int64) claudeMaxCacheBillingOutcome { + var out claudeMaxCacheBillingOutcome + if usage == nil { + return out + } + rewriteCtx := claudeMaxResponseRewriteContextFromContext(ctx) + return applyClaudeMaxCacheBillingPolicyToUsage(usage, rewriteCtx.Parsed, rewriteCtx.Group, model, accountID) +} + +func applyClaudeMaxSimulationToUsageJSONMap(ctx context.Context, usageObj map[string]any, model string, accountID int64) claudeMaxCacheBillingOutcome { + var out claudeMaxCacheBillingOutcome + if usageObj == nil { + return out + } + usage := claudeUsageFromJSONMap(usageObj) + out = applyClaudeMaxSimulationToUsage(ctx, &usage, model, accountID) + if out.Simulated { + rewriteClaudeUsageJSONMap(usageObj, usage) + } + return out +} + +func rewriteClaudeUsageJSONBytes(body []byte, usage ClaudeUsage) []byte { + updated := body + var err error + + updated, err = sjson.SetBytes(updated, "usage.input_tokens", usage.InputTokens) + if err != nil { + return body + } + updated, err = sjson.SetBytes(updated, "usage.cache_creation_input_tokens", usage.CacheCreationInputTokens) + if err != nil { + return body + } + updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_5m_input_tokens", usage.CacheCreation5mTokens) + if err != nil { + return body + } + updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_1h_input_tokens", usage.CacheCreation1hTokens) + if err != nil { + return body + } + return updated +} + +func claudeUsageFromJSONMap(usageObj map[string]any) ClaudeUsage { + var usage ClaudeUsage + if usageObj == nil { + return usage + } + + usage.InputTokens = usageIntFromAny(usageObj["input_tokens"]) + usage.OutputTokens = usageIntFromAny(usageObj["output_tokens"]) + usage.CacheCreationInputTokens = usageIntFromAny(usageObj["cache_creation_input_tokens"]) + usage.CacheReadInputTokens = usageIntFromAny(usageObj["cache_read_input_tokens"]) + + if ccObj, ok := usageObj["cache_creation"].(map[string]any); ok { + usage.CacheCreation5mTokens = usageIntFromAny(ccObj["ephemeral_5m_input_tokens"]) + usage.CacheCreation1hTokens = usageIntFromAny(ccObj["ephemeral_1h_input_tokens"]) + } + return usage +} + +func rewriteClaudeUsageJSONMap(usageObj map[string]any, usage ClaudeUsage) { + if usageObj == nil { + return + } + usageObj["input_tokens"] = usage.InputTokens + usageObj["cache_creation_input_tokens"] = usage.CacheCreationInputTokens + + ccObj, _ := usageObj["cache_creation"].(map[string]any) + if ccObj == nil { + ccObj = make(map[string]any, 2) + usageObj["cache_creation"] = ccObj + } + ccObj["ephemeral_5m_input_tokens"] = usage.CacheCreation5mTokens + ccObj["ephemeral_1h_input_tokens"] = usage.CacheCreation1hTokens +} + +func usageIntFromAny(v any) int { + switch value := v.(type) { + case int: + return value + case int64: + return int(value) + case float64: + return int(value) + case json.Number: + if n, err := value.Int64(); err == nil { + return int(n) + } + } + return 0 +} diff --git a/backend/internal/service/gateway_record_usage_claude_max_test.go b/backend/internal/service/gateway_record_usage_claude_max_test.go index 445519f8..7bee1b0f 100644 --- a/backend/internal/service/gateway_record_usage_claude_max_test.go +++ b/backend/internal/service/gateway_record_usage_claude_max_test.go @@ -32,7 +32,7 @@ func newGatewayServiceForRecordUsageTest(repo UsageLogRepository) *GatewayServic } } -func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsAndSkipsTTLOverride(t *testing.T) { +func TestRecordUsage_SimulateClaudeMaxEnabled_DoesNotProjectAndSkipsTTLOverride(t *testing.T) { repo := &usageLogRepoRecordUsageStub{inserted: true} svc := newGatewayServiceForRecordUsageTest(repo) @@ -92,12 +92,11 @@ func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsAndSkipsTTLOverride(t *tes require.NotNil(t, repo.last) log := repo.last - total := log.InputTokens + log.CacheCreation5mTokens + log.CacheCreation1hTokens - require.Equal(t, 160, total, "token 总量应保持不变") - require.Greater(t, log.CacheCreation1hTokens, 0, "应映射为 1h cache creation") - require.Equal(t, 0, log.CacheCreation5mTokens, "模拟成功后不应再被 TTL override 改写为 5m") - require.Equal(t, log.CacheCreation1hTokens, log.CacheCreationTokens, "聚合 cache_creation_tokens 应与 1h 一致") - require.False(t, log.CacheTTLOverridden, "模拟成功时应跳过 TTL override 标记") + require.Equal(t, 160, log.InputTokens) + require.Equal(t, 0, log.CacheCreationTokens) + require.Equal(t, 0, log.CacheCreation5mTokens) + require.Equal(t, 0, log.CacheCreation1hTokens) + require.False(t, log.CacheTTLOverridden, "simulate outcome should skip account ttl override") } func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T) { @@ -144,12 +143,12 @@ func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T) log := repo.last require.Equal(t, 120, log.CacheCreationTokens) - require.Equal(t, 120, log.CacheCreation5mTokens, "关闭模拟时应执行 TTL override 到 5m") + require.Equal(t, 120, log.CacheCreation5mTokens) require.Equal(t, 0, log.CacheCreation1hTokens) - require.True(t, log.CacheTTLOverridden, "TTL override 生效时应打标") + require.True(t, log.CacheTTLOverridden) } -func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationForce1H(t *testing.T) { +func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationBypassesSimulation(t *testing.T) { repo := &usageLogRepoRecordUsageStub{inserted: true} svc := newGatewayServiceForRecordUsageTest(repo) @@ -192,9 +191,9 @@ func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationForce1H(t *te require.NotNil(t, repo.last) log := repo.last - require.Equal(t, 20, log.InputTokens, "existing cache creation should not project input tokens") - require.Equal(t, 0, log.CacheCreation5mTokens, "existing cache creation should be forced to 1h") - require.Equal(t, 120, log.CacheCreation1hTokens) + require.Equal(t, 20, log.InputTokens) + require.Equal(t, 120, log.CacheCreation5mTokens) + require.Equal(t, 0, log.CacheCreation1hTokens) require.Equal(t, 120, log.CacheCreationTokens) - require.True(t, log.CacheTTLOverridden, "force-to-1h should mark cache ttl overridden") + require.True(t, log.CacheTTLOverridden, "existing cache_creation should remain under normal account ttl flow") } diff --git a/backend/internal/service/gateway_response_usage_sync_test.go b/backend/internal/service/gateway_response_usage_sync_test.go new file mode 100644 index 00000000..445ee8ad --- /dev/null +++ b/backend/internal/service/gateway_response_usage_sync_test.go @@ -0,0 +1,170 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestHandleNonStreamingResponse_UsageAlignedWithClaudeMaxSimulation(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := &GatewayService{ + cfg: &config.Config{}, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 11, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "cache_ttl_override_enabled": true, + "cache_ttl_override_target": "5m", + }, + } + group := &Group{ + ID: 99, + Platform: PlatformAnthropic, + SimulateClaudeMaxEnabled: true, + } + parsed := &ParsedRequest{ + Model: "claude-sonnet-4", + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": "long cached context", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "new user question", + }, + }, + }, + }, + } + + upstreamBody := []byte(`{"id":"msg_1","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: ioNopCloserBytes(upstreamBody), + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil)) + c.Set("api_key", &APIKey{Group: group}) + requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed) + + usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4") + require.NoError(t, err) + require.NotNil(t, usage) + + var rendered struct { + Usage ClaudeUsage `json:"usage"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rendered)) + rendered.Usage.CacheCreation5mTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_5m_input_tokens").Int()) + rendered.Usage.CacheCreation1hTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_1h_input_tokens").Int()) + + require.Equal(t, rendered.Usage.InputTokens, usage.InputTokens) + require.Equal(t, rendered.Usage.OutputTokens, usage.OutputTokens) + require.Equal(t, rendered.Usage.CacheCreationInputTokens, usage.CacheCreationInputTokens) + require.Equal(t, rendered.Usage.CacheCreation5mTokens, usage.CacheCreation5mTokens) + require.Equal(t, rendered.Usage.CacheCreation1hTokens, usage.CacheCreation1hTokens) + require.Equal(t, rendered.Usage.CacheReadInputTokens, usage.CacheReadInputTokens) + + require.Greater(t, usage.CacheCreation1hTokens, 0) + require.Equal(t, 0, usage.CacheCreation5mTokens) + require.Less(t, usage.InputTokens, 120) +} + +func TestHandleNonStreamingResponse_ClaudeMaxDisabled_NoSimulationIntercept(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := &GatewayService{ + cfg: &config.Config{}, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 12, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "cache_ttl_override_enabled": true, + "cache_ttl_override_target": "5m", + }, + } + group := &Group{ + ID: 100, + Platform: PlatformAnthropic, + SimulateClaudeMaxEnabled: false, + } + parsed := &ParsedRequest{ + Model: "claude-sonnet-4", + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "text", + "text": "long cached context", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + map[string]any{ + "type": "text", + "text": "new user question", + }, + }, + }, + }, + } + + upstreamBody := []byte(`{"id":"msg_2","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: ioNopCloserBytes(upstreamBody), + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil)) + c.Set("api_key", &APIKey{Group: group}) + requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed) + + usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4") + require.NoError(t, err) + require.NotNil(t, usage) + + require.Equal(t, 120, usage.InputTokens) + require.Equal(t, 0, usage.CacheCreationInputTokens) + require.Equal(t, 0, usage.CacheCreation5mTokens) + require.Equal(t, 0, usage.CacheCreation1hTokens) +} + +func ioNopCloserBytes(b []byte) *readCloserFromBytes { + return &readCloserFromBytes{Reader: bytes.NewReader(b)} +} + +type readCloserFromBytes struct { + *bytes.Reader +} + +func (r *readCloserFromBytes) Close() error { + return nil +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 53b1fd28..e025c6d9 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3709,6 +3709,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 处理正常响应 + ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed) var usage *ClaudeUsage var firstTokenMs *int var clientDisconnect bool @@ -5105,6 +5106,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage + skipAccountTTLOverride := false pendingEventLines := make([]string, 0, 4) @@ -5164,17 +5166,25 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { reconcileCachedTokens(u) + claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + skipAccountTTLOverride = true + } } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { reconcileCachedTokens(u) + claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + skipAccountTTLOverride = true + } } } // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { + if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride { overrideTarget := account.GetCacheTTLOverrideTarget() if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { @@ -5465,8 +5475,13 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } + claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + body = rewriteClaudeUsageJSONBytes(body, response.Usage) + } + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { + if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated && !claudeMaxOutcome.ForcedCache1H { overrideTarget := account.GetCacheTTLOverrideTarget() if applyCacheTTLOverride(&response.Usage, overrideTarget) { // 同步更新 body JSON 中的嵌套 cache_creation 对象 @@ -5608,9 +5623,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu result.Usage.InputTokens = 0 } - // Claude Max cache billing policy (group-level): force existing cache creation to 1h, - // otherwise simulate projection only when request carries cache signals. - claudeMaxOutcome := applyClaudeMaxCacheBillingPolicy(input) + // Claude Max cache billing policy (group-level): RecordUsage only checks outcome. + var apiKeyGroup *Group + if apiKey != nil { + apiKeyGroup = apiKey.Group + } + claudeMaxOutcome := detectClaudeMaxCacheBillingOutcomeForUsage(result.Usage, input.ParsedRequest, apiKeyGroup, result.Model) simulatedClaudeMax := claudeMaxOutcome.Simulated forcedClaudeMax1H := claudeMaxOutcome.ForcedCache1H