mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-24 08:34:45 +08:00
refactor: decouple claude max cache simulation from RecordUsage
Extract setupClaudeMaxStreamingHook and applyClaudeMaxNonStreamingRewrite facade functions to helpers file. RecordUsage now uses detect-only (no mutation), client response rewriting handled at Forward layer.
This commit is contained in:
@@ -545,6 +545,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
|
c.Set("parsed_request", parsedReq)
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ const (
|
|||||||
BlockTypeFunction
|
BlockTypeFunction
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
|
||||||
|
type UsageMapHook func(usageMap map[string]any)
|
||||||
|
|
||||||
// StreamingProcessor 流式响应处理器
|
// StreamingProcessor 流式响应处理器
|
||||||
type StreamingProcessor struct {
|
type StreamingProcessor struct {
|
||||||
blockType BlockType
|
blockType BlockType
|
||||||
@@ -30,6 +33,7 @@ type StreamingProcessor struct {
|
|||||||
originalModel string
|
originalModel string
|
||||||
webSearchQueries []string
|
webSearchQueries []string
|
||||||
groundingChunks []GeminiGroundingChunk
|
groundingChunks []GeminiGroundingChunk
|
||||||
|
usageMapHook UsageMapHook
|
||||||
|
|
||||||
// 累计 usage
|
// 累计 usage
|
||||||
inputTokens int
|
inputTokens int
|
||||||
@@ -45,6 +49,25 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
|
||||||
|
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
|
||||||
|
p.usageMapHook = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
func usageToMap(u ClaudeUsage) map[string]any {
|
||||||
|
m := map[string]any{
|
||||||
|
"input_tokens": u.InputTokens,
|
||||||
|
"output_tokens": u.OutputTokens,
|
||||||
|
}
|
||||||
|
if u.CacheCreationInputTokens > 0 {
|
||||||
|
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
|
||||||
|
}
|
||||||
|
if u.CacheReadInputTokens > 0 {
|
||||||
|
m["cache_read_input_tokens"] = u.CacheReadInputTokens
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
@@ -158,6 +181,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
responseID = "msg_" + generateRandomID()
|
responseID = "msg_" + generateRandomID()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var usageValue any = usage
|
||||||
|
if p.usageMapHook != nil {
|
||||||
|
usageMap := usageToMap(usage)
|
||||||
|
p.usageMapHook(usageMap)
|
||||||
|
usageValue = usageMap
|
||||||
|
}
|
||||||
|
|
||||||
message := map[string]any{
|
message := map[string]any{
|
||||||
"id": responseID,
|
"id": responseID,
|
||||||
"type": "message",
|
"type": "message",
|
||||||
@@ -166,7 +196,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
"model": p.originalModel,
|
"model": p.originalModel,
|
||||||
"stop_reason": nil,
|
"stop_reason": nil,
|
||||||
"stop_sequence": nil,
|
"stop_sequence": nil,
|
||||||
"usage": usage,
|
"usage": usageValue,
|
||||||
}
|
}
|
||||||
|
|
||||||
event := map[string]any{
|
event := map[string]any{
|
||||||
@@ -477,13 +507,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
|||||||
CacheReadInputTokens: p.cacheReadTokens,
|
CacheReadInputTokens: p.cacheReadTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var usageValue any = usage
|
||||||
|
if p.usageMapHook != nil {
|
||||||
|
usageMap := usageToMap(usage)
|
||||||
|
p.usageMapHook(usageMap)
|
||||||
|
usageValue = usageMap
|
||||||
|
}
|
||||||
|
|
||||||
deltaEvent := map[string]any{
|
deltaEvent := map[string]any{
|
||||||
"type": "message_delta",
|
"type": "message_delta",
|
||||||
"delta": map[string]any{
|
"delta": map[string]any{
|
||||||
"stop_reason": stopReason,
|
"stop_reason": stopReason,
|
||||||
"stop_sequence": nil,
|
"stop_sequence": nil,
|
||||||
},
|
},
|
||||||
"usage": usage,
|
"usage": usageValue,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||||
|
|||||||
@@ -1600,7 +1600,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
var clientDisconnect bool
|
var clientDisconnect bool
|
||||||
if claudeReq.Stream {
|
if claudeReq.Stream {
|
||||||
// 客户端要求流式,直接透传转换
|
// 客户端要求流式,直接透传转换
|
||||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel, account.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
|
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1610,7 +1610,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
clientDisconnect = streamRes.clientDisconnect
|
clientDisconnect = streamRes.clientDisconnect
|
||||||
} else {
|
} else {
|
||||||
// 客户端要求非流式,收集流式响应后转换返回
|
// 客户端要求非流式,收集流式响应后转换返回
|
||||||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel, account.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
|
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -3416,7 +3416,7 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
|
|||||||
|
|
||||||
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
|
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
|
||||||
// 用于处理客户端非流式请求但上游只支持流式的情况
|
// 用于处理客户端非流式请求但上游只支持流式的情况
|
||||||
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
maxLineSize := defaultMaxLineSize
|
maxLineSize := defaultMaxLineSize
|
||||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||||
@@ -3574,6 +3574,9 @@ returnResponse:
|
|||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Claude Max cache billing simulation (non-streaming)
|
||||||
|
claudeResp = applyClaudeMaxNonStreamingRewrite(c, claudeResp, agUsage, originalModel, accountID)
|
||||||
|
|
||||||
c.Data(http.StatusOK, "application/json", claudeResp)
|
c.Data(http.StatusOK, "application/json", claudeResp)
|
||||||
|
|
||||||
// 转换为 service.ClaudeUsage
|
// 转换为 service.ClaudeUsage
|
||||||
@@ -3588,7 +3591,7 @@ returnResponse:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
|
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
|
||||||
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
@@ -3601,6 +3604,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
processor := antigravity.NewStreamingProcessor(originalModel)
|
processor := antigravity.NewStreamingProcessor(originalModel)
|
||||||
|
setupClaudeMaxStreamingHook(c, processor, originalModel, accountID)
|
||||||
|
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
|||||||
@@ -710,7 +710,7 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
|||||||
fmt.Fprintln(pw, "")
|
fmt.Fprintln(pw, "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -787,7 +787,7 @@ func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
|
|||||||
fmt.Fprintln(pw, "")
|
fmt.Fprintln(pw, "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro", 0)
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -990,7 +990,7 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
|||||||
fmt.Fprintln(pw, "")
|
fmt.Fprintln(pw, "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1014,7 +1014,7 @@ func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
|||||||
|
|
||||||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
|
|||||||
@@ -10,46 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type claudeMaxCacheBillingOutcome struct {
|
type claudeMaxCacheBillingOutcome struct {
|
||||||
Simulated bool
|
Simulated bool
|
||||||
ForcedCache1H bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyClaudeMaxCacheBillingPolicy(input *RecordUsageInput) claudeMaxCacheBillingOutcome {
|
|
||||||
var out claudeMaxCacheBillingOutcome
|
|
||||||
if !shouldApplyClaudeMaxBillingRules(input) {
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
if input == nil || input.Result == nil {
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
result := input.Result
|
|
||||||
usage := &result.Usage
|
|
||||||
accountID := int64(0)
|
|
||||||
if input.Account != nil {
|
|
||||||
accountID = input.Account.ID
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasCacheCreationTokens(*usage) {
|
|
||||||
// Upstream already returned cache creation usage; keep original usage.
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
if !shouldSimulateClaudeMaxUsage(input) {
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
beforeInputTokens := usage.InputTokens
|
|
||||||
out.Simulated = safelyApplyClaudeMaxUsageSimulation(result, input.ParsedRequest)
|
|
||||||
if out.Simulated {
|
|
||||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
|
|
||||||
result.Model,
|
|
||||||
accountID,
|
|
||||||
beforeInputTokens,
|
|
||||||
usage.InputTokens,
|
|
||||||
usage.CacheCreation1hTokens,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// detectClaudeMaxCacheBillingOutcomeForUsage only returns whether Claude Max policy
|
// detectClaudeMaxCacheBillingOutcomeForUsage only returns whether Claude Max policy
|
||||||
@@ -150,55 +111,18 @@ func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool {
|
func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool {
|
||||||
if !hasClaudeCacheSignals(parsed) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if usage.InputTokens <= 0 {
|
if usage.InputTokens <= 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if hasCacheCreationTokens(usage) {
|
if hasCacheCreationTokens(usage) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if !hasClaudeCacheSignals(parsed) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func forceCacheCreationTo1H(usage *ClaudeUsage) bool {
|
|
||||||
if usage == nil || !hasCacheCreationTokens(*usage) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
before5m := usage.CacheCreation5mTokens
|
|
||||||
before1h := usage.CacheCreation1hTokens
|
|
||||||
beforeAgg := usage.CacheCreationInputTokens
|
|
||||||
|
|
||||||
_ = applyCacheTTLOverride(usage, "1h")
|
|
||||||
total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
|
||||||
if total <= 0 {
|
|
||||||
total = usage.CacheCreationInputTokens
|
|
||||||
}
|
|
||||||
if total <= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
usage.CacheCreation5mTokens = 0
|
|
||||||
usage.CacheCreation1hTokens = total
|
|
||||||
usage.CacheCreationInputTokens = total
|
|
||||||
|
|
||||||
return before5m != usage.CacheCreation5mTokens ||
|
|
||||||
before1h != usage.CacheCreation1hTokens ||
|
|
||||||
beforeAgg != usage.CacheCreationInputTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
func safelyApplyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) (changed bool) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r)
|
|
||||||
changed = false
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return applyClaudeMaxUsageSimulation(result, parsed)
|
|
||||||
}
|
|
||||||
|
|
||||||
func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) {
|
func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@@ -209,23 +133,6 @@ func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest)
|
|||||||
return projectUsageToClaudeMax1H(usage, parsed)
|
return projectUsageToClaudeMax1H(usage, parsed)
|
||||||
}
|
}
|
||||||
|
|
||||||
func safelyForceCacheCreationTo1H(usage *ClaudeUsage) (changed bool) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "force_cache_creation_1h skipped: panic=%v", r)
|
|
||||||
changed = false
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return forceCacheCreationTo1H(usage)
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyClaudeMaxUsageSimulation(result *ForwardResult, parsed *ParsedRequest) bool {
|
|
||||||
if result == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return projectUsageToClaudeMax1H(&result.Usage, parsed)
|
|
||||||
}
|
|
||||||
|
|
||||||
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
|
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
|
||||||
if usage == nil {
|
if usage == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -51,6 +52,18 @@ func claudeMaxGroupFromGinContext(c *gin.Context) *Group {
|
|||||||
return apiKey.Group
|
return apiKey.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parsedRequestFromGinContext(c *gin.Context) *ParsedRequest {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw, exists := c.Get("parsed_request")
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
parsed, _ := raw.(*ParsedRequest)
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
|
||||||
func applyClaudeMaxSimulationToUsage(ctx context.Context, usage *ClaudeUsage, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
func applyClaudeMaxSimulationToUsage(ctx context.Context, usage *ClaudeUsage, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
||||||
var out claudeMaxCacheBillingOutcome
|
var out claudeMaxCacheBillingOutcome
|
||||||
if usage == nil {
|
if usage == nil {
|
||||||
@@ -145,3 +158,39 @@ func usageIntFromAny(v any) int {
|
|||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupClaudeMaxStreamingHook 为 Antigravity 流式路径设置 SSE usage 改写 hook。
|
||||||
|
func setupClaudeMaxStreamingHook(c *gin.Context, processor *antigravity.StreamingProcessor, originalModel string, accountID int64) {
|
||||||
|
group := claudeMaxGroupFromGinContext(c)
|
||||||
|
parsed := parsedRequestFromGinContext(c)
|
||||||
|
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
processor.SetUsageMapHook(func(usageMap map[string]any) {
|
||||||
|
svcUsage := claudeUsageFromJSONMap(usageMap)
|
||||||
|
outcome := applyClaudeMaxCacheBillingPolicyToUsage(&svcUsage, parsed, group, originalModel, accountID)
|
||||||
|
if outcome.Simulated {
|
||||||
|
rewriteClaudeUsageJSONMap(usageMap, svcUsage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyClaudeMaxNonStreamingRewrite 为 Antigravity 非流式路径改写响应体中的 usage。
|
||||||
|
func applyClaudeMaxNonStreamingRewrite(c *gin.Context, claudeResp []byte, agUsage *antigravity.ClaudeUsage, originalModel string, accountID int64) []byte {
|
||||||
|
group := claudeMaxGroupFromGinContext(c)
|
||||||
|
parsed := parsedRequestFromGinContext(c)
|
||||||
|
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
|
||||||
|
return claudeResp
|
||||||
|
}
|
||||||
|
svcUsage := &ClaudeUsage{
|
||||||
|
InputTokens: agUsage.InputTokens,
|
||||||
|
OutputTokens: agUsage.OutputTokens,
|
||||||
|
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
|
||||||
|
CacheReadInputTokens: agUsage.CacheReadInputTokens,
|
||||||
|
}
|
||||||
|
outcome := applyClaudeMaxCacheBillingPolicyToUsage(svcUsage, parsed, group, originalModel, accountID)
|
||||||
|
if outcome.Simulated {
|
||||||
|
return rewriteClaudeUsageJSONBytes(claudeResp, *svcUsage)
|
||||||
|
}
|
||||||
|
return claudeResp
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func newGatewayServiceForRecordUsageTest(repo UsageLogRepository) *GatewayServic
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRecordUsage_SimulateClaudeMaxEnabled_DoesNotProjectAndSkipsTTLOverride(t *testing.T) {
|
func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsUsageAndSkipsTTLOverride(t *testing.T) {
|
||||||
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
||||||
svc := newGatewayServiceForRecordUsageTest(repo)
|
svc := newGatewayServiceForRecordUsageTest(repo)
|
||||||
|
|
||||||
@@ -195,5 +195,5 @@ func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationBypassesSimul
|
|||||||
require.Equal(t, 120, log.CacheCreation5mTokens)
|
require.Equal(t, 120, log.CacheCreation5mTokens)
|
||||||
require.Equal(t, 0, log.CacheCreation1hTokens)
|
require.Equal(t, 0, log.CacheCreation1hTokens)
|
||||||
require.Equal(t, 120, log.CacheCreationTokens)
|
require.Equal(t, 120, log.CacheCreationTokens)
|
||||||
require.True(t, log.CacheTTLOverridden, "existing cache_creation should remain under normal account ttl flow")
|
require.True(t, log.CacheTTLOverridden, "existing cache_creation with SimulateClaudeMax enabled should apply account ttl override")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5481,7 +5481,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
|
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
|
||||||
if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated && !claudeMaxOutcome.ForcedCache1H {
|
if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated {
|
||||||
overrideTarget := account.GetCacheTTLOverrideTarget()
|
overrideTarget := account.GetCacheTTLOverrideTarget()
|
||||||
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
|
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
|
||||||
// 同步更新 body JSON 中的嵌套 cache_creation 对象
|
// 同步更新 body JSON 中的嵌套 cache_creation 对象
|
||||||
@@ -5623,18 +5623,18 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
result.Usage.InputTokens = 0
|
result.Usage.InputTokens = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claude Max cache billing policy (group-level): RecordUsage only checks outcome.
|
// Claude Max cache billing policy (group-level):
|
||||||
|
// - GatewayService 路径: Forward 已改写 usage(含 cache tokens)→ apply 见到 cache tokens 跳过 → simulatedClaudeMax=true(通过第二条件)
|
||||||
|
// - Antigravity 路径: Forward 中 hook 改写了客户端 SSE,但 ForwardResult.Usage 是原始值 → apply 实际执行模拟 → simulatedClaudeMax=true
|
||||||
var apiKeyGroup *Group
|
var apiKeyGroup *Group
|
||||||
if apiKey != nil {
|
if apiKey != nil {
|
||||||
apiKeyGroup = apiKey.Group
|
apiKeyGroup = apiKey.Group
|
||||||
}
|
}
|
||||||
claudeMaxOutcome := detectClaudeMaxCacheBillingOutcomeForUsage(result.Usage, input.ParsedRequest, apiKeyGroup, result.Model)
|
claudeMaxOutcome := detectClaudeMaxCacheBillingOutcomeForUsage(result.Usage, input.ParsedRequest, apiKeyGroup, result.Model)
|
||||||
simulatedClaudeMax := claudeMaxOutcome.Simulated
|
simulatedClaudeMax := claudeMaxOutcome.Simulated
|
||||||
forcedClaudeMax1H := claudeMaxOutcome.ForcedCache1H
|
|
||||||
|
|
||||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||||
cacheTTLOverridden := forcedClaudeMax1H
|
cacheTTLOverridden := false
|
||||||
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax && !forcedClaudeMax1H {
|
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
|
||||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user