From c9debc50b1b95ed580d74e22e47219ca0c624195 Mon Sep 17 00:00:00 2001 From: ius Date: Wed, 11 Mar 2026 20:29:48 +0800 Subject: [PATCH] Batch usage log writes in repository --- backend/internal/repository/usage_log_repo.go | 502 +++++++++++++++--- .../usage_log_repo_integration_test.go | 99 ++++ 2 files changed, 538 insertions(+), 63 deletions(-) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index c91a68e5..8ffcb2f3 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "os" + "strconv" "strings" + "sync" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -43,6 +45,39 @@ func safeDateFormat(granularity string) string { type usageLogRepository struct { client *dbent.Client sql sqlExecutor + db *sql.DB + + createBatchOnce sync.Once + createBatchCh chan usageLogCreateRequest +} + +const ( + usageLogCreateBatchMaxSize = 64 + usageLogCreateBatchWindow = 3 * time.Millisecond + usageLogCreateBatchQueueCap = 4096 +) + +type usageLogCreateRequest struct { + log *service.UsageLog + resultCh chan usageLogCreateResult +} + +type usageLogCreateResult struct { + inserted bool + err error +} + +type usageLogInsertPrepared struct { + createdAt time.Time + requestID string + rateMultiplier float64 + requestType int16 + args []any +} + +type usageLogBatchState struct { + ID int64 + CreatedAt time.Time } func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository { @@ -51,7 +86,11 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository { // 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。 - return &usageLogRepository{client: client, sql: sqlq} + repo := &usageLogRepository{client: client, sql: sqlq} + if db, ok := sqlq.(*sql.DB); ok { + repo.db = db + } + return repo } // getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤) @@ -82,24 +121,25 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) return false, nil } - // 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。 - // 无事务时回退到默认的 *sql.DB 执行器。 - sqlq := r.sql if tx := dbent.TxFromContext(ctx); tx != nil { - sqlq = tx.Client() + return r.createSingle(ctx, tx.Client(), log) } - - createdAt := log.CreatedAt - if createdAt.IsZero() { - createdAt = time.Now() + if r.db == nil { + return r.createSingle(ctx, r.sql, log) } - requestID := strings.TrimSpace(log.RequestID) + if requestID == "" { + return r.createSingle(ctx, r.sql, log) + } log.RequestID = requestID + return r.createBatched(ctx, log) +} - rateMultiplier := log.RateMultiplier - log.SyncRequestTypeAndLegacyFields() - requestType := int16(log.RequestType) +func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) { + prepared := prepareUsageLogInsert(log) + if sqlq == nil { + sqlq = r.sql + } query := ` INSERT INTO usage_logs ( @@ -151,6 +191,336 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) RETURNING id, created_at ` + if err := scanSingleRow(ctx, sqlq, query, prepared.args, &log.ID, &log.CreatedAt); err != nil { + if errors.Is(err, sql.ErrNoRows) && prepared.requestID != "" { + selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" + if err := scanSingleRow(ctx, sqlq, selectQuery, []any{prepared.requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil { + return false, err + } + log.RateMultiplier = prepared.rateMultiplier + return false, nil + } else { + return false, err + } + } + log.RateMultiplier = prepared.rateMultiplier + return true, nil +} + +func (r *usageLogRepository) createBatched(ctx context.Context, log *service.UsageLog) (bool, error) { + if r.db == nil { + return r.createSingle(ctx, r.sql, log) + } + r.ensureCreateBatcher() + if r.createBatchCh == nil { + return r.createSingle(ctx, r.sql, log) + } + + req := usageLogCreateRequest{ + log: log, + resultCh: make(chan usageLogCreateResult, 1), + } + + select { + case r.createBatchCh <- req: + case <-ctx.Done(): + return false, ctx.Err() + default: + return r.createSingle(ctx, r.sql, log) + } + + select { + case res := <-req.resultCh: + return res.inserted, res.err + case <-ctx.Done(): + return false, ctx.Err() + } +} + +func (r *usageLogRepository) ensureCreateBatcher() { + if r == nil || r.db == nil { + return + } + r.createBatchOnce.Do(func() { + r.createBatchCh = make(chan usageLogCreateRequest, usageLogCreateBatchQueueCap) + go r.runCreateBatcher(r.db) + }) +} + +func (r *usageLogRepository) runCreateBatcher(db *sql.DB) { + for { + first, ok := <-r.createBatchCh + if !ok { + return + } + + batch := make([]usageLogCreateRequest, 0, usageLogCreateBatchMaxSize) + batch = append(batch, first) + + timer := time.NewTimer(usageLogCreateBatchWindow) + batchLoop: + for len(batch) < usageLogCreateBatchMaxSize { + select { + case req, ok := <-r.createBatchCh: + if !ok { + break batchLoop + } + batch = append(batch, req) + case <-timer.C: + break batchLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + + r.flushCreateBatch(db, batch) + } +} + +func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) { + if len(batch) == 0 { + return + } + + uniqueOrder := make([]string, 0, len(batch)) + preparedByKey := make(map[string]usageLogInsertPrepared, len(batch)) + requestsByKey := make(map[string][]usageLogCreateRequest, len(batch)) + fallback := make([]usageLogCreateRequest, 0) + + for _, req := range batch { + if req.log == nil { + sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil}) + continue + } + prepared := prepareUsageLogInsert(req.log) + if prepared.requestID == "" { + fallback = append(fallback, req) + continue + } + key := usageLogBatchKey(prepared.requestID, req.log.APIKeyID) + if _, exists := requestsByKey[key]; !exists { + uniqueOrder = append(uniqueOrder, key) + preparedByKey[key] = prepared + } + requestsByKey[key] = append(requestsByKey[key], req) + } + + if len(uniqueOrder) > 0 { + insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey) + if err != nil { + for _, key := range uniqueOrder { + fallback = append(fallback, requestsByKey[key]...) + } + } else { + for _, key := range uniqueOrder { + reqs := requestsByKey[key] + state, ok := stateMap[key] + if !ok { + for _, req := range reqs { + sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{ + inserted: false, + err: fmt.Errorf("usage log batch state missing for key=%s", key), + }) + } + continue + } + for idx, req := range reqs { + req.log.ID = state.ID + req.log.CreatedAt = state.CreatedAt + req.log.RateMultiplier = preparedByKey[key].rateMultiplier + sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{ + inserted: idx == 0 && insertedMap[key], + err: nil, + }) + } + } + } + } + + if len(fallback) == 0 { + return + } + + fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for _, req := range fallback { + inserted, err := r.createSingle(fallbackCtx, db, req.log) + sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err}) + } +} + +func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) { + if len(keys) == 0 { + return map[string]bool{}, map[string]usageLogBatchState{}, nil + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey) + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return nil, nil, err + } + insertedMap := make(map[string]bool, len(keys)) + for rows.Next() { + var ( + requestID string + apiKeyID int64 + id int64 + createdAt time.Time + ) + if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil { + _ = rows.Close() + return nil, nil, err + } + insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return nil, nil, err + } + _ = rows.Close() + + stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey) + if err != nil { + return nil, nil, err + } + return insertedMap, stateMap, nil +} + +func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) { + var query strings.Builder + _, _ = query.WriteString(` + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) VALUES `) + + args := make([]any, 0, len(keys)*36) + argPos := 1 + for idx, key := range keys { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + prepared := preparedByKey[key] + for i := 0; i < len(prepared.args); i++ { + if i > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + _, _ = query.WriteString(` + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING request_id, api_key_id, id, created_at + `) + return query.String(), args +} + +func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) { + var query strings.Builder + _, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `) + args := make([]any, 0, len(keys)*2) + argPos := 1 + for idx, key := range keys { + if idx > 0 { + _, _ = query.WriteString(" OR ") + } + prepared := preparedByKey[key] + apiKeyID := prepared.args[1] + _, _ = query.WriteString("(request_id = $") + _, _ = query.WriteString(strconv.Itoa(argPos)) + _, _ = query.WriteString(" AND api_key_id = $") + _, _ = query.WriteString(strconv.Itoa(argPos + 1)) + _, _ = query.WriteString(")") + args = append(args, prepared.requestID, apiKeyID) + argPos += 2 + } + + rows, err := db.QueryContext(ctx, query.String(), args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + stateMap := make(map[string]usageLogBatchState, len(keys)) + for rows.Next() { + var ( + requestID string + apiKeyID int64 + id int64 + createdAt time.Time + ) + if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil { + return nil, err + } + stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{ + ID: id, + CreatedAt: createdAt, + } + } + if err := rows.Err(); err != nil { + return nil, err + } + return stateMap, nil +} + +func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { + createdAt := log.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now() + } + + requestID := strings.TrimSpace(log.RequestID) + log.RequestID = requestID + + rateMultiplier := log.RateMultiplier + log.SyncRequestTypeAndLegacyFields() + requestType := int16(log.RequestType) + groupID := nullInt64(log.GroupID) subscriptionID := nullInt64(log.SubscriptionID) duration := nullInt(log.DurationMs) @@ -167,58 +537,64 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) requestIDArg = requestID } - args := []any{ - log.UserID, - log.APIKeyID, - log.AccountID, - requestIDArg, - log.Model, - groupID, - subscriptionID, - log.InputTokens, - log.OutputTokens, - log.CacheCreationTokens, - log.CacheReadTokens, - log.CacheCreation5mTokens, - log.CacheCreation1hTokens, - log.InputCost, - log.OutputCost, - log.CacheCreationCost, - log.CacheReadCost, - log.TotalCost, - log.ActualCost, - rateMultiplier, - log.AccountRateMultiplier, - log.BillingType, - requestType, - log.Stream, - log.OpenAIWSMode, - duration, - firstToken, - userAgent, - ipAddress, - log.ImageCount, - imageSize, - mediaType, - serviceTier, - reasoningEffort, - log.CacheTTLOverridden, - createdAt, + return usageLogInsertPrepared{ + createdAt: createdAt, + requestID: requestID, + rateMultiplier: rateMultiplier, + requestType: requestType, + args: []any{ + log.UserID, + log.APIKeyID, + log.AccountID, + requestIDArg, + log.Model, + groupID, + subscriptionID, + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + rateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + requestType, + log.Stream, + log.OpenAIWSMode, + duration, + firstToken, + userAgent, + ipAddress, + log.ImageCount, + imageSize, + mediaType, + serviceTier, + reasoningEffort, + log.CacheTTLOverridden, + createdAt, + }, } - if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { - if errors.Is(err, sql.ErrNoRows) && requestID != "" { - selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" - if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil { - return false, err - } - log.RateMultiplier = rateMultiplier - return false, nil - } else { - return false, err - } +} + +func usageLogBatchKey(requestID string, apiKeyID int64) string { + return requestID + "\x1f" + strconv.FormatInt(apiKeyID, 10) +} + +func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateResult) { + if ch == nil { + return + } + select { + case ch <- res: + default: } - log.RateMultiplier = rateMultiplier - return true, nil } func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 4d50f7de..d2e1e9d4 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -4,6 +4,8 @@ package repository import ( "context" + "fmt" + "sync" "testing" "time" @@ -14,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -84,6 +87,102 @@ func (s *UsageLogRepoSuite) TestCreate() { s.Require().NotZero(log.ID) } +func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()}) + + const total = 16 + results := make([]bool, total) + errs := make([]error, total) + logs := make([]*service.UsageLog, total) + + var wg sync.WaitGroup + wg.Add(total) + for i := 0; i < total; i++ { + i := i + logs[i] = &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10 + i, + OutputTokens: 20 + i, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + go func() { + defer wg.Done() + results[i], errs[i] = repo.Create(ctx, logs[i]) + }() + } + wg.Wait() + + for i := 0; i < total; i++ { + require.NoError(t, errs[i]) + require.True(t, results[i]) + require.NotZero(t, logs[i].ID) + } + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count)) + require.Equal(t, total, count) +} + +func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()}) + requestID := uuid.NewString() + + log1 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + log2 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + + inserted1, err1 := repo.Create(ctx, log1) + inserted2, err2 := repo.Create(ctx, log2) + require.NoError(t, err1) + require.NoError(t, err2) + require.True(t, inserted1) + require.False(t, inserted2) + require.Equal(t, log1.ID, log2.ID) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)) + require.Equal(t, 1, count) +} + func (s *UsageLogRepoSuite) TestGetByID() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})