mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-29 19:04:48 +08:00
feat: decouple billing correctness from usage log batching
This commit is contained in:
@@ -50,6 +50,7 @@ const (
|
||||
|
||||
defaultUserGroupRateCacheTTL = 30 * time.Second
|
||||
defaultModelsListCacheTTL = 15 * time.Second
|
||||
postUsageBillingTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -106,6 +107,52 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) {
|
||||
return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load()
|
||||
}
|
||||
|
||||
func claudeUsageHasAnyTokens(usage *ClaudeUsage) bool {
|
||||
return usage != nil && (usage.InputTokens > 0 ||
|
||||
usage.OutputTokens > 0 ||
|
||||
usage.CacheCreationInputTokens > 0 ||
|
||||
usage.CacheReadInputTokens > 0 ||
|
||||
usage.CacheCreation5mTokens > 0 ||
|
||||
usage.CacheCreation1hTokens > 0)
|
||||
}
|
||||
|
||||
func openAIUsageHasAnyTokens(usage *OpenAIUsage) bool {
|
||||
return usage != nil && (usage.InputTokens > 0 ||
|
||||
usage.OutputTokens > 0 ||
|
||||
usage.CacheCreationInputTokens > 0 ||
|
||||
usage.CacheReadInputTokens > 0)
|
||||
}
|
||||
|
||||
func openAIStreamEventIsTerminal(data string) bool {
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if trimmed == "[DONE]" {
|
||||
return true
|
||||
}
|
||||
switch gjson.Get(trimmed, "type").String() {
|
||||
case "response.completed", "response.done", "response.failed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func anthropicStreamEventIsTerminal(eventName, data string) bool {
|
||||
if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") {
|
||||
return true
|
||||
}
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if trimmed == "[DONE]" {
|
||||
return true
|
||||
}
|
||||
return gjson.Get(trimmed, "type").String() == "message_stop"
|
||||
}
|
||||
|
||||
func cloneStringSlice(src []string) []string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
@@ -504,6 +551,7 @@ type GatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageBillingRepo UsageBillingRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
@@ -537,6 +585,7 @@ func NewGatewayService(
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
usageBillingRepo UsageBillingRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
@@ -563,6 +612,7 @@ func NewGatewayService(
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageBillingRepo: usageBillingRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
@@ -4049,7 +4099,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4127,7 +4179,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// also downgrade tool_use/tool_result blocks to text.
|
||||
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseRetryCtx()
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
@@ -4159,7 +4213,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseRetryCtx2()
|
||||
if buildErr2 == nil {
|
||||
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr2 == nil {
|
||||
@@ -4226,7 +4282,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
rectifiedBody, applied := RectifyThinkingBudget(body)
|
||||
if applied && time.Since(retryStart) < maxRetryElapsed {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
|
||||
budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseBudgetRetryCtx()
|
||||
if buildErr == nil {
|
||||
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
@@ -4498,7 +4556,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
||||
var resp *http.Response
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4774,6 +4834,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
sawTerminalEvent := false
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@@ -4836,17 +4897,20 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
|
||||
flusher.Flush()
|
||||
}
|
||||
if !sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err)
|
||||
}
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v",
|
||||
account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err())
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
@@ -4858,11 +4922,19 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
line := ev.line
|
||||
if data, ok := extractAnthropicSSEDataLine(line); ok {
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if anthropicStreamEventIsTerminal("", trimmed) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsagePassthrough(data, usage)
|
||||
} else {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
}
|
||||
|
||||
if !clientDisconnected {
|
||||
@@ -4884,8 +4956,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
||||
if s.rateLimitService != nil {
|
||||
@@ -6011,6 +6082,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
sawTerminalEvent := false
|
||||
|
||||
pendingEventLines := make([]string, 0, 4)
|
||||
|
||||
@@ -6041,6 +6113,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
if dataLine == "[DONE]" {
|
||||
sawTerminalEvent = true
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
block = "event: " + eventName + "\n"
|
||||
@@ -6107,6 +6180,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
usagePatch := s.extractSSEUsagePatch(event)
|
||||
if anthropicStreamEventIsTerminal(eventName, dataLine) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
if !eventChanged {
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
@@ -6140,18 +6216,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// 上游完成,返回结果
|
||||
if !sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err)
|
||||
}
|
||||
// 客户端未断开,正常的错误处理
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
@@ -6209,9 +6289,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
// 客户端已断开,上游也超时了,返回已收集的 usage
|
||||
logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
@@ -6557,15 +6635,16 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
}
|
||||
|
||||
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
|
||||
@@ -6574,6 +6653,14 @@ type APIKeyQuotaUpdater interface {
|
||||
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
|
||||
}
|
||||
|
||||
type apiKeyAuthCacheInvalidator interface {
|
||||
InvalidateAuthCacheByKey(ctx context.Context, key string)
|
||||
}
|
||||
|
||||
type usageLogBestEffortWriter interface {
|
||||
CreateBestEffort(ctx context.Context, log *UsageLog) error
|
||||
}
|
||||
|
||||
// postUsageBillingParams 统一扣费所需的参数
|
||||
type postUsageBillingParams struct {
|
||||
Cost *CostBreakdown
|
||||
@@ -6581,6 +6668,7 @@ type postUsageBillingParams struct {
|
||||
APIKey *APIKey
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
RequestPayloadHash string
|
||||
IsSubscriptionBill bool
|
||||
AccountRateMultiplier float64
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
@@ -6592,19 +6680,22 @@ type postUsageBillingParams struct {
|
||||
// - API Key 限速用量更新
|
||||
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
|
||||
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
||||
billingCtx, cancel := detachedBillingContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
cost := p.Cost
|
||||
|
||||
// 1. 订阅 / 余额扣费
|
||||
if p.IsSubscriptionBill {
|
||||
if cost.TotalCost > 0 {
|
||||
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
|
||||
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
|
||||
@@ -6613,31 +6704,187 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
|
||||
// 2. API Key 配额
|
||||
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. API Key 限速用量
|
||||
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 更新账号最近使用时间
|
||||
finalizePostUsageBilling(p, deps)
|
||||
}
|
||||
|
||||
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
|
||||
if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
if ctx != nil {
|
||||
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||
return "client:" + strings.TrimSpace(clientRequestID)
|
||||
}
|
||||
if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
|
||||
return "local:" + strings.TrimSpace(requestID)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string {
|
||||
if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" {
|
||||
return payloadHash
|
||||
}
|
||||
if ctx != nil {
|
||||
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||
return "client:" + strings.TrimSpace(clientRequestID)
|
||||
}
|
||||
if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
|
||||
return "local:" + strings.TrimSpace(requestID)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand {
|
||||
if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := &UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: p.APIKey.ID,
|
||||
UserID: p.User.ID,
|
||||
AccountID: p.Account.ID,
|
||||
AccountType: p.Account.Type,
|
||||
RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash),
|
||||
}
|
||||
if usageLog != nil {
|
||||
cmd.Model = usageLog.Model
|
||||
cmd.BillingType = usageLog.BillingType
|
||||
cmd.InputTokens = usageLog.InputTokens
|
||||
cmd.OutputTokens = usageLog.OutputTokens
|
||||
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
||||
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
||||
cmd.ImageCount = usageLog.ImageCount
|
||||
if usageLog.MediaType != nil {
|
||||
cmd.MediaType = *usageLog.MediaType
|
||||
}
|
||||
if usageLog.ServiceTier != nil {
|
||||
cmd.ServiceTier = *usageLog.ServiceTier
|
||||
}
|
||||
if usageLog.ReasoningEffort != nil {
|
||||
cmd.ReasoningEffort = *usageLog.ReasoningEffort
|
||||
}
|
||||
if usageLog.SubscriptionID != nil {
|
||||
cmd.SubscriptionID = usageLog.SubscriptionID
|
||||
}
|
||||
}
|
||||
|
||||
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
|
||||
cmd.SubscriptionID = &p.Subscription.ID
|
||||
cmd.SubscriptionCost = p.Cost.TotalCost
|
||||
} else if p.Cost.ActualCost > 0 {
|
||||
cmd.BalanceCost = p.Cost.ActualCost
|
||||
}
|
||||
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
cmd.APIKeyQuotaCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
|
||||
}
|
||||
|
||||
cmd.Normalize()
|
||||
return cmd
|
||||
}
|
||||
|
||||
func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) {
|
||||
if p == nil || deps == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := buildUsageBillingCommand(requestID, usageLog, p)
|
||||
if cmd == nil || cmd.RequestID == "" || repo == nil {
|
||||
postUsageBilling(ctx, p, deps)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
billingCtx, cancel := detachedBillingContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
result, err := repo.Apply(billingCtx, cmd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if result == nil || !result.Applied {
|
||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if result.APIKeyQuotaExhausted {
|
||||
if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" {
|
||||
invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key)
|
||||
}
|
||||
}
|
||||
|
||||
finalizePostUsageBilling(p, deps)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
|
||||
if p == nil || p.Cost == nil || deps == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if p.IsSubscriptionBill {
|
||||
if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost)
|
||||
}
|
||||
} else if p.Cost.ActualCost > 0 && p.User != nil {
|
||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
|
||||
}
|
||||
|
||||
if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() {
|
||||
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost)
|
||||
}
|
||||
|
||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||
}
|
||||
|
||||
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
base := context.Background()
|
||||
if ctx != nil {
|
||||
base = context.WithoutCancel(ctx)
|
||||
}
|
||||
return context.WithTimeout(base, postUsageBillingTimeout)
|
||||
}
|
||||
|
||||
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||
if !stream {
|
||||
return ctx, func() {}
|
||||
}
|
||||
if ctx == nil {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
return context.WithoutCancel(ctx), func() {}
|
||||
}
|
||||
|
||||
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
|
||||
type billingDeps struct {
|
||||
accountRepo AccountRepository
|
||||
@@ -6657,6 +6904,28 @@ func (s *GatewayService) billingDeps() *billingDeps {
|
||||
}
|
||||
}
|
||||
|
||||
func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) {
|
||||
if repo == nil || usageLog == nil {
|
||||
return
|
||||
}
|
||||
usageCtx, cancel := detachedBillingContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
if writer, ok := repo.(usageLogBestEffortWriter); ok {
|
||||
if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil {
|
||||
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
|
||||
if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil {
|
||||
logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := repo.Create(usageCtx, usageLog); err != nil {
|
||||
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
@@ -6758,11 +7027,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
mediaType = &result.MediaType
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
@@ -6807,33 +7077,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -6844,13 +7113,14 @@ type RecordUsageLongContextInput struct {
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
LongContextThreshold int // 长上下文阈值(如 200000)
|
||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService *APIKeyService // API Key 配额服务(可选)
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||
LongContextThreshold int // 长上下文阈值(如 200000)
|
||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
|
||||
}
|
||||
|
||||
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
||||
@@ -6933,11 +7203,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
imageSize = &result.ImageSize
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
@@ -6981,33 +7252,32 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user