diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 103bd086..127715dd 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -545,6 +545,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) // 转发请求 - 根据账号平台分流 + c.Set("parsed_request", parsedReq) var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index 677435ad..54f7e282 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -18,6 +18,9 @@ const ( BlockTypeFunction ) +// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events. +type UsageMapHook func(usageMap map[string]any) + // StreamingProcessor 流式响应处理器 type StreamingProcessor struct { blockType BlockType @@ -30,6 +33,7 @@ type StreamingProcessor struct { originalModel string webSearchQueries []string groundingChunks []GeminiGroundingChunk + usageMapHook UsageMapHook // 累计 usage inputTokens int @@ -45,6 +49,25 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor { } } +// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted. +func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) { + p.usageMapHook = fn +} + +func usageToMap(u ClaudeUsage) map[string]any { + m := map[string]any{ + "input_tokens": u.InputTokens, + "output_tokens": u.OutputTokens, + } + if u.CacheCreationInputTokens > 0 { + m["cache_creation_input_tokens"] = u.CacheCreationInputTokens + } + if u.CacheReadInputTokens > 0 { + m["cache_read_input_tokens"] = u.CacheReadInputTokens + } + return m +} + // ProcessLine 处理 SSE 行,返回 Claude SSE 事件 func (p *StreamingProcessor) ProcessLine(line string) []byte { line = strings.TrimSpace(line) @@ -158,6 +181,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte responseID = "msg_" + generateRandomID() } + var usageValue any = usage + if p.usageMapHook != nil { + usageMap := usageToMap(usage) + p.usageMapHook(usageMap) + usageValue = usageMap + } + message := map[string]any{ "id": responseID, "type": "message", @@ -166,7 +196,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte "model": p.originalModel, "stop_reason": nil, "stop_sequence": nil, - "usage": usage, + "usage": usageValue, } event := map[string]any{ @@ -477,13 +507,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { CacheReadInputTokens: p.cacheReadTokens, } + var usageValue any = usage + if p.usageMapHook != nil { + usageMap := usageToMap(usage) + p.usageMapHook(usageMap) + usageValue = usageMap + } + deltaEvent := map[string]any{ "type": "message_delta", "delta": map[string]any{ "stop_reason": stopReason, "stop_sequence": nil, }, - "usage": usage, + "usage": usageValue, } _, _ = result.Write(p.formatSSE("message_delta", deltaEvent)) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2bd6195a..4922de3c 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1600,7 +1600,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, var clientDisconnect bool if claudeReq.Stream { // 客户端要求流式,直接透传转换 - streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) + streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel, account.ID) if err != nil { logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err) return nil, err @@ -1610,7 +1610,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, clientDisconnect = streamRes.clientDisconnect } else { // 客户端要求非流式,收集流式响应后转换返回 - streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) + streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel, account.ID) if err != nil { logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err) return nil, err @@ -3416,7 +3416,7 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, // handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回 // 用于处理客户端非流式请求但上游只支持流式的情况 -func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) { +func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) { scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { @@ -3574,6 +3574,9 @@ returnResponse: return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") } + // Claude Max cache billing simulation (non-streaming) + claudeResp = applyClaudeMaxNonStreamingRewrite(c, claudeResp, agUsage, originalModel, accountID) + c.Data(http.StatusOK, "application/json", claudeResp) // 转换为 service.ClaudeUsage @@ -3588,7 +3591,7 @@ returnResponse: } // handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换) -func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) { +func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") @@ -3601,6 +3604,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } processor := antigravity.NewStreamingProcessor(originalModel) + setupClaudeMaxStreamingHook(c, processor, originalModel, accountID) + var firstTokenMs *int // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM scanner := bufio.NewScanner(resp.Body) diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 84b65adc..cbecfee5 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -710,7 +710,7 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) { fmt.Fprintln(pw, "") }() - result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0) _ = pr.Close() require.NoError(t, err) @@ -787,7 +787,7 @@ func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) { fmt.Fprintln(pw, "") }() - result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro") + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro", 0) _ = pr.Close() require.NoError(t, err) @@ -990,7 +990,7 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) { fmt.Fprintln(pw, "") }() - result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0) _ = pr.Close() require.NoError(t, err) @@ -1014,7 +1014,7 @@ func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) { resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} - result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0) require.NoError(t, err) require.NotNil(t, result) diff --git a/backend/internal/service/claude_max_cache_billing_policy.go b/backend/internal/service/claude_max_cache_billing_policy.go index 398c9ec8..64696d4d 100644 --- a/backend/internal/service/claude_max_cache_billing_policy.go +++ b/backend/internal/service/claude_max_cache_billing_policy.go @@ -10,46 +10,7 @@ import ( ) type claudeMaxCacheBillingOutcome struct { - Simulated bool - ForcedCache1H bool -} - -func applyClaudeMaxCacheBillingPolicy(input *RecordUsageInput) claudeMaxCacheBillingOutcome { - var out claudeMaxCacheBillingOutcome - if !shouldApplyClaudeMaxBillingRules(input) { - return out - } - - if input == nil || input.Result == nil { - return out - } - result := input.Result - usage := &result.Usage - accountID := int64(0) - if input.Account != nil { - accountID = input.Account.ID - } - - if hasCacheCreationTokens(*usage) { - // Upstream already returned cache creation usage; keep original usage. - return out - } - - if !shouldSimulateClaudeMaxUsage(input) { - return out - } - beforeInputTokens := usage.InputTokens - out.Simulated = safelyApplyClaudeMaxUsageSimulation(result, input.ParsedRequest) - if out.Simulated { - logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d", - result.Model, - accountID, - beforeInputTokens, - usage.InputTokens, - usage.CacheCreation1hTokens, - ) - } - return out + Simulated bool } // detectClaudeMaxCacheBillingOutcomeForUsage only returns whether Claude Max policy @@ -150,55 +111,18 @@ func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool { } func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool { - if !hasClaudeCacheSignals(parsed) { - return false - } if usage.InputTokens <= 0 { return false } if hasCacheCreationTokens(usage) { return false } + if !hasClaudeCacheSignals(parsed) { + return false + } return true } -func forceCacheCreationTo1H(usage *ClaudeUsage) bool { - if usage == nil || !hasCacheCreationTokens(*usage) { - return false - } - - before5m := usage.CacheCreation5mTokens - before1h := usage.CacheCreation1hTokens - beforeAgg := usage.CacheCreationInputTokens - - _ = applyCacheTTLOverride(usage, "1h") - total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens - if total <= 0 { - total = usage.CacheCreationInputTokens - } - if total <= 0 { - return false - } - - usage.CacheCreation5mTokens = 0 - usage.CacheCreation1hTokens = total - usage.CacheCreationInputTokens = total - - return before5m != usage.CacheCreation5mTokens || - before1h != usage.CacheCreation1hTokens || - beforeAgg != usage.CacheCreationInputTokens -} - -func safelyApplyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) (changed bool) { - defer func() { - if r := recover(); r != nil { - logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r) - changed = false - } - }() - return applyClaudeMaxUsageSimulation(result, parsed) -} - func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) { defer func() { if r := recover(); r != nil { @@ -209,23 +133,6 @@ func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) return projectUsageToClaudeMax1H(usage, parsed) } -func safelyForceCacheCreationTo1H(usage *ClaudeUsage) (changed bool) { - defer func() { - if r := recover(); r != nil { - logger.LegacyPrintf("service.gateway", "force_cache_creation_1h skipped: panic=%v", r) - changed = false - } - }() - return forceCacheCreationTo1H(usage) -} - -func applyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) bool { - if result == nil { - return false - } - return projectUsageToClaudeMax1H(&result.Usage, parsed) -} - func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool { if usage == nil { return false diff --git a/backend/internal/service/gateway_claude_max_response_helpers.go b/backend/internal/service/gateway_claude_max_response_helpers.go index b4c7e819..a5f5f3d2 100644 --- a/backend/internal/service/gateway_claude_max_response_helpers.go +++ b/backend/internal/service/gateway_claude_max_response_helpers.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/gin-gonic/gin" "github.com/tidwall/sjson" ) @@ -51,6 +52,18 @@ func claudeMaxGroupFromGinContext(c *gin.Context) *Group { return apiKey.Group } +func parsedRequestFromGinContext(c *gin.Context) *ParsedRequest { + if c == nil { + return nil + } + raw, exists := c.Get("parsed_request") + if !exists { + return nil + } + parsed, _ := raw.(*ParsedRequest) + return parsed +} + func applyClaudeMaxSimulationToUsage(ctx context.Context, usage *ClaudeUsage, model string, accountID int64) claudeMaxCacheBillingOutcome { var out claudeMaxCacheBillingOutcome if usage == nil { @@ -145,3 +158,39 @@ func usageIntFromAny(v any) int { } return 0 } + +// setupClaudeMaxStreamingHook 为 Antigravity 流式路径设置 SSE usage 改写 hook。 +func setupClaudeMaxStreamingHook(c *gin.Context, processor *antigravity.StreamingProcessor, originalModel string, accountID int64) { + group := claudeMaxGroupFromGinContext(c) + parsed := parsedRequestFromGinContext(c) + if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) { + return + } + processor.SetUsageMapHook(func(usageMap map[string]any) { + svcUsage := claudeUsageFromJSONMap(usageMap) + outcome := applyClaudeMaxCacheBillingPolicyToUsage(&svcUsage, parsed, group, originalModel, accountID) + if outcome.Simulated { + rewriteClaudeUsageJSONMap(usageMap, svcUsage) + } + }) +} + +// applyClaudeMaxNonStreamingRewrite 为 Antigravity 非流式路径改写响应体中的 usage。 +func applyClaudeMaxNonStreamingRewrite(c *gin.Context, claudeResp []byte, agUsage *antigravity.ClaudeUsage, originalModel string, accountID int64) []byte { + group := claudeMaxGroupFromGinContext(c) + parsed := parsedRequestFromGinContext(c) + if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) { + return claudeResp + } + svcUsage := &ClaudeUsage{ + InputTokens: agUsage.InputTokens, + OutputTokens: agUsage.OutputTokens, + CacheCreationInputTokens: agUsage.CacheCreationInputTokens, + CacheReadInputTokens: agUsage.CacheReadInputTokens, + } + outcome := applyClaudeMaxCacheBillingPolicyToUsage(svcUsage, parsed, group, originalModel, accountID) + if outcome.Simulated { + return rewriteClaudeUsageJSONBytes(claudeResp, *svcUsage) + } + return claudeResp +} diff --git a/backend/internal/service/gateway_record_usage_claude_max_test.go b/backend/internal/service/gateway_record_usage_claude_max_test.go index 7bee1b0f..2e1b5ae7 100644 --- a/backend/internal/service/gateway_record_usage_claude_max_test.go +++ b/backend/internal/service/gateway_record_usage_claude_max_test.go @@ -32,7 +32,7 @@ func newGatewayServiceForRecordUsageTest(repo UsageLogRepository) *GatewayServic } } -func TestRecordUsage_SimulateClaudeMaxEnabled_DoesNotProjectAndSkipsTTLOverride(t *testing.T) { +func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsUsageAndSkipsTTLOverride(t *testing.T) { repo := &usageLogRepoRecordUsageStub{inserted: true} svc := newGatewayServiceForRecordUsageTest(repo) @@ -195,5 +195,5 @@ func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationBypassesSimul require.Equal(t, 120, log.CacheCreation5mTokens) require.Equal(t, 0, log.CacheCreation1hTokens) require.Equal(t, 120, log.CacheCreationTokens) - require.True(t, log.CacheTTLOverridden, "existing cache_creation should remain under normal account ttl flow") + require.True(t, log.CacheTTLOverridden, "existing cache_creation with SimulateClaudeMax enabled should apply account ttl override") } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index e025c6d9..2b47509e 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5481,7 +5481,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated && !claudeMaxOutcome.ForcedCache1H { + if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated { overrideTarget := account.GetCacheTTLOverrideTarget() if applyCacheTTLOverride(&response.Usage, overrideTarget) { // 同步更新 body JSON 中的嵌套 cache_creation 对象 @@ -5623,18 +5623,18 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu result.Usage.InputTokens = 0 } - // Claude Max cache billing policy (group-level): RecordUsage only checks outcome. + // Claude Max cache billing policy (group-level): + // - GatewayService 路径: Forward 已改写 usage(含 cache tokens)→ apply 见到 cache tokens 跳过 → simulatedClaudeMax=true(通过第二条件) + // - Antigravity 路径: Forward 中 hook 改写了客户端 SSE,但 ForwardResult.Usage 是原始值 → apply 实际执行模拟 → simulatedClaudeMax=true var apiKeyGroup *Group if apiKey != nil { apiKeyGroup = apiKey.Group } claudeMaxOutcome := detectClaudeMaxCacheBillingOutcomeForUsage(result.Usage, input.ParsedRequest, apiKeyGroup, result.Model) simulatedClaudeMax := claudeMaxOutcome.Simulated - forcedClaudeMax1H := claudeMaxOutcome.ForcedCache1H - // Cache TTL Override: 确保计费时 token 分类与账号设置一致 - cacheTTLOverridden := forcedClaudeMax1H - if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax && !forcedClaudeMax1H { + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax { applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 }