mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-18 05:44:46 +08:00
feat: decouple billing correctness from usage log batching
This commit is contained in:
@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
|
||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
|
||||
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||
}
|
||||
|
||||
@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
||||
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
|
||||
|
||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||
if aggErr != nil {
|
||||
@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
||||
if usageErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil {
|
||||
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
|
||||
if dedupErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil && dedupErr == nil {
|
||||
s.lastRetentionCleanup.Store(now)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,12 +12,18 @@ import (
|
||||
|
||||
type dashboardAggregationRepoTestStub struct {
|
||||
aggregateCalls int
|
||||
recomputeCalls int
|
||||
cleanupUsageCalls int
|
||||
cleanupDedupCalls int
|
||||
ensurePartitionCalls int
|
||||
lastStart time.Time
|
||||
lastEnd time.Time
|
||||
watermark time.Time
|
||||
aggregateErr error
|
||||
cleanupAggregatesErr error
|
||||
cleanupUsageErr error
|
||||
cleanupDedupErr error
|
||||
ensurePartitionErr error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
s.recomputeCalls++
|
||||
return s.AggregateRange(ctx, start, end)
|
||||
}
|
||||
|
||||
@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
s.cleanupUsageCalls++
|
||||
return s.cleanupUsageErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
s.cleanupDedupCalls++
|
||||
return s.cleanupDedupErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
s.ensurePartitionCalls++
|
||||
return s.ensurePartitionErr
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||
@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
require.Equal(t, 1, repo.cleanupUsageCalls)
|
||||
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
UsageBillingDedupDays: 2,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.runScheduledAggregation()
|
||||
|
||||
require.Equal(t, 1, repo.ensurePartitionCalls)
|
||||
require.Equal(t, 1, repo.aggregateCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||
|
||||
@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
deferredService: &DeferredService{},
|
||||
billingCacheService: nil,
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: cfg,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
deferredService: &DeferredService{},
|
||||
billingCacheService: nil,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
@@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: cfg,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
|
||||
require.Equal(t, 5, result.usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
},
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
|
||||
"",
|
||||
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
|
||||
_ = pr.Close()
|
||||
<-done
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete after timeout")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.Equal(t, 9, result.usage.InputTokens)
|
||||
@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
}
|
||||
@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete after disconnect")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.Equal(t, 8, result.usage.InputTokens)
|
||||
|
||||
261
backend/internal/service/gateway_record_usage_test.go
Normal file
261
backend/internal/service/gateway_record_usage_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.1
|
||||
return NewGatewayService(
|
||||
nil,
|
||||
nil,
|
||||
usageRepo,
|
||||
nil,
|
||||
userRepo,
|
||||
subRepo,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
NewBillingService(cfg, nil),
|
||||
nil,
|
||||
&BillingCacheService{},
|
||||
nil,
|
||||
nil,
|
||||
&DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
svc.usageBillingRepo = billingRepo
|
||||
return svc
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsage(reqCtx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_detached_ctx",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 501,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.NoError(t, userRepo.lastCtxErr)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`))
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_payload_hash",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
RequestPayloadHash: payloadHash,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123")
|
||||
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_payload_fallback",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_not_persisted",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 503,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 603},
|
||||
Account: &Account{ID: 703},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_long_context_detached_ctx",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 12,
|
||||
OutputTokens: 8,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 502,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 602},
|
||||
Account: &Account{ID: 702},
|
||||
LongContextThreshold: 200000,
|
||||
LongContextMultiplier: 2,
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.NoError(t, userRepo.lastCtxErr)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback")
|
||||
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 504},
|
||||
User: &User{ID: 604},
|
||||
Account: &Account{ID: 704},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_billing_fail",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 505},
|
||||
User: &User{ID: 605},
|
||||
Account: &Account{ID: 705},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 0, usageRepo.calls)
|
||||
}
|
||||
@@ -50,6 +50,7 @@ const (
|
||||
|
||||
defaultUserGroupRateCacheTTL = 30 * time.Second
|
||||
defaultModelsListCacheTTL = 15 * time.Second
|
||||
postUsageBillingTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -106,6 +107,52 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) {
|
||||
return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load()
|
||||
}
|
||||
|
||||
func claudeUsageHasAnyTokens(usage *ClaudeUsage) bool {
|
||||
return usage != nil && (usage.InputTokens > 0 ||
|
||||
usage.OutputTokens > 0 ||
|
||||
usage.CacheCreationInputTokens > 0 ||
|
||||
usage.CacheReadInputTokens > 0 ||
|
||||
usage.CacheCreation5mTokens > 0 ||
|
||||
usage.CacheCreation1hTokens > 0)
|
||||
}
|
||||
|
||||
func openAIUsageHasAnyTokens(usage *OpenAIUsage) bool {
|
||||
return usage != nil && (usage.InputTokens > 0 ||
|
||||
usage.OutputTokens > 0 ||
|
||||
usage.CacheCreationInputTokens > 0 ||
|
||||
usage.CacheReadInputTokens > 0)
|
||||
}
|
||||
|
||||
func openAIStreamEventIsTerminal(data string) bool {
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if trimmed == "[DONE]" {
|
||||
return true
|
||||
}
|
||||
switch gjson.Get(trimmed, "type").String() {
|
||||
case "response.completed", "response.done", "response.failed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func anthropicStreamEventIsTerminal(eventName, data string) bool {
|
||||
if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") {
|
||||
return true
|
||||
}
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if trimmed == "[DONE]" {
|
||||
return true
|
||||
}
|
||||
return gjson.Get(trimmed, "type").String() == "message_stop"
|
||||
}
|
||||
|
||||
func cloneStringSlice(src []string) []string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
@@ -504,6 +551,7 @@ type GatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageBillingRepo UsageBillingRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
@@ -537,6 +585,7 @@ func NewGatewayService(
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
usageBillingRepo UsageBillingRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
@@ -563,6 +612,7 @@ func NewGatewayService(
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageBillingRepo: usageBillingRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
@@ -4049,7 +4099,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4127,7 +4179,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// also downgrade tool_use/tool_result blocks to text.
|
||||
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseRetryCtx()
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
@@ -4159,7 +4213,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseRetryCtx2()
|
||||
if buildErr2 == nil {
|
||||
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr2 == nil {
|
||||
@@ -4226,7 +4282,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
rectifiedBody, applied := RectifyThinkingBudget(body)
|
||||
if applied && time.Since(retryStart) < maxRetryElapsed {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
|
||||
budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseBudgetRetryCtx()
|
||||
if buildErr == nil {
|
||||
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
@@ -4498,7 +4556,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
||||
var resp *http.Response
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4774,6 +4834,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
sawTerminalEvent := false
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@@ -4836,17 +4897,20 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
|
||||
flusher.Flush()
|
||||
}
|
||||
if !sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err)
|
||||
}
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v",
|
||||
account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err())
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
@@ -4858,11 +4922,19 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
line := ev.line
|
||||
if data, ok := extractAnthropicSSEDataLine(line); ok {
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if anthropicStreamEventIsTerminal("", trimmed) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsagePassthrough(data, usage)
|
||||
} else {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
}
|
||||
|
||||
if !clientDisconnected {
|
||||
@@ -4884,8 +4956,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
||||
if s.rateLimitService != nil {
|
||||
@@ -6011,6 +6082,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
sawTerminalEvent := false
|
||||
|
||||
pendingEventLines := make([]string, 0, 4)
|
||||
|
||||
@@ -6041,6 +6113,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
if dataLine == "[DONE]" {
|
||||
sawTerminalEvent = true
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
block = "event: " + eventName + "\n"
|
||||
@@ -6107,6 +6180,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
usagePatch := s.extractSSEUsagePatch(event)
|
||||
if anthropicStreamEventIsTerminal(eventName, dataLine) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
if !eventChanged {
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
@@ -6140,18 +6216,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// 上游完成,返回结果
|
||||
if !sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err)
|
||||
}
|
||||
// 客户端未断开,正常的错误处理
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
@@ -6209,9 +6289,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
// 客户端已断开,上游也超时了,返回已收集的 usage
|
||||
logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
@@ -6557,15 +6635,16 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
}
|
||||
|
||||
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
|
||||
@@ -6574,6 +6653,14 @@ type APIKeyQuotaUpdater interface {
|
||||
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
|
||||
}
|
||||
|
||||
type apiKeyAuthCacheInvalidator interface {
|
||||
InvalidateAuthCacheByKey(ctx context.Context, key string)
|
||||
}
|
||||
|
||||
type usageLogBestEffortWriter interface {
|
||||
CreateBestEffort(ctx context.Context, log *UsageLog) error
|
||||
}
|
||||
|
||||
// postUsageBillingParams 统一扣费所需的参数
|
||||
type postUsageBillingParams struct {
|
||||
Cost *CostBreakdown
|
||||
@@ -6581,6 +6668,7 @@ type postUsageBillingParams struct {
|
||||
APIKey *APIKey
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
RequestPayloadHash string
|
||||
IsSubscriptionBill bool
|
||||
AccountRateMultiplier float64
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
@@ -6592,19 +6680,22 @@ type postUsageBillingParams struct {
|
||||
// - API Key 限速用量更新
|
||||
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
|
||||
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
||||
billingCtx, cancel := detachedBillingContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
cost := p.Cost
|
||||
|
||||
// 1. 订阅 / 余额扣费
|
||||
if p.IsSubscriptionBill {
|
||||
if cost.TotalCost > 0 {
|
||||
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
|
||||
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
|
||||
@@ -6613,31 +6704,187 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
|
||||
// 2. API Key 配额
|
||||
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. API Key 限速用量
|
||||
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 更新账号最近使用时间
|
||||
finalizePostUsageBilling(p, deps)
|
||||
}
|
||||
|
||||
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
|
||||
if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
if ctx != nil {
|
||||
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||
return "client:" + strings.TrimSpace(clientRequestID)
|
||||
}
|
||||
if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
|
||||
return "local:" + strings.TrimSpace(requestID)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string {
|
||||
if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" {
|
||||
return payloadHash
|
||||
}
|
||||
if ctx != nil {
|
||||
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||
return "client:" + strings.TrimSpace(clientRequestID)
|
||||
}
|
||||
if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
|
||||
return "local:" + strings.TrimSpace(requestID)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand {
|
||||
if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := &UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: p.APIKey.ID,
|
||||
UserID: p.User.ID,
|
||||
AccountID: p.Account.ID,
|
||||
AccountType: p.Account.Type,
|
||||
RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash),
|
||||
}
|
||||
if usageLog != nil {
|
||||
cmd.Model = usageLog.Model
|
||||
cmd.BillingType = usageLog.BillingType
|
||||
cmd.InputTokens = usageLog.InputTokens
|
||||
cmd.OutputTokens = usageLog.OutputTokens
|
||||
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
||||
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
||||
cmd.ImageCount = usageLog.ImageCount
|
||||
if usageLog.MediaType != nil {
|
||||
cmd.MediaType = *usageLog.MediaType
|
||||
}
|
||||
if usageLog.ServiceTier != nil {
|
||||
cmd.ServiceTier = *usageLog.ServiceTier
|
||||
}
|
||||
if usageLog.ReasoningEffort != nil {
|
||||
cmd.ReasoningEffort = *usageLog.ReasoningEffort
|
||||
}
|
||||
if usageLog.SubscriptionID != nil {
|
||||
cmd.SubscriptionID = usageLog.SubscriptionID
|
||||
}
|
||||
}
|
||||
|
||||
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
|
||||
cmd.SubscriptionID = &p.Subscription.ID
|
||||
cmd.SubscriptionCost = p.Cost.TotalCost
|
||||
} else if p.Cost.ActualCost > 0 {
|
||||
cmd.BalanceCost = p.Cost.ActualCost
|
||||
}
|
||||
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
cmd.APIKeyQuotaCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
|
||||
}
|
||||
|
||||
cmd.Normalize()
|
||||
return cmd
|
||||
}
|
||||
|
||||
func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) {
|
||||
if p == nil || deps == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := buildUsageBillingCommand(requestID, usageLog, p)
|
||||
if cmd == nil || cmd.RequestID == "" || repo == nil {
|
||||
postUsageBilling(ctx, p, deps)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
billingCtx, cancel := detachedBillingContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
result, err := repo.Apply(billingCtx, cmd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if result == nil || !result.Applied {
|
||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if result.APIKeyQuotaExhausted {
|
||||
if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" {
|
||||
invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key)
|
||||
}
|
||||
}
|
||||
|
||||
finalizePostUsageBilling(p, deps)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
|
||||
if p == nil || p.Cost == nil || deps == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if p.IsSubscriptionBill {
|
||||
if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost)
|
||||
}
|
||||
} else if p.Cost.ActualCost > 0 && p.User != nil {
|
||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
|
||||
}
|
||||
|
||||
if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() {
|
||||
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost)
|
||||
}
|
||||
|
||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||
}
|
||||
|
||||
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
base := context.Background()
|
||||
if ctx != nil {
|
||||
base = context.WithoutCancel(ctx)
|
||||
}
|
||||
return context.WithTimeout(base, postUsageBillingTimeout)
|
||||
}
|
||||
|
||||
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||
if !stream {
|
||||
return ctx, func() {}
|
||||
}
|
||||
if ctx == nil {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
return context.WithoutCancel(ctx), func() {}
|
||||
}
|
||||
|
||||
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
|
||||
type billingDeps struct {
|
||||
accountRepo AccountRepository
|
||||
@@ -6657,6 +6904,28 @@ func (s *GatewayService) billingDeps() *billingDeps {
|
||||
}
|
||||
}
|
||||
|
||||
func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) {
|
||||
if repo == nil || usageLog == nil {
|
||||
return
|
||||
}
|
||||
usageCtx, cancel := detachedBillingContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
if writer, ok := repo.(usageLogBestEffortWriter); ok {
|
||||
if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil {
|
||||
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
|
||||
if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil {
|
||||
logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := repo.Create(usageCtx, usageLog); err != nil {
|
||||
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
@@ -6758,11 +7027,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
mediaType = &result.MediaType
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
@@ -6807,33 +7077,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -6844,13 +7113,14 @@ type RecordUsageLongContextInput struct {
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
LongContextThreshold int // 长上下文阈值(如 200000)
|
||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService *APIKeyService // API Key 配额服务(可选)
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||
LongContextThreshold int // 长上下文阈值(如 200000)
|
||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
|
||||
}
|
||||
|
||||
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
||||
@@ -6933,11 +7203,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
imageSize = &result.ImageSize
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
@@ -6981,33 +7252,32 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,35 +7,63 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIRecordUsageLogRepoStub struct {
|
||||
UsageLogRepository
|
||||
|
||||
inserted bool
|
||||
err error
|
||||
calls int
|
||||
lastLog *UsageLog
|
||||
inserted bool
|
||||
err error
|
||||
calls int
|
||||
lastLog *UsageLog
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||
s.calls++
|
||||
s.lastLog = log
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.inserted, s.err
|
||||
}
|
||||
|
||||
type openAIRecordUsageBillingRepoStub struct {
|
||||
UsageBillingRepository
|
||||
|
||||
result *UsageBillingApplyResult
|
||||
err error
|
||||
calls int
|
||||
lastCmd *UsageBillingCommand
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) {
|
||||
s.calls++
|
||||
s.lastCmd = cmd
|
||||
s.lastCtxErr = ctx.Err()
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
if s.result != nil {
|
||||
return s.result, nil
|
||||
}
|
||||
return &UsageBillingApplyResult{Applied: true}, nil
|
||||
}
|
||||
|
||||
type openAIRecordUsageUserRepoStub struct {
|
||||
UserRepository
|
||||
|
||||
deductCalls int
|
||||
deductErr error
|
||||
lastAmount float64
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
s.deductCalls++
|
||||
s.lastAmount = amount
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.deductErr
|
||||
}
|
||||
|
||||
@@ -44,29 +72,35 @@ type openAIRecordUsageSubRepoStub struct {
|
||||
|
||||
incrementCalls int
|
||||
incrementErr error
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
s.incrementCalls++
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.incrementErr
|
||||
}
|
||||
|
||||
type openAIRecordUsageAPIKeyQuotaStub struct {
|
||||
quotaCalls int
|
||||
rateLimitCalls int
|
||||
err error
|
||||
lastAmount float64
|
||||
quotaCalls int
|
||||
rateLimitCalls int
|
||||
err error
|
||||
lastAmount float64
|
||||
lastQuotaCtxErr error
|
||||
lastRateLimitCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
s.quotaCalls++
|
||||
s.lastAmount = cost
|
||||
s.lastQuotaCtxErr = ctx.Err()
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
s.rateLimitCalls++
|
||||
s.lastAmount = cost
|
||||
s.lastRateLimitCtxErr = ctx.Err()
|
||||
return s.err
|
||||
}
|
||||
|
||||
@@ -93,23 +127,38 @@ func i64p(v int64) *int64 {
|
||||
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.1
|
||||
svc := NewOpenAIGatewayService(
|
||||
nil,
|
||||
usageRepo,
|
||||
nil,
|
||||
userRepo,
|
||||
subRepo,
|
||||
rateRepo,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
NewBillingService(cfg, nil),
|
||||
nil,
|
||||
&BillingCacheService{},
|
||||
nil,
|
||||
&DeferredService{},
|
||||
nil,
|
||||
)
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
nil,
|
||||
resolveUserGroupRateCacheTTL(cfg),
|
||||
nil,
|
||||
"service.openai_gateway.test",
|
||||
)
|
||||
return svc
|
||||
}
|
||||
|
||||
return &OpenAIGatewayService{
|
||||
usageLogRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: subRepo,
|
||||
cfg: cfg,
|
||||
billingService: NewBillingService(cfg, nil),
|
||||
billingCacheService: &BillingCacheService{},
|
||||
deferredService: &DeferredService{},
|
||||
userGroupRateResolver: newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
nil,
|
||||
resolveUserGroupRateCacheTTL(cfg),
|
||||
nil,
|
||||
"service.openai_gateway.test",
|
||||
),
|
||||
}
|
||||
func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||
svc.usageBillingRepo = billingRepo
|
||||
return svc
|
||||
}
|
||||
|
||||
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
|
||||
@@ -252,9 +301,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
@@ -272,11 +322,254 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_duplicate_billing_key",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10045,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 20045},
|
||||
Account: &Account{ID: 30045},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
require.Equal(t, 0, quotaSvc.quotaCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) {
|
||||
usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_usage_log_error",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10041},
|
||||
User: &User{ID: 20041},
|
||||
Account: &Account{ID: 30041},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_not_persisted",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10043,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 20043},
|
||||
Account: &Account{ID: 30043},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
|
||||
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_detached_billing_ctx",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10042,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 20042},
|
||||
Account: &Account{ID: 30042},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.NoError(t, userRepo.lastCtxErr)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_detached_billing_repo_ctx",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10046},
|
||||
User: &User{ID: 20046},
|
||||
Account: &Account{ID: 30046},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.NoError(t, billingRepo.lastCtxErr)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.NoError(t, usageRepo.lastCtxErr)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||
|
||||
payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`))
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "openai_payload_hash",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "gpt-5",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
RequestPayloadHash: payloadHash,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback")
|
||||
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10047},
|
||||
User: &User{ID: 20047},
|
||||
Account: &Account{ID: 30047},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_billing_fail",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10048},
|
||||
User: &User{ID: 20048},
|
||||
Account: &Account{ID: 30048},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 0, usageRepo.calls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
|
||||
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
|
||||
@@ -259,6 +259,7 @@ type openAIWSRetryMetrics struct {
|
||||
type OpenAIGatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageBillingRepo UsageBillingRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
@@ -295,6 +296,7 @@ type OpenAIGatewayService struct {
|
||||
func NewOpenAIGatewayService(
|
||||
accountRepo AccountRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
usageBillingRepo UsageBillingRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
@@ -312,6 +314,7 @@ func NewOpenAIGatewayService(
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageBillingRepo: usageBillingRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
@@ -2014,7 +2017,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
// Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2206,7 +2211,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2543,6 +2550,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
sawDone := false
|
||||
sawTerminalEvent := false
|
||||
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
@@ -2562,6 +2570,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
if trimmedData == "[DONE]" {
|
||||
sawDone = true
|
||||
}
|
||||
if openAIStreamEventIsTerminal(trimmedData) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
@@ -2579,19 +2590,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
|
||||
if sawTerminalEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if clientDisconnected {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
|
||||
account.ID,
|
||||
upstreamRequestID,
|
||||
err,
|
||||
ctx.Err(),
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
@@ -2605,12 +2611,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
if !clientDisconnected && !sawDone && ctx.Err() == nil {
|
||||
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
|
||||
logger.FromContext(ctx).With(
|
||||
zap.String("component", "service.openai_gateway"),
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("upstream_request_id", upstreamRequestID),
|
||||
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
@@ -3030,6 +3037,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
|
||||
errorEventSent := false
|
||||
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||
sawTerminalEvent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent || clientDisconnected {
|
||||
return
|
||||
@@ -3060,22 +3068,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
||||
}
|
||||
}
|
||||
if !sawTerminalEvent {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
|
||||
if scanErr == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sawTerminalEvent {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
|
||||
return resultWithUsage(), nil, true
|
||||
}
|
||||
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
|
||||
return resultWithUsage(), nil, true
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
|
||||
}
|
||||
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr)
|
||||
return resultWithUsage(), nil, true
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
|
||||
}
|
||||
if errors.Is(scanErr, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
|
||||
@@ -3098,6 +3111,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
|
||||
dataBytes := []byte(data)
|
||||
if openAIStreamEventIsTerminal(data) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
||||
@@ -3214,8 +3230,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
|
||||
return resultWithUsage(), nil
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
@@ -3313,11 +3328,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
|
||||
if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||
return
|
||||
}
|
||||
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
|
||||
if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
|
||||
// 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。
|
||||
if len(data) < 72 {
|
||||
return
|
||||
}
|
||||
if gjson.GetBytes(data, "type").String() != "response.completed" {
|
||||
eventType := gjson.GetBytes(data, "type").String()
|
||||
if eventType != "response.completed" && eventType != "response.done" {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3670,14 +3686,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
Result *OpenAIForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -3743,11 +3760,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
// Create usage log
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
RequestID: requestID,
|
||||
Model: billingModel,
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
@@ -3788,29 +3806,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
||||
func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
@@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
if err == nil || !strings.Contains(err.Error(), "stream usage incomplete") {
|
||||
t.Fatalf("expected incomplete stream error, got %v", err)
|
||||
}
|
||||
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
|
||||
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||
@@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
|
||||
t.Fatalf("expected missing terminal event error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
|
||||
_ = pr.Close()
|
||||
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
|
||||
t.Fatalf("expected missing terminal event error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 2, result.usage.InputTokens)
|
||||
require.Equal(t, 3, result.usage.OutputTokens)
|
||||
require.Equal(t, 1, result.usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{}}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
@@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
|
||||
require.Equal(t, 3, usage.InputTokens)
|
||||
require.Equal(t, 5, usage.OutputTokens)
|
||||
require.Equal(t, 2, usage.CacheReadInputTokens)
|
||||
|
||||
// done 事件同样可能携带最终 usage
|
||||
svc.parseSSEUsage(`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`, usage)
|
||||
require.Equal(t, 13, usage.InputTokens)
|
||||
require.Equal(t, 15, usage.OutputTokens)
|
||||
require.Equal(t, 4, usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
|
||||
|
||||
@@ -392,6 +392,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
|
||||
110
backend/internal/service/usage_billing.go
Normal file
110
backend/internal/service/usage_billing.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrUsageBillingRequestIDRequired = errors.New("usage billing request_id is required")
|
||||
var ErrUsageBillingRequestConflict = errors.New("usage billing request fingerprint conflict")
|
||||
|
||||
// UsageBillingCommand describes one billable request that must be applied at most once.
|
||||
type UsageBillingCommand struct {
|
||||
RequestID string
|
||||
APIKeyID int64
|
||||
RequestFingerprint string
|
||||
RequestPayloadHash string
|
||||
|
||||
UserID int64
|
||||
AccountID int64
|
||||
SubscriptionID *int64
|
||||
AccountType string
|
||||
Model string
|
||||
ServiceTier string
|
||||
ReasoningEffort string
|
||||
BillingType int8
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheCreationTokens int
|
||||
CacheReadTokens int
|
||||
ImageCount int
|
||||
MediaType string
|
||||
|
||||
BalanceCost float64
|
||||
SubscriptionCost float64
|
||||
APIKeyQuotaCost float64
|
||||
APIKeyRateLimitCost float64
|
||||
AccountQuotaCost float64
|
||||
}
|
||||
|
||||
func (c *UsageBillingCommand) Normalize() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.RequestID = strings.TrimSpace(c.RequestID)
|
||||
if strings.TrimSpace(c.RequestFingerprint) == "" {
|
||||
c.RequestFingerprint = buildUsageBillingFingerprint(c)
|
||||
}
|
||||
}
|
||||
|
||||
func buildUsageBillingFingerprint(c *UsageBillingCommand) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
raw := fmt.Sprintf(
|
||||
"%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f",
|
||||
c.UserID,
|
||||
c.AccountID,
|
||||
c.APIKeyID,
|
||||
strings.TrimSpace(c.AccountType),
|
||||
strings.TrimSpace(c.Model),
|
||||
strings.TrimSpace(c.ServiceTier),
|
||||
strings.TrimSpace(c.ReasoningEffort),
|
||||
c.BillingType,
|
||||
c.InputTokens,
|
||||
c.OutputTokens,
|
||||
c.CacheCreationTokens,
|
||||
c.CacheReadTokens,
|
||||
c.ImageCount,
|
||||
strings.TrimSpace(c.MediaType),
|
||||
valueOrZero(c.SubscriptionID),
|
||||
c.BalanceCost,
|
||||
c.SubscriptionCost,
|
||||
c.APIKeyQuotaCost,
|
||||
c.APIKeyRateLimitCost,
|
||||
c.AccountQuotaCost,
|
||||
)
|
||||
if payloadHash := strings.TrimSpace(c.RequestPayloadHash); payloadHash != "" {
|
||||
raw += "|" + payloadHash
|
||||
}
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func HashUsageRequestPayload(payload []byte) string {
|
||||
if len(payload) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(payload)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func valueOrZero(v *int64) int64 {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
type UsageBillingApplyResult struct {
|
||||
Applied bool
|
||||
APIKeyQuotaExhausted bool
|
||||
}
|
||||
|
||||
type UsageBillingRepository interface {
|
||||
Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error)
|
||||
}
|
||||
@@ -56,7 +56,8 @@ type cleanupRepoStub struct {
|
||||
}
|
||||
|
||||
type dashboardRepoStub struct {
|
||||
recomputeErr error
|
||||
recomputeErr error
|
||||
recomputeCalls int
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
@@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
s.recomputeCalls++
|
||||
return s.recomputeErr
|
||||
}
|
||||
|
||||
@@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
|
||||
dashboardRepo := &dashboardRepoStub{recomputeErr: errors.New("recompute failed")}
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 0},
|
||||
},
|
||||
}
|
||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: false},
|
||||
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
|
||||
})
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
|
||||
@@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
|
||||
dashboardRepo := &dashboardRepoStub{}
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 0},
|
||||
},
|
||||
}
|
||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
||||
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
|
||||
})
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
@@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {
|
||||
|
||||
60
backend/internal/service/usage_log_create_result.go
Normal file
60
backend/internal/service/usage_log_create_result.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package service
|
||||
|
||||
import "errors"
|
||||
|
||||
type usageLogCreateDisposition int
|
||||
|
||||
const (
|
||||
usageLogCreateDispositionUnknown usageLogCreateDisposition = iota
|
||||
usageLogCreateDispositionNotPersisted
|
||||
)
|
||||
|
||||
type UsageLogCreateError struct {
|
||||
err error
|
||||
disposition usageLogCreateDisposition
|
||||
}
|
||||
|
||||
func (e *UsageLogCreateError) Error() string {
|
||||
if e == nil || e.err == nil {
|
||||
return "usage log create error"
|
||||
}
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *UsageLogCreateError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.err
|
||||
}
|
||||
|
||||
func MarkUsageLogCreateNotPersisted(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageLogCreateError{
|
||||
err: err,
|
||||
disposition: usageLogCreateDispositionNotPersisted,
|
||||
}
|
||||
}
|
||||
|
||||
func IsUsageLogCreateNotPersisted(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var target *UsageLogCreateError
|
||||
if !errors.As(err, &target) {
|
||||
return false
|
||||
}
|
||||
return target.disposition == usageLogCreateDispositionNotPersisted
|
||||
}
|
||||
|
||||
func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool {
|
||||
if inserted {
|
||||
return true
|
||||
}
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return !IsUsageLogCreateNotPersisted(err)
|
||||
}
|
||||
Reference in New Issue
Block a user