From c9debc50b1b95ed580d74e22e47219ca0c624195 Mon Sep 17 00:00:00 2001 From: ius Date: Wed, 11 Mar 2026 20:29:48 +0800 Subject: [PATCH 01/18] Batch usage log writes in repository --- backend/internal/repository/usage_log_repo.go | 502 +++++++++++++++--- .../usage_log_repo_integration_test.go | 99 ++++ 2 files changed, 538 insertions(+), 63 deletions(-) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index c91a68e5..8ffcb2f3 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "os" + "strconv" "strings" + "sync" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -43,6 +45,39 @@ func safeDateFormat(granularity string) string { type usageLogRepository struct { client *dbent.Client sql sqlExecutor + db *sql.DB + + createBatchOnce sync.Once + createBatchCh chan usageLogCreateRequest +} + +const ( + usageLogCreateBatchMaxSize = 64 + usageLogCreateBatchWindow = 3 * time.Millisecond + usageLogCreateBatchQueueCap = 4096 +) + +type usageLogCreateRequest struct { + log *service.UsageLog + resultCh chan usageLogCreateResult +} + +type usageLogCreateResult struct { + inserted bool + err error +} + +type usageLogInsertPrepared struct { + createdAt time.Time + requestID string + rateMultiplier float64 + requestType int16 + args []any +} + +type usageLogBatchState struct { + ID int64 + CreatedAt time.Time } func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository { @@ -51,7 +86,11 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository { // 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。 - return &usageLogRepository{client: client, sql: sqlq} + repo := &usageLogRepository{client: client, sql: sqlq} + if db, ok := sqlq.(*sql.DB); ok { + repo.db = db + } + return repo } // getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤) @@ -82,24 +121,25 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) return false, nil } - // 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。 - // 无事务时回退到默认的 *sql.DB 执行器。 - sqlq := r.sql if tx := dbent.TxFromContext(ctx); tx != nil { - sqlq = tx.Client() + return r.createSingle(ctx, tx.Client(), log) } - - createdAt := log.CreatedAt - if createdAt.IsZero() { - createdAt = time.Now() + if r.db == nil { + return r.createSingle(ctx, r.sql, log) } - requestID := strings.TrimSpace(log.RequestID) + if requestID == "" { + return r.createSingle(ctx, r.sql, log) + } log.RequestID = requestID + return r.createBatched(ctx, log) +} - rateMultiplier := log.RateMultiplier - log.SyncRequestTypeAndLegacyFields() - requestType := int16(log.RequestType) +func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) { + prepared := prepareUsageLogInsert(log) + if sqlq == nil { + sqlq = r.sql + } query := ` INSERT INTO usage_logs ( @@ -151,6 +191,336 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) RETURNING id, created_at ` + if err := scanSingleRow(ctx, sqlq, query, prepared.args, &log.ID, &log.CreatedAt); err != nil { + if errors.Is(err, sql.ErrNoRows) && prepared.requestID != "" { + selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" + if err := scanSingleRow(ctx, sqlq, selectQuery, []any{prepared.requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil { + return false, err + } + log.RateMultiplier = prepared.rateMultiplier + return false, nil + } else { + return false, err + } + } + log.RateMultiplier = prepared.rateMultiplier + return true, nil +} + +func (r *usageLogRepository) createBatched(ctx context.Context, log *service.UsageLog) (bool, error) { + if r.db == nil { + return r.createSingle(ctx, r.sql, log) + } + r.ensureCreateBatcher() + if r.createBatchCh == nil { + return r.createSingle(ctx, r.sql, log) + } + + req := usageLogCreateRequest{ + log: log, + resultCh: make(chan usageLogCreateResult, 1), + } + + select { + case r.createBatchCh <- req: + case <-ctx.Done(): + return false, ctx.Err() + default: + return r.createSingle(ctx, r.sql, log) + } + + select { + case res := <-req.resultCh: + return res.inserted, res.err + case <-ctx.Done(): + return false, ctx.Err() + } +} + +func (r *usageLogRepository) ensureCreateBatcher() { + if r == nil || r.db == nil { + return + } + r.createBatchOnce.Do(func() { + r.createBatchCh = make(chan usageLogCreateRequest, usageLogCreateBatchQueueCap) + go r.runCreateBatcher(r.db) + }) +} + +func (r *usageLogRepository) runCreateBatcher(db *sql.DB) { + for { + first, ok := <-r.createBatchCh + if !ok { + return + } + + batch := make([]usageLogCreateRequest, 0, usageLogCreateBatchMaxSize) + batch = append(batch, first) + + timer := time.NewTimer(usageLogCreateBatchWindow) + batchLoop: + for len(batch) < usageLogCreateBatchMaxSize { + select { + case req, ok := <-r.createBatchCh: + if !ok { + break batchLoop + } + batch = append(batch, req) + case <-timer.C: + break batchLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + + r.flushCreateBatch(db, batch) + } +} + +func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) { + if len(batch) == 0 { + return + } + + uniqueOrder := make([]string, 0, len(batch)) + preparedByKey := make(map[string]usageLogInsertPrepared, len(batch)) + requestsByKey := make(map[string][]usageLogCreateRequest, len(batch)) + fallback := make([]usageLogCreateRequest, 0) + + for _, req := range batch { + if req.log == nil { + sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil}) + continue + } + prepared := prepareUsageLogInsert(req.log) + if prepared.requestID == "" { + fallback = append(fallback, req) + continue + } + key := usageLogBatchKey(prepared.requestID, req.log.APIKeyID) + if _, exists := requestsByKey[key]; !exists { + uniqueOrder = append(uniqueOrder, key) + preparedByKey[key] = prepared + } + requestsByKey[key] = append(requestsByKey[key], req) + } + + if len(uniqueOrder) > 0 { + insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey) + if err != nil { + for _, key := range uniqueOrder { + fallback = append(fallback, requestsByKey[key]...) + } + } else { + for _, key := range uniqueOrder { + reqs := requestsByKey[key] + state, ok := stateMap[key] + if !ok { + for _, req := range reqs { + sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{ + inserted: false, + err: fmt.Errorf("usage log batch state missing for key=%s", key), + }) + } + continue + } + for idx, req := range reqs { + req.log.ID = state.ID + req.log.CreatedAt = state.CreatedAt + req.log.RateMultiplier = preparedByKey[key].rateMultiplier + sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{ + inserted: idx == 0 && insertedMap[key], + err: nil, + }) + } + } + } + } + + if len(fallback) == 0 { + return + } + + fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for _, req := range fallback { + inserted, err := r.createSingle(fallbackCtx, db, req.log) + sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err}) + } +} + +func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) { + if len(keys) == 0 { + return map[string]bool{}, map[string]usageLogBatchState{}, nil + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey) + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return nil, nil, err + } + insertedMap := make(map[string]bool, len(keys)) + for rows.Next() { + var ( + requestID string + apiKeyID int64 + id int64 + createdAt time.Time + ) + if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil { + _ = rows.Close() + return nil, nil, err + } + insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return nil, nil, err + } + _ = rows.Close() + + stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey) + if err != nil { + return nil, nil, err + } + return insertedMap, stateMap, nil +} + +func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) { + var query strings.Builder + _, _ = query.WriteString(` + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) VALUES `) + + args := make([]any, 0, len(keys)*36) + argPos := 1 + for idx, key := range keys { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + prepared := preparedByKey[key] + for i := 0; i < len(prepared.args); i++ { + if i > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + _, _ = query.WriteString(` + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING request_id, api_key_id, id, created_at + `) + return query.String(), args +} + +func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) { + var query strings.Builder + _, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `) + args := make([]any, 0, len(keys)*2) + argPos := 1 + for idx, key := range keys { + if idx > 0 { + _, _ = query.WriteString(" OR ") + } + prepared := preparedByKey[key] + apiKeyID := prepared.args[1] + _, _ = query.WriteString("(request_id = $") + _, _ = query.WriteString(strconv.Itoa(argPos)) + _, _ = query.WriteString(" AND api_key_id = $") + _, _ = query.WriteString(strconv.Itoa(argPos + 1)) + _, _ = query.WriteString(")") + args = append(args, prepared.requestID, apiKeyID) + argPos += 2 + } + + rows, err := db.QueryContext(ctx, query.String(), args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + stateMap := make(map[string]usageLogBatchState, len(keys)) + for rows.Next() { + var ( + requestID string + apiKeyID int64 + id int64 + createdAt time.Time + ) + if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil { + return nil, err + } + stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{ + ID: id, + CreatedAt: createdAt, + } + } + if err := rows.Err(); err != nil { + return nil, err + } + return stateMap, nil +} + +func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { + createdAt := log.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now() + } + + requestID := strings.TrimSpace(log.RequestID) + log.RequestID = requestID + + rateMultiplier := log.RateMultiplier + log.SyncRequestTypeAndLegacyFields() + requestType := int16(log.RequestType) + groupID := nullInt64(log.GroupID) subscriptionID := nullInt64(log.SubscriptionID) duration := nullInt(log.DurationMs) @@ -167,58 +537,64 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) requestIDArg = requestID } - args := []any{ - log.UserID, - log.APIKeyID, - log.AccountID, - requestIDArg, - log.Model, - groupID, - subscriptionID, - log.InputTokens, - log.OutputTokens, - log.CacheCreationTokens, - log.CacheReadTokens, - log.CacheCreation5mTokens, - log.CacheCreation1hTokens, - log.InputCost, - log.OutputCost, - log.CacheCreationCost, - log.CacheReadCost, - log.TotalCost, - log.ActualCost, - rateMultiplier, - log.AccountRateMultiplier, - log.BillingType, - requestType, - log.Stream, - log.OpenAIWSMode, - duration, - firstToken, - userAgent, - ipAddress, - log.ImageCount, - imageSize, - mediaType, - serviceTier, - reasoningEffort, - log.CacheTTLOverridden, - createdAt, + return usageLogInsertPrepared{ + createdAt: createdAt, + requestID: requestID, + rateMultiplier: rateMultiplier, + requestType: requestType, + args: []any{ + log.UserID, + log.APIKeyID, + log.AccountID, + requestIDArg, + log.Model, + groupID, + subscriptionID, + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + rateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + requestType, + log.Stream, + log.OpenAIWSMode, + duration, + firstToken, + userAgent, + ipAddress, + log.ImageCount, + imageSize, + mediaType, + serviceTier, + reasoningEffort, + log.CacheTTLOverridden, + createdAt, + }, } - if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { - if errors.Is(err, sql.ErrNoRows) && requestID != "" { - selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" - if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil { - return false, err - } - log.RateMultiplier = rateMultiplier - return false, nil - } else { - return false, err - } +} + +func usageLogBatchKey(requestID string, apiKeyID int64) string { + return requestID + "\x1f" + strconv.FormatInt(apiKeyID, 10) +} + +func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateResult) { + if ch == nil { + return + } + select { + case ch <- res: + default: } - log.RateMultiplier = rateMultiplier - return true, nil } func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 4d50f7de..d2e1e9d4 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -4,6 +4,8 @@ package repository import ( "context" + "fmt" + "sync" "testing" "time" @@ -14,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -84,6 +87,102 @@ func (s *UsageLogRepoSuite) TestCreate() { s.Require().NotZero(log.ID) } +func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()}) + + const total = 16 + results := make([]bool, total) + errs := make([]error, total) + logs := make([]*service.UsageLog, total) + + var wg sync.WaitGroup + wg.Add(total) + for i := 0; i < total; i++ { + i := i + logs[i] = &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10 + i, + OutputTokens: 20 + i, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + go func() { + defer wg.Done() + results[i], errs[i] = repo.Create(ctx, logs[i]) + }() + } + wg.Wait() + + for i := 0; i < total; i++ { + require.NoError(t, errs[i]) + require.True(t, results[i]) + require.NotZero(t, logs[i].ID) + } + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count)) + require.Equal(t, total, count) +} + +func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()}) + requestID := uuid.NewString() + + log1 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + log2 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + + inserted1, err1 := repo.Create(ctx, log1) + inserted2, err2 := repo.Create(ctx, log2) + require.NoError(t, err1) + require.NoError(t, err2) + require.True(t, inserted1) + require.False(t, inserted2) + require.Equal(t, log1.ID, log2.ID) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)) + require.Equal(t, 1, count) +} + func (s *UsageLogRepoSuite) TestGetByID() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) From 611fd884bd0a356ef8c0f97b2511ff4ca58951cd Mon Sep 17 00:00:00 2001 From: ius Date: Thu, 12 Mar 2026 16:53:18 +0800 Subject: [PATCH 02/18] feat: decouple billing correctness from usage log batching --- PR_REPORT_20260311_db_write_hotspots.md | 307 ++++++++ backend/cmd/server/wire_gen.go | 5 +- backend/internal/config/config.go | 22 +- backend/internal/config/config_test.go | 20 + backend/internal/handler/gateway_handler.go | 40 +- ...eway_handler_warmup_intercept_unit_test.go | 1 + .../internal/handler/gemini_v1beta_handler.go | 2 + .../handler/openai_gateway_handler.go | 53 +- .../handler/sora_client_handler_test.go | 2 +- .../internal/handler/sora_gateway_handler.go | 16 +- .../handler/sora_gateway_handler_test.go | 1 + .../repository/dashboard_aggregation_repo.go | 76 +- .../migrations_schema_integration_test.go | 31 + .../internal/repository/usage_billing_repo.go | 308 ++++++++ .../usage_billing_repo_integration_test.go | 279 +++++++ backend/internal/repository/usage_log_repo.go | 741 +++++++++++++++--- .../usage_log_repo_integration_test.go | 208 +++++ backend/internal/repository/wire.go | 1 + .../service/dashboard_aggregation_service.go | 8 +- .../dashboard_aggregation_service_test.go | 60 +- .../service/dashboard_service_test.go | 4 + ...teway_anthropic_apikey_passthrough_test.go | 78 +- .../service/gateway_record_usage_test.go | 261 ++++++ backend/internal/service/gateway_service.go | 406 ++++++++-- .../service/gateway_streaming_test.go | 3 +- .../openai_gateway_record_usage_test.go | 343 +++++++- .../service/openai_gateway_service.go | 99 ++- .../service/openai_gateway_service_test.go | 115 ++- .../openai_ws_protocol_forward_test.go | 1 + backend/internal/service/usage_billing.go | 110 +++ .../service/usage_cleanup_service_test.go | 18 +- .../service/usage_log_create_result.go | 60 ++ .../071_add_usage_billing_dedup.sql | 13 + ...age_billing_dedup_created_at_brin_notx.sql | 7 + .../073_add_usage_billing_dedup_archive.sql | 10 + deploy/build_image.sh | 0 deploy/install-datamanagementd.sh | 0 37 files changed, 3379 insertions(+), 330 deletions(-) create mode 100644 PR_REPORT_20260311_db_write_hotspots.md create mode 100644 backend/internal/repository/usage_billing_repo.go create mode 100644 backend/internal/repository/usage_billing_repo_integration_test.go create mode 100644 backend/internal/service/gateway_record_usage_test.go create mode 100644 backend/internal/service/usage_billing.go create mode 100644 backend/internal/service/usage_log_create_result.go create mode 100644 backend/migrations/071_add_usage_billing_dedup.sql create mode 100644 backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql create mode 100644 backend/migrations/073_add_usage_billing_dedup_archive.sql mode change 100755 => 100644 deploy/build_image.sh mode change 100755 => 100644 deploy/install-datamanagementd.sh diff --git a/PR_REPORT_20260311_db_write_hotspots.md b/PR_REPORT_20260311_db_write_hotspots.md new file mode 100644 index 00000000..54db3c92 --- /dev/null +++ b/PR_REPORT_20260311_db_write_hotspots.md @@ -0,0 +1,307 @@ +# PR Report: DB 写入热点与后台查询拥塞排查 + +## 背景 + +线上在高峰期出现了几类明显症状: + +- 管理后台仪表盘接口经常超时,`/api/v1/admin/dashboard/snapshot-v2` 一度达到 50s 以上 +- 管理后台充值接口 `/api/v1/admin/users/:id/balance` 出现 15s 以上超时 +- 登录态刷新、扣费、错误记录在高峰期出现大量 `context deadline exceeded` +- PostgreSQL 曾出现连接打满,后续回退连接池后,主问题转为 WAL/刷盘拥塞 + +本报告基于 `/home/ius/sub2api` 当前源码,目标是给出一份可直接拆成 PR 的修复方案。 + +## 结论 + +这次故障的主因不是单一“慢 SQL”,而是请求成功路径上的同步写库次数过多,叠加部分后台查询仍直接扫 `usage_logs`,最终把 PostgreSQL 的 WAL 刷盘、热点行更新和 outbox 重建链路一起放大。 + +代码层面的核心问题有 6 个。 + +### 1. 成功请求路径同步写库过多 + +`backend/internal/service/gateway_service.go:6594` 的 `postUsageBilling` 在单次请求成功后,可能同步触发以下写操作: + +- `userRepo.DeductBalance` +- `APIKeyService.UpdateQuotaUsed` +- `APIKeyService.UpdateRateLimitUsage` +- `accountRepo.IncrementQuotaUsed` +- `deferredService.ScheduleLastUsedUpdate`(这一项已经做了延迟批量,是正确方向) + +也就是说,一次成功请求不是 1 次落库,而是 3 到 5 次写入。 + +这和线上看到的现象是吻合的: + +- `UPDATE accounts SET extra = ...` +- `INSERT INTO usage_logs ...` +- `INSERT INTO ops_error_logs ...` +- `scheduler_outbox` backlog + +### 2. API Key 配额更新存在额外读写放大 + +`backend/internal/service/api_key_service.go:815` 的 `UpdateQuotaUsed` 当前流程是: + +1. `IncrementQuotaUsed` +2. `GetByID` +3. 如超限再 `Update` + +对应仓储实现: + +- `backend/internal/repository/api_key_repo.go:441` 只做自增 +- 然后 service 再回表读取完整 API Key +- 之后可能再整行更新状态 + +这让“每次扣费后更新 API Key 配额”从 1 条 SQL 变成了最多 3 次数据库交互。 + +### 3. `accounts.extra` 被当成高频热写字段使用 + +两个最重的热点都落在 `accounts.extra`: + +- `backend/internal/repository/account_repo.go:1159` `UpdateExtra` +- `backend/internal/repository/account_repo.go:1683` `IncrementQuotaUsed` + +问题有两个: + +1. 两者都会重写整块 JSONB,并更新 `updated_at` +2. `UpdateExtra` 每次写完都会额外插入一条 `scheduler_outbox` + +尤其 `UpdateExtra` 现在被多处高频调用: + +- `backend/internal/service/openai_gateway_service.go:4039` 持久化 Codex rate-limit snapshot +- `backend/internal/service/ratelimit_service.go:903` 持久化 OpenAI Codex snapshot +- `backend/internal/service/ratelimit_service.go:1013` / `1025` 更新 session window utilization + +这类“监控/额度快照”并不会改变账号是否可调度,却仍然走了: + +- JSONB 更新 +- `updated_at` +- `scheduler_outbox` + +这是明显的写放大。 + +### 4. `scheduler_outbox` 设计偏向“每次状态变更都写一条”,高峰期会反压调度器 + +`backend/internal/repository/scheduler_outbox_repo.go:79` 的 `enqueueSchedulerOutbox` 非常轻,但它被大量调用。 + +例如: + +- `UpdateExtra` 每次都 enqueue `AccountChanged` +- `BatchUpdateLastUsed` 也会 enqueue 一条 `AccountLastUsed` +- 各类账号限流、过载、错误状态切换也都会 enqueue + +对应的 outbox worker 在: + +- `backend/internal/service/scheduler_snapshot_service.go:199` +- `backend/internal/service/scheduler_snapshot_service.go:219` + +它会不断拉取 outbox,再触发 `GetByID`、`rebuildBucket`、`loadAccountsFromDB`。 + +所以当高频写入导致 outbox 增长时,系统不仅多了写,还会反向带出更多读和缓存重建。 + +### 5. 仪表盘只有一部分走了预聚合,`models/groups/users-trend` 仍然直接扫 `usage_logs` + +好消息是,`dashboard stats` 本身已经接了预聚合表: + +- `backend/internal/repository/usage_log_repo.go:306` +- `backend/internal/repository/usage_log_repo.go:420` +- 预聚合表定义在 `backend/migrations/034_usage_dashboard_aggregation_tables.sql:1` + +但后台慢的不是只有 stats。 + +`snapshot-v2` 默认会同时拉: + +- stats +- trend +- model stats + +见: + +- `backend/internal/handler/admin/dashboard_snapshot_v2_handler.go:68` + +其中: + +- `GetUsageTrendWithFilters` 只有“无过滤、day/hour”时才走预聚合,见 `usage_log_repo.go:1657` +- `GetModelStatsWithFilters` 直接扫 `usage_logs`,见 `usage_log_repo.go:1805` +- `GetGroupStatsWithFilters` 直接扫 `usage_logs`,见 `usage_log_repo.go:1872` +- `GetUserUsageTrend` 直接扫 `usage_logs`,见 `usage_log_repo.go:1101` +- `GetAPIKeyUsageTrend` 直接扫 `usage_logs`,见 `usage_log_repo.go:1046` + +所以线上会出现: + +- stats 快 +- 但 `snapshot-v2` 仍然慢 +- `/admin/dashboard/users-trend` 单独也慢 + +这和你线上看到的日志完全一致。 + +### 6. 管理后台充值是“读用户 -> 整体更新用户 -> 插审计记录” + +`backend/internal/service/admin_service.go:694` 的 `UpdateUserBalance` 当前流程: + +1. `GetByID` +2. 在内存里改 balance +3. `userRepo.Update` +4. `redeemCodeRepo.Create` 记录 admin 调账历史 + +而 `userRepo.Update` 是整用户对象更新,并同步 allowed groups 事务处理: + +- `backend/internal/repository/user_repo.go:118` + +这个接口平时不一定重,但在数据库已经抖动时,会比一个原子 `UPDATE users SET balance = balance + $1` 更脆弱。 + +## 额外观察 + +### `ops_error_logs` 虽然已异步化,但单条写入仍然很重 + +错误日志中间件已经做了队列削峰: + +- `backend/internal/handler/ops_error_logger.go:69` +- `backend/internal/handler/ops_error_logger.go:106` + +这点方向是对的。 + +但落库表本身很重: + +- `backend/internal/repository/ops_repo.go:23` +- `backend/migrations/033_ops_monitoring_vnext.sql:69` +- `backend/migrations/033_ops_monitoring_vnext.sql:470` + +`ops_error_logs` 不仅列很多,还带了多组 B-Tree 和 trigram 索引。高错误率时,即使改成异步,也还是会把 WAL 和 I/O 压上去。 + +## 建议的 PR 拆分 + +建议拆成 4 个 PR,不要在一个 PR 里同时改数据库模型、后台查询和管理接口。 + +### PR 1: 收缩成功请求路径的同步写库次数 + +目标:把一次成功请求的同步写次数从 3 到 5 次,尽量压到 1 到 2 次。 + +建议改动: + +1. 把 `APIKeyService.UpdateQuotaUsed` 改为单 SQL + - 新增 repo 方法,例如 `IncrementQuotaUsedAndMaybeExhaust` + - 在 SQL 里同时完成 `quota_used += ?` 和 `status = quota_exhausted` + - 返回 `key/status/quota/quota_used` 最小字段,直接失效缓存 + - 删掉当前的 `Increment -> GetByID -> Update` + +2. 把账号 quota 计数从 `accounts.extra` 拆出去 + - 最理想:新增结构化列或独立 `account_quota_counters` 表 + - 次优:至少把 `quota_used/quota_daily_used/quota_weekly_used` 从 JSONB 中剥离 + +3. 对“纯监控型 extra 字段”禁止 enqueue outbox + - 例如 codex snapshot、session_window_utilization + - 这些字段不影响调度,不应该触发 `SchedulerOutboxEventAccountChanged` + +4. 复用现有 `DeferredService` 思路 + - `last_used` 已经是批量刷盘,见 `deferred_service.go:41` + - 可继续扩展 `deferred quota snapshot flush` + +预期收益: + +- 直接减少 WAL 写入量 +- 降低 `accounts` 热点行锁竞争 +- 降低 outbox 增长速度 + +### PR 2: 给 dashboard 补齐预聚合/缓存,避免继续扫 `usage_logs` + +目标:后台仪表盘接口不再直接扫描大表。 + +建议改动: + +1. 为 `users-trend` / `api-keys-trend` 增加小时/天级预聚合表 +2. 为 `model stats` / `group stats` 增加日级聚合表 +3. `snapshot-v2` 增加分段缓存 + - `stats` + - `trend` + - `models` + - `groups` + - `users_trend` + 避免一个 section miss 导致整份 snapshot 重新扫库 +4. 可选:把 `include_model_stats` 默认值从 `true` 改成 `false` + - 至少让默认仪表盘先恢复可用,再按需加载重模块 + +预期收益: + +- `snapshot-v2` +- `/admin/dashboard/users-trend` +- `/admin/dashboard/api-keys-trend` + +这几类接口会从“随数据量线性恶化”变成“近似固定成本”。 + +### PR 3: 简化管理后台充值链路 + +目标:管理充值/扣余额不再依赖整用户对象更新。 + +建议改动: + +1. 新增 repo 原子方法 + - `SetBalance(userID, amount)` + - `AddBalance(userID, delta)` + - `SubtractBalance(userID, delta)` + +2. `UpdateUserBalance` 改为: + - 先执行原子 SQL + - 再读一次最小必要字段返回 + - 审计记录改为异步或降级写 + +3. 审计记录建议改名或独立表 + - 现在把后台调账记录塞进 `redeem_codes`,语义上不干净 + +预期收益: + +- `/api/v1/admin/users/:id/balance` 在库抖时更稳 +- 失败面缩小,不再被 allowed groups 同步事务拖累 + +### PR 4: 为重写路径增加“丢弃策略”和“熔断指标” + +目标:高峰期先保护主链路,不让非核心写入拖死数据库。 + +建议改动: + +1. `ops_error_logs` + - 增加采样或分级开关 + - 对重复 429/5xx 做聚合计数而不是逐条插入 + - 对 request body / headers 存储加更严格开关 + +2. `scheduler_outbox` + - 增加 coalesce/merge 机制 + - 同一账号短时间内多次 `AccountChanged` 合并为一条 + +3. 指标补齐 + - outbox backlog + - ops error queue dropped + - deferred flush lag + - account extra write QPS + +## 推荐实施顺序 + +1. 先做 PR 1 + - 这是这次线上故障的主链路 +2. 再做 PR 2 + - 解决后台仪表盘慢 +3. 再做 PR 3 + - 解决后台充值接口脆弱 +4. 最后做 PR 4 + - 做长期保护 + +## 验证方案 + +每个 PR 合并前都建议做同一组验证: + +1. 压测成功请求链路,记录单请求 SQL 次数 +2. 观测 PostgreSQL: + - `pg_stat_activity` + - `pg_stat_statements` + - `WALWrite` / `WalSync` + - 每分钟 WAL 增量 +3. 观测接口: + - `/api/v1/auth/refresh` + - `/api/v1/admin/dashboard/snapshot-v2` + - `/api/v1/admin/dashboard/users-trend` + - `/api/v1/admin/users/:id/balance` +4. 观测队列: + - `ops_error_logs` queue length / dropped + - `scheduler_outbox` backlog + +## 可直接作为 PR 描述的摘要 + +This PR reduces database write amplification on the request success path and removes several hot-path writes from `accounts.extra` + `scheduler_outbox`. It also prepares dashboard endpoints to rely on pre-aggregated data instead of scanning `usage_logs` under load. The goal is to keep admin dashboard, balance update, auth refresh, and billing-related paths stable under sustained 500+ RPS traffic. diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 034c70ec..76b2d0db 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) + usageBillingRepository := repository.NewUsageBillingRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemHandler := handler.NewRedeemHandler(redeemService) @@ -162,9 +163,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index de876098..e90e56af 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -934,9 +934,10 @@ type DashboardAggregationConfig struct { // DashboardAggregationRetentionConfig 预聚合保留窗口 type DashboardAggregationRetentionConfig struct { - UsageLogsDays int `mapstructure:"usage_logs_days"` - HourlyDays int `mapstructure:"hourly_days"` - DailyDays int `mapstructure:"daily_days"` + UsageLogsDays int `mapstructure:"usage_logs_days"` + UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"` + HourlyDays int `mapstructure:"hourly_days"` + DailyDays int `mapstructure:"daily_days"` } // UsageCleanupConfig 使用记录清理任务配置 @@ -1301,6 +1302,7 @@ func setDefaults() { viper.SetDefault("dashboard_aggregation.backfill_enabled", false) viper.SetDefault("dashboard_aggregation.backfill_max_days", 31) viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90) + viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365) viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180) viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) viper.SetDefault("dashboard_aggregation.recompute_days", 2) @@ -1758,6 +1760,12 @@ func (c *Config) Validate() error { if c.DashboardAgg.Retention.UsageLogsDays <= 0 { return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive") } + if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive") + } + if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days") + } if c.DashboardAgg.Retention.HourlyDays <= 0 { return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive") } @@ -1780,6 +1788,14 @@ func (c *Config) Validate() error { if c.DashboardAgg.Retention.UsageLogsDays < 0 { return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative") } + if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative") + } + if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 && + c.DashboardAgg.Retention.UsageLogsDays > 0 && + c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days") + } if c.DashboardAgg.Retention.HourlyDays < 0 { return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 79fcc6d0..abb76549 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { if cfg.DashboardAgg.Retention.UsageLogsDays != 90 { t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays) } + if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 { + t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays) + } if cfg.DashboardAgg.Retention.HourlyDays != 180 { t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays) } @@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 }, wantErr: "dashboard_aggregation.retention.usage_logs_days", }, + { + name: "dashboard aggregation dedup retention", + mutate: func(c *Config) { + c.DashboardAgg.Enabled = true + c.DashboardAgg.Retention.UsageBillingDedupDays = 0 + }, + wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days", + }, + { + name: "dashboard aggregation dedup retention smaller than usage logs", + mutate: func(c *Config) { + c.DashboardAgg.Enabled = true + c.DashboardAgg.Retention.UsageLogsDays = 30 + c.DashboardAgg.Retention.UsageBillingDedupDays = 29 + }, + wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days", + }, { name: "dashboard aggregation disabled interval", mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 }, diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 4441cf07..676ba0e1 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -434,19 +434,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - ForceCacheBilling: fs.ForceCacheBilling, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + ForceCacheBilling: fs.ForceCacheBilling, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), @@ -736,19 +738,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: currentAPIKey, - User: currentAPIKey.User, - Account: account, - Subscription: currentSubscription, - UserAgent: userAgent, - IPAddress: clientIP, - ForceCacheBilling: fs.ForceCacheBilling, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: currentAPIKey, + User: currentAPIKey.User, + Account: account, + Subscription: currentSubscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + ForceCacheBilling: fs.ForceCacheBilling, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 0c94d50b..6bcc0003 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // accountRepo (not used: scheduler snapshot hit) &fakeGroupRepo{group: group}, nil, // usageLogRepo + nil, // usageBillingRepo nil, // userRepo nil, // userSubRepo nil, // userGroupRateRepo diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 50af9c8f..9a16ff3a 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + requestPayloadHash := service.HashUsageRequestPayload(body) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, @@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { Subscription: subscription, UserAgent: userAgent, IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 ForceCacheBilling: fs.ForceCacheBilling, diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 8567b52b..d23c7efe 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -352,18 +352,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), @@ -732,17 +734,19 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.messages"), @@ -1231,14 +1235,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) h.submitUsageRecordTask(func(taskCtx context.Context) { if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), + APIKeyService: h.apiKeyService, }); err != nil { reqLog.Error("openai.websocket_record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index 30a761bd..dab17673 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac // newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( - accountRepo, nil, nil, nil, nil, nil, nil, nil, + accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ) } diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 48c1e451..06abdf60 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, }); err != nil { logger.L().With( zap.String("component", "handler.sora_gateway.chat_completions"), diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 688c5d12..088946e7 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -431,6 +431,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, nil, nil, + nil, testutil.StubGatewayCache{}, cfg, nil, diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go index 59bbd6a3..e82a73a3 100644 --- a/backend/internal/repository/dashboard_aggregation_repo.go +++ b/backend/internal/repository/dashboard_aggregation_repo.go @@ -17,6 +17,9 @@ type dashboardAggregationRepository struct { sql sqlExecutor } +const usageLogsCleanupBatchSize = 10000 +const usageBillingDedupCleanupBatchSize = 10000 + // NewDashboardAggregationRepository 创建仪表盘预聚合仓储。 func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository { if sqlDB == nil { @@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool { } func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error { + if r == nil || r.sql == nil { + return nil + } loc := timezone.Location() startLocal := start.In(loc) endLocal := end.In(loc) @@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta dayEnd = dayEnd.Add(24 * time.Hour) } + if db, ok := r.sql.(*sql.DB); ok { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + txRepo := newDashboardAggregationRepositoryWithSQL(tx) + if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() + } + return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd) +} + +func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error { // 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。 if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil { return err @@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c if isPartitioned { return r.dropUsageLogsPartitions(ctx, cutoff) } - _, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC()) - return err + for { + res, err := r.sql.ExecContext(ctx, ` + WITH victims AS ( + SELECT ctid + FROM usage_logs + WHERE created_at < $1 + LIMIT $2 + ) + DELETE FROM usage_logs + WHERE ctid IN (SELECT ctid FROM victims) + `, cutoff.UTC(), usageLogsCleanupBatchSize) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected < usageLogsCleanupBatchSize { + return nil + } + } +} + +func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + for { + res, err := r.sql.ExecContext(ctx, ` + WITH victims AS ( + SELECT ctid, request_id, api_key_id, request_fingerprint, created_at + FROM usage_billing_dedup + WHERE created_at < $1 + LIMIT $2 + ), archived AS ( + INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at) + SELECT request_id, api_key_id, request_fingerprint, created_at + FROM victims + ON CONFLICT (request_id, api_key_id) DO NOTHING + ) + DELETE FROM usage_billing_dedup + WHERE ctid IN (SELECT ctid FROM victims) + `, cutoff.UTC(), usageBillingDedupCleanupBatchSize) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected < usageBillingDedupCleanupBatchSize { + return nil + } + } } func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index 72422d18..dd3019bb 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false) requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false) + // usage_billing_dedup: billing idempotency narrow table + var usageBillingDedupRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass)) + require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist") + requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false) + requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key") + requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin") + + var usageBillingDedupArchiveRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass)) + require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist") + requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false) + requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey") + // settings table should exist var settingsRegclass sql.NullString require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass)) @@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false) } +func requireIndex(t *testing.T, tx *sql.Tx, table, index string) { + t.Helper() + + var exists bool + err := tx.QueryRowContext(context.Background(), ` +SELECT EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = $1 + AND indexname = $2 +) +`, table, index).Scan(&exists) + require.NoError(t, err, "query pg_indexes for %s.%s", table, index) + require.True(t, exists, "expected index %s on %s", index, table) +} + func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { t.Helper() diff --git a/backend/internal/repository/usage_billing_repo.go b/backend/internal/repository/usage_billing_repo.go new file mode 100644 index 00000000..b13cfeb8 --- /dev/null +++ b/backend/internal/repository/usage_billing_repo.go @@ -0,0 +1,308 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type usageBillingRepository struct { + db *sql.DB +} + +func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository { + return &usageBillingRepository{db: sqlDB} +} + +func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) { + if cmd == nil { + return &service.UsageBillingApplyResult{}, nil + } + if r == nil || r.db == nil { + return nil, errors.New("usage billing repository db is nil") + } + + cmd.Normalize() + if cmd.RequestID == "" { + return nil, service.ErrUsageBillingRequestIDRequired + } + + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + if tx != nil { + _ = tx.Rollback() + } + }() + + applied, err := r.claimUsageBillingKey(ctx, tx, cmd) + if err != nil { + return nil, err + } + if !applied { + return &service.UsageBillingApplyResult{Applied: false}, nil + } + + result := &service.UsageBillingApplyResult{Applied: true} + if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + tx = nil + return result, nil +} + +func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) { + var id int64 + err := tx.QueryRowContext(ctx, ` + INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint) + VALUES ($1, $2, $3) + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING id + `, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + var existingFingerprint string + if err := tx.QueryRowContext(ctx, ` + SELECT request_fingerprint + FROM usage_billing_dedup + WHERE request_id = $1 AND api_key_id = $2 + `, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil { + return false, err + } + if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) { + return false, service.ErrUsageBillingRequestConflict + } + return false, nil + } + if err != nil { + return false, err + } + var archivedFingerprint string + err = tx.QueryRowContext(ctx, ` + SELECT request_fingerprint + FROM usage_billing_dedup_archive + WHERE request_id = $1 AND api_key_id = $2 + `, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint) + if err == nil { + if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) { + return false, service.ErrUsageBillingRequestConflict + } + return false, nil + } + if !errors.Is(err, sql.ErrNoRows) { + return false, err + } + return true, nil +} + +func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error { + if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil { + if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil { + return err + } + } + + if cmd.BalanceCost > 0 { + if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil { + return err + } + } + + if cmd.APIKeyQuotaCost > 0 { + exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost) + if err != nil { + return err + } + result.APIKeyQuotaExhausted = exhausted + } + + if cmd.APIKeyRateLimitCost > 0 { + if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil { + return err + } + } + + if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) { + if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil { + return err + } + } + + return nil +} + +func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error { + const updateSQL = ` + UPDATE user_subscriptions us + SET + daily_usage_usd = us.daily_usage_usd + $1, + weekly_usage_usd = us.weekly_usage_usd + $1, + monthly_usage_usd = us.monthly_usage_usd + $1, + updated_at = NOW() + FROM groups g + WHERE us.id = $2 + AND us.deleted_at IS NULL + AND us.group_id = g.id + AND g.deleted_at IS NULL + ` + res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected > 0 { + return nil + } + return service.ErrSubscriptionNotFound +} + +func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error { + res, err := tx.ExecContext(ctx, ` + UPDATE users + SET balance = balance - $1, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + `, amount, userID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected > 0 { + return nil + } + return service.ErrUserNotFound +} + +func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) { + var exhausted bool + err := tx.QueryRowContext(ctx, ` + UPDATE api_keys + SET quota_used = quota_used + $1, + status = CASE + WHEN quota > 0 + AND status = $3 + AND quota_used < quota + AND quota_used + $1 >= quota + THEN $4 + ELSE status + END, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota + `, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted) + if errors.Is(err, sql.ErrNoRows) { + return false, service.ErrAPIKeyNotFound + } + if err != nil { + return false, err + } + return exhausted, nil +} + +func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error { + res, err := tx.ExecContext(ctx, ` + UPDATE api_keys SET + usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END, + usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END, + usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END, + window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END, + window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END, + window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + `, cost, apiKeyID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrAPIKeyNotFound + } + return nil +} + +func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error { + rows, err := tx.QueryContext(ctx, + `UPDATE accounts SET extra = ( + COALESCE(extra, '{}'::jsonb) + || jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1) + || CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_daily_used', + CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + THEN $1 + ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END, + 'quota_daily_start', + CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END + ) + ELSE '{}'::jsonb END + || CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_weekly_used', + CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + THEN $1 + ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END, + 'quota_weekly_start', + CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END + ) + ELSE '{}'::jsonb END + ), updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + RETURNING + COALESCE((extra->>'quota_used')::numeric, 0), + COALESCE((extra->>'quota_limit')::numeric, 0)`, + amount, accountID) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + var newUsed, limit float64 + if rows.Next() { + if err := rows.Scan(&newUsed, &limit); err != nil { + return err + } + } else { + if err := rows.Err(); err != nil { + return err + } + return service.ErrAccountNotFound + } + if err := rows.Err(); err != nil { + return err + } + if limit > 0 && newUsed >= limit && (newUsed-amount) < limit { + if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil { + logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err) + return err + } + } + return nil +} diff --git a/backend/internal/repository/usage_billing_repo_integration_test.go b/backend/internal/repository/usage_billing_repo_integration_test.go new file mode 100644 index 00000000..eda34cc9 --- /dev/null +++ b/backend/internal/repository/usage_billing_repo_integration_test.go @@ -0,0 +1,279 @@ +//go:build integration + +package repository + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-" + uuid.NewString(), + Name: "billing", + Quota: 1, + }) + account := mustCreateAccount(t, client, &service.Account{ + Name: "usage-billing-account-" + uuid.NewString(), + Type: service.AccountTypeAPIKey, + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: account.ID, + AccountType: service.AccountTypeAPIKey, + BalanceCost: 1.25, + APIKeyQuotaCost: 1.25, + APIKeyRateLimitCost: 1.25, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.NotNil(t, result1) + require.True(t, result1.Applied) + require.True(t, result1.APIKeyQuotaExhausted) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.NotNil(t, result2) + require.False(t, result2.Applied) + + var balance float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance)) + require.InDelta(t, 98.75, balance, 0.000001) + + var quotaUsed float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed)) + require.InDelta(t, 1.25, quotaUsed, 0.000001) + + var usage5h float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h)) + require.InDelta(t, 1.25, usage5h, 0.000001) + + var status string + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status)) + require.Equal(t, service.StatusAPIKeyQuotaExhausted, status) + + var dedupCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount)) + require.Equal(t, 1, dedupCount) +} + +func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + }) + group := mustCreateGroup(t, client, &service.Group{ + Name: "usage-billing-group-" + uuid.NewString(), + Platform: service.PlatformAnthropic, + SubscriptionType: service.SubscriptionTypeSubscription, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + GroupID: &group.ID, + Key: "sk-usage-billing-sub-" + uuid.NewString(), + Name: "billing-sub", + }) + subscription := mustCreateSubscription(t, client, &service.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: 0, + SubscriptionID: &subscription.ID, + SubscriptionCost: 2.5, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.True(t, result1.Applied) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.False(t, result2.Applied) + + var dailyUsage float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage)) + require.InDelta(t, 2.5, dailyUsage, 0.000001) +} + +func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-conflict-" + uuid.NewString(), + Name: "billing-conflict", + }) + + requestID := uuid.NewString() + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 1.25, + }) + require.NoError(t, err) + + _, err = repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 2.50, + }) + require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict) +} + +func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-account-" + uuid.NewString(), + Name: "billing-account", + }) + account := mustCreateAccount(t, client, &service.Account{ + Name: "usage-billing-account-quota-" + uuid.NewString(), + Type: service.AccountTypeAPIKey, + Extra: map[string]any{ + "quota_limit": 100.0, + }, + }) + + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: account.ID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 3.5, + }) + require.NoError(t, err) + + var quotaUsed float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed)) + require.InDelta(t, 3.5, quotaUsed, 0.000001) +} + +func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) { + ctx := context.Background() + repo := newDashboardAggregationRepositoryWithSQL(integrationDB) + + oldRequestID := "dedup-old-" + uuid.NewString() + newRequestID := "dedup-new-" + uuid.NewString() + oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400) + newCreatedAt := time.Now().UTC().Add(-time.Hour) + + _, err := integrationDB.ExecContext(ctx, ` + INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at) + VALUES ($1, 1, $2, $3), ($4, 1, $5, $6) + `, + oldRequestID, strings.Repeat("a", 64), oldCreatedAt, + newRequestID, strings.Repeat("b", 64), newCreatedAt, + ) + require.NoError(t, err) + + require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365))) + + var oldCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount)) + require.Equal(t, 0, oldCount) + + var newCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount)) + require.Equal(t, 1, newCount) + + var archivedCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount)) + require.Equal(t, 1, archivedCount) +} + +func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-archive-" + uuid.NewString(), + Name: "billing-archive", + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 1.25, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.True(t, result1.Applied) + + _, err = integrationDB.ExecContext(ctx, ` + UPDATE usage_billing_dedup + SET created_at = $1 + WHERE request_id = $2 AND api_key_id = $3 + `, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID) + require.NoError(t, err) + require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365))) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.False(t, result2.Applied) + + var balance float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance)) + require.InDelta(t, 98.75, balance, 0.000001) +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 8ffcb2f3..5e81818b 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3,12 +3,14 @@ package repository import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "os" "strconv" "strings" "sync" + "sync/atomic" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -17,11 +19,13 @@ import ( dbgroup "github.com/Wei-Shaw/sub2api/ent/group" dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" + gocache "github.com/patrickmn/go-cache" ) const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at" @@ -47,18 +51,29 @@ type usageLogRepository struct { sql sqlExecutor db *sql.DB - createBatchOnce sync.Once - createBatchCh chan usageLogCreateRequest + createBatchOnce sync.Once + createBatchCh chan usageLogCreateRequest + bestEffortBatchOnce sync.Once + bestEffortBatchCh chan usageLogBestEffortRequest + bestEffortRecent *gocache.Cache } const ( usageLogCreateBatchMaxSize = 64 usageLogCreateBatchWindow = 3 * time.Millisecond usageLogCreateBatchQueueCap = 4096 + usageLogCreateCancelWait = 2 * time.Second + + usageLogBestEffortBatchMaxSize = 256 + usageLogBestEffortBatchWindow = 20 * time.Millisecond + usageLogBestEffortBatchQueueCap = 32768 + usageLogBestEffortRecentTTL = 30 * time.Second ) type usageLogCreateRequest struct { log *service.UsageLog + prepared usageLogInsertPrepared + shared *usageLogCreateShared resultCh chan usageLogCreateResult } @@ -67,6 +82,12 @@ type usageLogCreateResult struct { err error } +type usageLogBestEffortRequest struct { + prepared usageLogInsertPrepared + apiKeyID int64 + resultCh chan error +} + type usageLogInsertPrepared struct { createdAt time.Time requestID string @@ -80,6 +101,25 @@ type usageLogBatchState struct { CreatedAt time.Time } +type usageLogBatchRow struct { + RequestID string `json:"request_id"` + APIKeyID int64 `json:"api_key_id"` + ID int64 `json:"id"` + CreatedAt time.Time `json:"created_at"` + Inserted bool `json:"inserted"` +} + +type usageLogCreateShared struct { + state atomic.Int32 +} + +const ( + usageLogCreateStateQueued int32 = iota + usageLogCreateStateProcessing + usageLogCreateStateCompleted + usageLogCreateStateCanceled +) + func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository { return newUsageLogRepositoryWithSQL(client, sqlDB) } @@ -90,6 +130,7 @@ func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usage if db, ok := sqlq.(*sql.DB); ok { repo.db = db } + repo.bestEffortRecent = gocache.New(usageLogBestEffortRecentTTL, time.Minute) return repo } @@ -124,9 +165,6 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) if tx := dbent.TxFromContext(ctx); tx != nil { return r.createSingle(ctx, tx.Client(), log) } - if r.db == nil { - return r.createSingle(ctx, r.sql, log) - } requestID := strings.TrimSpace(log.RequestID) if requestID == "" { return r.createSingle(ctx, r.sql, log) @@ -135,11 +173,61 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) return r.createBatched(ctx, log) } +func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error { + if log == nil { + return nil + } + + if tx := dbent.TxFromContext(ctx); tx != nil { + _, err := r.createSingle(ctx, tx.Client(), log) + return err + } + if r.db == nil { + _, err := r.createSingle(ctx, r.sql, log) + return err + } + + r.ensureBestEffortBatcher() + if r.bestEffortBatchCh == nil { + _, err := r.createSingle(ctx, r.sql, log) + return err + } + + req := usageLogBestEffortRequest{ + prepared: prepareUsageLogInsert(log), + apiKeyID: log.APIKeyID, + resultCh: make(chan error, 1), + } + if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok { + if _, exists := r.bestEffortRecent.Get(key); exists { + return nil + } + } + + select { + case r.bestEffortBatchCh <- req: + case <-ctx.Done(): + return ctx.Err() + default: + return errors.New("usage log best-effort queue full") + } + + select { + case err := <-req.resultCh: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) { prepared := prepareUsageLogInsert(log) if sqlq == nil { sqlq = r.sql } + if ctx != nil && ctx.Err() != nil { + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + } query := ` INSERT INTO usage_logs ( @@ -218,13 +306,15 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa req := usageLogCreateRequest{ log: log, + prepared: prepareUsageLogInsert(log), + shared: &usageLogCreateShared{}, resultCh: make(chan usageLogCreateResult, 1), } select { case r.createBatchCh <- req: case <-ctx.Done(): - return false, ctx.Err() + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) default: return r.createSingle(ctx, r.sql, log) } @@ -233,7 +323,17 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa case res := <-req.resultCh: return res.inserted, res.err case <-ctx.Done(): - return false, ctx.Err() + if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) { + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + } + timer := time.NewTimer(usageLogCreateCancelWait) + defer timer.Stop() + select { + case res := <-req.resultCh: + return res.inserted, res.err + case <-timer.C: + return false, ctx.Err() + } } } @@ -247,6 +347,16 @@ func (r *usageLogRepository) ensureCreateBatcher() { }) } +func (r *usageLogRepository) ensureBestEffortBatcher() { + if r == nil || r.db == nil { + return + } + r.bestEffortBatchOnce.Do(func() { + r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap) + go r.runBestEffortBatcher(r.db) + }) +} + func (r *usageLogRepository) runCreateBatcher(db *sql.DB) { for { first, ok := <-r.createBatchCh @@ -281,6 +391,40 @@ func (r *usageLogRepository) runCreateBatcher(db *sql.DB) { } } +func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) { + for { + first, ok := <-r.bestEffortBatchCh + if !ok { + return + } + + batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize) + batch = append(batch, first) + + timer := time.NewTimer(usageLogBestEffortBatchWindow) + bestEffortLoop: + for len(batch) < usageLogBestEffortBatchMaxSize { + select { + case req, ok := <-r.bestEffortBatchCh: + if !ok { + break bestEffortLoop + } + batch = append(batch, req) + case <-timer.C: + break bestEffortLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + + r.flushBestEffortBatch(db, batch) + } +} + func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) { if len(batch) == 0 { return @@ -293,10 +437,19 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate for _, req := range batch { if req.log == nil { - sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil}) + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) continue } - prepared := prepareUsageLogInsert(req.log) + if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) { + if req.shared.state.Load() == usageLogCreateStateCanceled { + completeUsageLogCreateRequest(req, usageLogCreateResult{ + inserted: false, + err: service.MarkUsageLogCreateNotPersisted(context.Canceled), + }) + continue + } + } + prepared := req.prepared if prepared.requestID == "" { fallback = append(fallback, req) continue @@ -310,10 +463,37 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate } if len(uniqueOrder) > 0 { - insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey) + insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey) if err != nil { - for _, key := range uniqueOrder { - fallback = append(fallback, requestsByKey[key]...) + if safeFallback { + for _, key := range uniqueOrder { + fallback = append(fallback, requestsByKey[key]...) + } + } else { + for _, key := range uniqueOrder { + reqs := requestsByKey[key] + state, hasState := stateMap[key] + inserted := insertedMap[key] + for idx, req := range reqs { + req.log.RateMultiplier = preparedByKey[key].rateMultiplier + if hasState { + req.log.ID = state.ID + req.log.CreatedAt = state.CreatedAt + } + switch { + case inserted && idx == 0: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil}) + case inserted: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + case hasState: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + case idx == 0: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err}) + default: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + } + } + } } } else { for _, key := range uniqueOrder { @@ -321,7 +501,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate state, ok := stateMap[key] if !ok { for _, req := range reqs { - sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{ + completeUsageLogCreateRequest(req, usageLogCreateResult{ inserted: false, err: fmt.Errorf("usage log batch state missing for key=%s", key), }) @@ -332,7 +512,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate req.log.ID = state.ID req.log.CreatedAt = state.CreatedAt req.log.RateMultiplier = preparedByKey[key].rateMultiplier - sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{ + completeUsageLogCreateRequest(req, usageLogCreateResult{ inserted: idx == 0 && insertedMap[key], err: nil, }) @@ -345,56 +525,366 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate return } - fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() for _, req := range fallback { + fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) inserted, err := r.createSingle(fallbackCtx, db, req.log) - sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err}) + cancel() + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err}) } } -func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) { +func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) { + if len(batch) == 0 { + return + } + + type bestEffortGroup struct { + prepared usageLogInsertPrepared + apiKeyID int64 + key string + reqs []usageLogBestEffortRequest + } + + groupsByKey := make(map[string]*bestEffortGroup, len(batch)) + groupOrder := make([]*bestEffortGroup, 0, len(batch)) + preparedList := make([]usageLogInsertPrepared, 0, len(batch)) + + for idx, req := range batch { + prepared := req.prepared + key := fmt.Sprintf("__best_effort_%d", idx) + if prepared.requestID != "" { + key = usageLogBatchKey(prepared.requestID, req.apiKeyID) + } + group, exists := groupsByKey[key] + if !exists { + group = &bestEffortGroup{ + prepared: prepared, + apiKeyID: req.apiKeyID, + key: key, + } + groupsByKey[key] = group + groupOrder = append(groupOrder, group) + preparedList = append(preparedList, prepared) + } + group.reqs = append(group.reqs, req) + } + + if len(preparedList) == 0 { + for _, req := range batch { + sendUsageLogBestEffortResult(req.resultCh, nil) + } + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + query, args := buildUsageLogBestEffortInsertQuery(preparedList) + if _, err := db.ExecContext(ctx, query, args...); err != nil { + logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err) + for _, group := range groupOrder { + singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared) + if singleErr != nil { + logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr) + } else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { + r.bestEffortRecent.SetDefault(group.key, struct{}{}) + } + for _, req := range group.reqs { + sendUsageLogBestEffortResult(req.resultCh, singleErr) + } + } + return + } + for _, group := range groupOrder { + if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { + r.bestEffortRecent.SetDefault(group.key, struct{}{}) + } + for _, req := range group.reqs { + sendUsageLogBestEffortResult(req.resultCh, nil) + } + } +} + +func sendUsageLogBestEffortResult(ch chan error, err error) { + if ch == nil { + return + } + select { + case ch <- err: + default: + } +} + +func completeUsageLogCreateRequest(req usageLogCreateRequest, res usageLogCreateResult) { + if req.shared != nil { + req.shared.state.Store(usageLogCreateStateCompleted) + } + sendUsageLogCreateResult(req.resultCh, res) +} + +func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) { if len(keys) == 0 { - return map[string]bool{}, map[string]usageLogBatchState{}, nil + return map[string]bool{}, map[string]usageLogBatchState{}, false, nil } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey) - rows, err := db.QueryContext(ctx, query, args...) - if err != nil { - return nil, nil, err + var payload []byte + if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil { + return nil, nil, true, err + } + var rows []usageLogBatchRow + if err := json.Unmarshal(payload, &rows); err != nil { + return nil, nil, false, err } insertedMap := make(map[string]bool, len(keys)) - for rows.Next() { - var ( - requestID string - apiKeyID int64 - id int64 - createdAt time.Time - ) - if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil { - _ = rows.Close() - return nil, nil, err + stateMap := make(map[string]usageLogBatchState, len(keys)) + for _, row := range rows { + key := usageLogBatchKey(row.RequestID, row.APIKeyID) + insertedMap[key] = row.Inserted + stateMap[key] = usageLogBatchState{ + ID: row.ID, + CreatedAt: row.CreatedAt, } - insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true } - if err := rows.Err(); err != nil { - _ = rows.Close() - return nil, nil, err + if len(stateMap) != len(keys) { + return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys)) } - _ = rows.Close() - - stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey) - if err != nil { - return nil, nil, err - } - return insertedMap, stateMap, nil + return insertedMap, stateMap, false, nil } func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) { var query strings.Builder _, _ = query.WriteString(` + WITH input ( + input_idx, + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) AS (VALUES `) + + args := make([]any, 0, len(keys)*37) + argPos := 1 + for idx, key := range keys { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + args = append(args, idx) + argPos++ + prepared := preparedByKey[key] + for i := 0; i < len(prepared.args); i++ { + _, _ = query.WriteString(",") + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + _, _ = query.WriteString(` + ), + inserted AS ( + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) + SELECT + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + FROM input + ON CONFLICT (request_id, api_key_id) DO UPDATE + SET request_id = usage_logs.request_id + RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted + ) + SELECT COALESCE( + json_agg( + json_build_object( + 'request_id', inserted.request_id, + 'api_key_id', inserted.api_key_id, + 'id', inserted.id, + 'created_at', inserted.created_at, + 'inserted', inserted.inserted + ) + ORDER BY input.input_idx + ), + '[]'::json + ) + FROM input + JOIN inserted + ON inserted.request_id = input.request_id + AND inserted.api_key_id = input.api_key_id + `) + return query.String(), args +} + +func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) { + var query strings.Builder + _, _ = query.WriteString(` + WITH input ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) AS (VALUES `) + + args := make([]any, 0, len(preparedList)*36) + argPos := 1 + for idx, prepared := range preparedList { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + for i := 0; i < len(prepared.args); i++ { + if i > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + + _, _ = query.WriteString(` + ) INSERT INTO usage_logs ( user_id, api_key_id, @@ -432,80 +922,101 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage reasoning_effort, cache_ttl_overridden, created_at - ) VALUES `) - - args := make([]any, 0, len(keys)*36) - argPos := 1 - for idx, key := range keys { - if idx > 0 { - _, _ = query.WriteString(",") - } - _, _ = query.WriteString("(") - prepared := preparedByKey[key] - for i := 0; i < len(prepared.args); i++ { - if i > 0 { - _, _ = query.WriteString(",") - } - _, _ = query.WriteString("$") - _, _ = query.WriteString(strconv.Itoa(argPos)) - argPos++ - } - _, _ = query.WriteString(")") - args = append(args, prepared.args...) - } - _, _ = query.WriteString(` + ) + SELECT + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING - RETURNING request_id, api_key_id, id, created_at `) + return query.String(), args } -func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) { - var query strings.Builder - _, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `) - args := make([]any, 0, len(keys)*2) - argPos := 1 - for idx, key := range keys { - if idx > 0 { - _, _ = query.WriteString(" OR ") - } - prepared := preparedByKey[key] - apiKeyID := prepared.args[1] - _, _ = query.WriteString("(request_id = $") - _, _ = query.WriteString(strconv.Itoa(argPos)) - _, _ = query.WriteString(" AND api_key_id = $") - _, _ = query.WriteString(strconv.Itoa(argPos + 1)) - _, _ = query.WriteString(")") - args = append(args, prepared.requestID, apiKeyID) - argPos += 2 - } - - rows, err := db.QueryContext(ctx, query.String(), args...) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - - stateMap := make(map[string]usageLogBatchState, len(keys)) - for rows.Next() { - var ( - requestID string - apiKeyID int64 - id int64 - createdAt time.Time +func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error { + _, err := sqlq.ExecContext(ctx, ` + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, + $8, $9, $10, $11, + $12, $13, + $14, $15, $16, $17, $18, $19, + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36 ) - if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil { - return nil, err - } - stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{ - ID: id, - CreatedAt: createdAt, - } - } - if err := rows.Err(); err != nil { - return nil, err - } - return stateMap, nil + ON CONFLICT (request_id, api_key_id) DO NOTHING + `, prepared.args...) + return err } func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { @@ -597,6 +1108,14 @@ func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateRe } } +func (r *usageLogRepository) bestEffortRecentKey(requestID string, apiKeyID int64) (string, bool) { + requestID = strings.TrimSpace(requestID) + if requestID == "" || r == nil || r.bestEffortRecent == nil { + return "", false + } + return usageLogBatchKey(requestID, apiKeyID), true +} + func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1" rows, err := r.sql.QueryContext(ctx, query, id) diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index d2e1e9d4..00740878 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -183,6 +183,214 @@ func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) { require.Equal(t, 1, count) } +func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()}) + requestID := uuid.NewString() + + const total = 8 + batch := make([]usageLogCreateRequest, 0, total) + logs := make([]*service.UsageLog, 0, total) + + for i := 0; i < total; i++ { + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10 + i, + OutputTokens: 20 + i, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + logs = append(logs, log) + batch = append(batch, usageLogCreateRequest{ + log: log, + prepared: prepareUsageLogInsert(log), + resultCh: make(chan usageLogCreateResult, 1), + }) + } + + repo.flushCreateBatch(integrationDB, batch) + + insertedCount := 0 + var firstID int64 + for idx, req := range batch { + res := <-req.resultCh + require.NoError(t, res.err) + if res.inserted { + insertedCount++ + } + require.NotZero(t, logs[idx].ID) + if idx == 0 { + firstID = logs[idx].ID + } else { + require.Equal(t, firstID, logs[idx].ID) + } + } + + require.Equal(t, 1, insertedCount) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)) + require.Equal(t, 1, count) +} + +func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()}) + requestID := uuid.NewString() + + log1 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + log2 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + + require.NoError(t, repo.CreateBestEffort(ctx, log1)) + require.NoError(t, repo.CreateBestEffort(ctx, log2)) + + require.Eventually(t, func() bool { + var count int + err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count) + return err == nil && count == 1 + }, 3*time.Second, 20*time.Millisecond) +} + +func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + inserted, err := repo.Create(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.False(t, inserted) + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) +} + +func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.createBatchCh = make(chan usageLogCreateRequest, 1) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()}) + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + + go func() { + _, err := repo.createBatched(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + errCh <- err + }() + + req := <-repo.createBatchCh + require.NotNil(t, req.shared) + cancel() + + err := <-errCh + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)}) +} + +func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + req := usageLogCreateRequest{ + log: log, + prepared: prepareUsageLogInsert(log), + shared: &usageLogCreateShared{}, + resultCh: make(chan usageLogCreateResult, 1), + } + req.shared.state.Store(usageLogCreateStateCanceled) + + repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req}) + + res := <-req.resultCh + require.False(t, res.inserted) + require.Error(t, res.err) + require.True(t, service.IsUsageLogCreateNotPersisted(res.err)) +} + func (s *UsageLogRepoSuite) TestGetByID() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 5fe7a98e..01395bcb 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet( NewAnnouncementRepository, NewAnnouncementReadRepository, NewUsageLogRepository, + NewUsageBillingRepository, NewIdempotencyRepository, NewUsageCleanupRepository, NewDashboardAggregationRepository, diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index a67f8532..b58a1ea9 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -35,6 +35,7 @@ type DashboardAggregationRepository interface { UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error CleanupUsageLogs(ctx context.Context, cutoff time.Time) error + CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error } @@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays) dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays) usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays) + dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays) aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff) if aggErr != nil { @@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, if usageErr != nil { logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) } - if aggErr == nil && usageErr == nil { + dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff) + if dedupErr != nil { + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr) + } + if aggErr == nil && usageErr == nil && dedupErr == nil { s.lastRetentionCleanup.Store(now) } } diff --git a/backend/internal/service/dashboard_aggregation_service_test.go b/backend/internal/service/dashboard_aggregation_service_test.go index a7058985..fbb671bb 100644 --- a/backend/internal/service/dashboard_aggregation_service_test.go +++ b/backend/internal/service/dashboard_aggregation_service_test.go @@ -12,12 +12,18 @@ import ( type dashboardAggregationRepoTestStub struct { aggregateCalls int + recomputeCalls int + cleanupUsageCalls int + cleanupDedupCalls int + ensurePartitionCalls int lastStart time.Time lastEnd time.Time watermark time.Time aggregateErr error cleanupAggregatesErr error cleanupUsageErr error + cleanupDedupErr error + ensurePartitionErr error } func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error { @@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s } func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + s.recomputeCalls++ return s.AggregateRange(ctx, start, end) } @@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context } func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error { + s.cleanupUsageCalls++ return s.cleanupUsageErr } +func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + s.cleanupDedupCalls++ + return s.cleanupDedupErr +} + func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { - return nil + s.ensurePartitionCalls++ + return s.ensurePartitionErr } func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) { @@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te svc.maybeCleanupRetention(context.Background(), time.Now().UTC()) require.Nil(t, svc.lastRetentionCleanup.Load()) + require.Equal(t, 1, repo.cleanupUsageCalls) + require.Equal(t, 1, repo.cleanupDedupCalls) +} + +func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.maybeCleanupRetention(context.Background(), time.Now().UTC()) + + require.Nil(t, svc.lastRetentionCleanup.Load()) + require.Equal(t, 1, repo.cleanupDedupCalls) +} + +func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Enabled: true, + IntervalSeconds: 60, + LookbackSeconds: 120, + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + UsageBillingDedupDays: 2, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.runScheduledAggregation() + + require.Equal(t, 1, repo.ensurePartitionCalls) + require.Equal(t, 1, repo.aggregateCalls) } func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) { diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go index 59b83e66..2a7f47b6 100644 --- a/backend/internal/service/dashboard_service_test.go +++ b/backend/internal/service/dashboard_service_test.go @@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut return nil } +func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + return nil +} + func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { return nil } diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 5dcda1de..789cbab8 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd }, } - svc := &GatewayService{ - cfg: &config.Config{ - Gateway: config.GatewayConfig{ - MaxLineSize: defaultMaxLineSize, - }, + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, }, - httpUpstream: upstream, - rateLimitService: &RateLimitService{}, - deferredService: &DeferredService{}, - billingCacheService: nil, + } + svc := &GatewayService{ + cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + deferredService: &DeferredService{}, + billingCacheService: nil, } account := &Account{ @@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo }, } - svc := &GatewayService{ - cfg: &config.Config{ - Gateway: config.GatewayConfig{ - MaxLineSize: defaultMaxLineSize, - }, + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, }, - httpUpstream: upstream, - rateLimitService: &RateLimitService{}, + } + svc := &GatewayService{ + cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, } account := &Account{ @@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf require.Equal(t, 5, result.usage.OutputTokens) } +func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`, + "", + `data: {"type":"message_delta","usage":{"output_tokens":5}}`, + "", + }, "\n"))), + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) +} + func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() @@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi _ = pr.Close() <-done - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete after timeout") require.NotNil(t, result) require.True(t, result.clientDisconnect) require.Equal(t, 9, result.usage.InputTokens) @@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t } result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219") - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete") require.NotNil(t, result) require.True(t, result.clientDisconnect) } @@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft } result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219") - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete after disconnect") require.NotNil(t, result) require.True(t, result.clientDisconnect) require.Equal(t, 8, result.usage.InputTokens) diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go new file mode 100644 index 00000000..92e59ac8 --- /dev/null +++ b/backend/internal/service/gateway_record_usage_test.go @@ -0,0 +1,261 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 1.1 + return NewGatewayService( + nil, + nil, + usageRepo, + nil, + userRepo, + subRepo, + nil, + nil, + cfg, + nil, + nil, + NewBillingService(cfg, nil), + nil, + &BillingCacheService{}, + nil, + nil, + &DeferredService{}, + nil, + nil, + nil, + nil, + nil, + ) +} + +func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService { + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + svc.usageBillingRepo = billingRepo + return svc +} + +func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_detached_ctx", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 501, + Quota: 100, + }, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`)) + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_payload_hash", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + RequestPayloadHash: payloadHash, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash) +} + +func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_payload_fallback", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash) +} + +func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_not_persisted", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 503, + Quota: 100, + }, + User: &User{ID: 603}, + Account: &Account{ID: 703}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 1, quotaSvc.quotaCalls) +} + +func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{ + Result: &ForwardResult{ + RequestID: "gateway_long_context_detached_ctx", + Usage: ClaudeUsage{ + InputTokens: 12, + OutputTokens: 8, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 502, + Quota: 100, + }, + User: &User{ID: 602}, + Account: &Account{ID: 702}, + LongContextThreshold: 200000, + LongContextMultiplier: 2, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 504}, + User: &User{ID: 604}, + Account: &Account{ID: 704}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_billing_fail", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 505}, + User: &User{ID: 605}, + Account: &Account{ID: 705}, + }) + + require.Error(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 0, usageRepo.calls) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 080de063..670ff21e 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -50,6 +50,7 @@ const ( defaultUserGroupRateCacheTTL = 30 * time.Second defaultModelsListCacheTTL = 15 * time.Second + postUsageBillingTimeout = 15 * time.Second ) const ( @@ -106,6 +107,52 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() } +func claudeUsageHasAnyTokens(usage *ClaudeUsage) bool { + return usage != nil && (usage.InputTokens > 0 || + usage.OutputTokens > 0 || + usage.CacheCreationInputTokens > 0 || + usage.CacheReadInputTokens > 0 || + usage.CacheCreation5mTokens > 0 || + usage.CacheCreation1hTokens > 0) +} + +func openAIUsageHasAnyTokens(usage *OpenAIUsage) bool { + return usage != nil && (usage.InputTokens > 0 || + usage.OutputTokens > 0 || + usage.CacheCreationInputTokens > 0 || + usage.CacheReadInputTokens > 0) +} + +func openAIStreamEventIsTerminal(data string) bool { + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + switch gjson.Get(trimmed, "type").String() { + case "response.completed", "response.done", "response.failed": + return true + default: + return false + } +} + +func anthropicStreamEventIsTerminal(eventName, data string) bool { + if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") { + return true + } + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + return gjson.Get(trimmed, "type").String() == "message_stop" +} + func cloneStringSlice(src []string) []string { if len(src) == 0 { return nil @@ -504,6 +551,7 @@ type GatewayService struct { accountRepo AccountRepository groupRepo GroupRepository usageLogRepo UsageLogRepository + usageBillingRepo UsageBillingRepository userRepo UserRepository userSubRepo UserSubscriptionRepository userGroupRateRepo UserGroupRateRepository @@ -537,6 +585,7 @@ func NewGatewayService( accountRepo AccountRepository, groupRepo GroupRepository, usageLogRepo UsageLogRepository, + usageBillingRepo UsageBillingRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, userGroupRateRepo UserGroupRateRepository, @@ -563,6 +612,7 @@ func NewGatewayService( accountRepo: accountRepo, groupRepo: groupRepo, usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, userRepo: userRepo, userSubRepo: userSubRepo, userGroupRateRepo: userGroupRateRepo, @@ -4049,7 +4099,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -4127,7 +4179,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // also downgrade tool_use/tool_result blocks to text. filteredBody := FilterThinkingBlocksForRetry(body) - retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx() if buildErr == nil { retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -4159,7 +4213,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) - retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream) + retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx2() if buildErr2 == nil { retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr2 == nil { @@ -4226,7 +4282,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A rectifiedBody, applied := RectifyThinkingBudget(body) if applied && time.Since(retryStart) < maxRetryElapsed { logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens) - budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseBudgetRetryCtx() if buildErr == nil { budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -4498,7 +4556,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( var resp *http.Response retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { - upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -4774,6 +4834,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( usage := &ClaudeUsage{} var firstTokenMs *int clientDisconnected := false + sawTerminalEvent := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -4836,17 +4897,20 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 flusher.Flush() } + if !sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil } if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } if clientDisconnected { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) } if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v", - account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err()) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) } if errors.Is(ev.err, bufio.ErrTooLong) { logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) @@ -4858,11 +4922,19 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( line := ev.line if data, ok := extractAnthropicSSEDataLine(line); ok { trimmed := strings.TrimSpace(data) + if anthropicStreamEventIsTerminal("", trimmed) { + sawTerminalEvent = true + } if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } s.parseSSEUsagePassthrough(data, usage) + } else { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") { + sawTerminalEvent = true + } } if !clientDisconnected { @@ -4884,8 +4956,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( continue } if clientDisconnected { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) if s.rateLimitService != nil { @@ -6011,6 +6082,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage + sawTerminalEvent := false pendingEventLines := make([]string, 0, 4) @@ -6041,6 +6113,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if dataLine == "[DONE]" { + sawTerminalEvent = true block := "" if eventName != "" { block = "event: " + eventName + "\n" @@ -6107,6 +6180,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } usagePatch := s.extractSSEUsagePatch(event) + if anthropicStreamEventIsTerminal(eventName, dataLine) { + sawTerminalEvent = true + } if !eventChanged { block := "" if eventName != "" { @@ -6140,18 +6216,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http case ev, ok := <-events: if !ok { // 上游完成,返回结果 + if !sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil } if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage") - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) } // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage if clientDisconnected { - logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) } // 客户端未断开,正常的错误处理 if errors.Is(ev.err, bufio.ErrTooLong) { @@ -6209,9 +6289,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http continue } if clientDisconnected { - // 客户端已断开,上游也超时了,返回已收集的 usage - logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage") - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 @@ -6557,15 +6635,16 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { - Result *ForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 } // APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage @@ -6574,6 +6653,14 @@ type APIKeyQuotaUpdater interface { UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error } +type apiKeyAuthCacheInvalidator interface { + InvalidateAuthCacheByKey(ctx context.Context, key string) +} + +type usageLogBestEffortWriter interface { + CreateBestEffort(ctx context.Context, log *UsageLog) error +} + // postUsageBillingParams 统一扣费所需的参数 type postUsageBillingParams struct { Cost *CostBreakdown @@ -6581,6 +6668,7 @@ type postUsageBillingParams struct { APIKey *APIKey Account *Account Subscription *UserSubscription + RequestPayloadHash string IsSubscriptionBill bool AccountRateMultiplier float64 APIKeyService APIKeyQuotaUpdater @@ -6592,19 +6680,22 @@ type postUsageBillingParams struct { // - API Key 限速用量更新 // - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率) func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + cost := p.Cost // 1. 订阅 / 余额扣费 if p.IsSubscriptionBill { if cost.TotalCost > 0 { - if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil { + if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) } deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost) } } else { if cost.ActualCost > 0 { - if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil { + if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil { slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) } deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost) @@ -6613,31 +6704,187 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill // 2. API Key 配额 if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { - if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) } } // 3. API Key 限速用量 if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { - if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) } - deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost) } // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { accountCost := cost.TotalCost * p.AccountRateMultiplier - if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil { + if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) } } - // 5. 更新账号最近使用时间 + finalizePostUsageBilling(p, deps) +} + +func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { + if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" { + return requestID + } + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + return "" +} + +func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string { + if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" { + return payloadHash + } + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + return "" +} + +func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand { + if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil { + return nil + } + + cmd := &UsageBillingCommand{ + RequestID: requestID, + APIKeyID: p.APIKey.ID, + UserID: p.User.ID, + AccountID: p.Account.ID, + AccountType: p.Account.Type, + RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash), + } + if usageLog != nil { + cmd.Model = usageLog.Model + cmd.BillingType = usageLog.BillingType + cmd.InputTokens = usageLog.InputTokens + cmd.OutputTokens = usageLog.OutputTokens + cmd.CacheCreationTokens = usageLog.CacheCreationTokens + cmd.CacheReadTokens = usageLog.CacheReadTokens + cmd.ImageCount = usageLog.ImageCount + if usageLog.MediaType != nil { + cmd.MediaType = *usageLog.MediaType + } + if usageLog.ServiceTier != nil { + cmd.ServiceTier = *usageLog.ServiceTier + } + if usageLog.ReasoningEffort != nil { + cmd.ReasoningEffort = *usageLog.ReasoningEffort + } + if usageLog.SubscriptionID != nil { + cmd.SubscriptionID = usageLog.SubscriptionID + } + } + + if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { + cmd.SubscriptionID = &p.Subscription.ID + cmd.SubscriptionCost = p.Cost.TotalCost + } else if p.Cost.ActualCost > 0 { + cmd.BalanceCost = p.Cost.ActualCost + } + + if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + cmd.APIKeyQuotaCost = p.Cost.ActualCost + } + if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + cmd.APIKeyRateLimitCost = p.Cost.ActualCost + } + if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { + cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier + } + + cmd.Normalize() + return cmd +} + +func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) { + if p == nil || deps == nil { + return false, nil + } + + cmd := buildUsageBillingCommand(requestID, usageLog, p) + if cmd == nil || cmd.RequestID == "" || repo == nil { + postUsageBilling(ctx, p, deps) + return true, nil + } + + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + result, err := repo.Apply(billingCtx, cmd) + if err != nil { + return false, err + } + + if result == nil || !result.Applied { + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) + return false, nil + } + + if result.APIKeyQuotaExhausted { + if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" { + invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key) + } + } + + finalizePostUsageBilling(p, deps) + return true, nil +} + +func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { + if p == nil || p.Cost == nil || deps == nil { + return + } + + if p.IsSubscriptionBill { + if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost) + } + } else if p.Cost.ActualCost > 0 && p.User != nil { + deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) + } + + if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() { + deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost) + } + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) } +func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) { + base := context.Background() + if ctx != nil { + base = context.WithoutCancel(ctx) + } + return context.WithTimeout(base, postUsageBillingTimeout) +} + +func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if !stream { + return ctx, func() {} + } + if ctx == nil { + return context.Background(), func() {} + } + return context.WithoutCancel(ctx), func() {} +} + // billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) type billingDeps struct { accountRepo AccountRepository @@ -6657,6 +6904,28 @@ func (s *GatewayService) billingDeps() *billingDeps { } } +func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) { + if repo == nil || usageLog == nil { + return + } + usageCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + if writer, ok := repo.(usageLogBestEffortWriter); ok { + if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil { + logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr) + } + } + return + } + + if _, err := repo.Create(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + } +} + // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { result := input.Result @@ -6758,11 +7027,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu mediaType = &result.MediaType } accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: result.Model, InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -6807,33 +7077,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) - if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) - } - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil - - if shouldBill { - postUsageBilling(ctx, &postUsageBillingParams{ + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ Cost: cost, User: user, APIKey: apiKey, Account: account, Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, - }, s.billingDeps()) - } else { - s.deferredService.ScheduleLastUsedUpdate(account.ID) + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") return nil } @@ -6844,13 +7113,14 @@ type RecordUsageLongContextInput struct { APIKey *APIKey User *User Account *Account - Subscription *UserSubscription // 可选:订阅信息 - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - LongContextThreshold int // 长上下文阈值(如 200000) - LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService *APIKeyService // API Key 配额服务(可选) + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + LongContextThreshold int // 长上下文阈值(如 200000) + LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) } // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) @@ -6933,11 +7203,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * imageSize = &result.ImageSize } accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: result.Model, InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -6981,33 +7252,32 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) - if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) - } - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil - - if shouldBill { - postUsageBilling(ctx, &postUsageBillingParams{ + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ Cost: cost, User: user, APIKey: apiKey, Account: account, Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, - }, s.billingDeps()) - } else { - s.deferredService.ScheduleLastUsedUpdate(account.ID) + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") return nil } diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go index cd690cbd..b1584827 100644 --- a/backend/internal/service/gateway_streaming_test.go +++ b/backend/internal/service/gateway_streaming_test.go @@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) { result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) _ = pr.Close() - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") require.NotNil(t, result) } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 9529462e..f05fa5f5 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -7,35 +7,63 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/stretchr/testify/require" ) type openAIRecordUsageLogRepoStub struct { UsageLogRepository - inserted bool - err error - calls int - lastLog *UsageLog + inserted bool + err error + calls int + lastLog *UsageLog + lastCtxErr error } func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { s.calls++ s.lastLog = log + s.lastCtxErr = ctx.Err() return s.inserted, s.err } +type openAIRecordUsageBillingRepoStub struct { + UsageBillingRepository + + result *UsageBillingApplyResult + err error + calls int + lastCmd *UsageBillingCommand + lastCtxErr error +} + +func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) { + s.calls++ + s.lastCmd = cmd + s.lastCtxErr = ctx.Err() + if s.err != nil { + return nil, s.err + } + if s.result != nil { + return s.result, nil + } + return &UsageBillingApplyResult{Applied: true}, nil +} + type openAIRecordUsageUserRepoStub struct { UserRepository deductCalls int deductErr error lastAmount float64 + lastCtxErr error } func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error { s.deductCalls++ s.lastAmount = amount + s.lastCtxErr = ctx.Err() return s.deductErr } @@ -44,29 +72,35 @@ type openAIRecordUsageSubRepoStub struct { incrementCalls int incrementErr error + lastCtxErr error } func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { s.incrementCalls++ + s.lastCtxErr = ctx.Err() return s.incrementErr } type openAIRecordUsageAPIKeyQuotaStub struct { - quotaCalls int - rateLimitCalls int - err error - lastAmount float64 + quotaCalls int + rateLimitCalls int + err error + lastAmount float64 + lastQuotaCtxErr error + lastRateLimitCtxErr error } func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { s.quotaCalls++ s.lastAmount = cost + s.lastQuotaCtxErr = ctx.Err() return s.err } func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { s.rateLimitCalls++ s.lastAmount = cost + s.lastRateLimitCtxErr = ctx.Err() return s.err } @@ -93,23 +127,38 @@ func i64p(v int64) *int64 { func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService { cfg := &config.Config{} cfg.Default.RateMultiplier = 1.1 + svc := NewOpenAIGatewayService( + nil, + usageRepo, + nil, + userRepo, + subRepo, + rateRepo, + nil, + cfg, + nil, + nil, + NewBillingService(cfg, nil), + nil, + &BillingCacheService{}, + nil, + &DeferredService{}, + nil, + ) + svc.userGroupRateResolver = newUserGroupRateResolver( + rateRepo, + nil, + resolveUserGroupRateCacheTTL(cfg), + nil, + "service.openai_gateway.test", + ) + return svc +} - return &OpenAIGatewayService{ - usageLogRepo: usageRepo, - userRepo: userRepo, - userSubRepo: subRepo, - cfg: cfg, - billingService: NewBillingService(cfg, nil), - billingCacheService: &BillingCacheService{}, - deferredService: &DeferredService{}, - userGroupRateResolver: newUserGroupRateResolver( - rateRepo, - nil, - resolveUserGroupRateCacheTTL(cfg), - nil, - "service.openai_gateway.test", - ), - } +func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService { + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo) + svc.usageBillingRepo = billingRepo + return svc } func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown { @@ -252,9 +301,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: false} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}} userRepo := &openAIRecordUsageUserRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{} - svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ Result: &OpenAIForwardResult{ @@ -272,11 +322,254 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin }) require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) require.Equal(t, 1, usageRepo.calls) require.Equal(t, 0, userRepo.deductCalls) require.Equal(t, 0, subRepo.incrementCalls) } +func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_duplicate_billing_key", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10045, + Quota: 100, + }, + User: &User{ID: 20045}, + Account: &Account{ID: 30045}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 0, quotaSvc.quotaCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) { + usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4} + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_usage_log_error", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10041}, + User: &User{ID: 20041}, + Account: &Account{ID: 30041}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_not_persisted", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10043, + Quota: 100, + }, + User: &User{ID: 20043}, + Account: &Account{ID: 30043}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 1, quotaSvc.quotaCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) { + usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2} + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_detached_billing_ctx", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10042, + Quota: 100, + }, + User: &User{ID: 20042}, + Account: &Account{ID: 30042}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_detached_billing_repo_ctx", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10046}, + User: &User{ID: 20046}, + Account: &Account{ID: 30046}, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.NoError(t, billingRepo.lastCtxErr) + require.Equal(t, 1, usageRepo.calls) + require.NoError(t, usageRepo.lastCtxErr) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + + payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`)) + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "openai_payload_hash", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "gpt-5", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + RequestPayloadHash: payloadHash, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash) +} + +func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback") + err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10047}, + User: &User{ID: 20047}, + Account: &Account{ID: 30047}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_billing_fail", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10048}, + User: &User{ID: 20048}, + Account: &Account{ID: 30048}, + }) + + require.Error(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 0, usageRepo.calls) +} + func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) { usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2} usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 44cfc83a..241c5cd6 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -259,6 +259,7 @@ type openAIWSRetryMetrics struct { type OpenAIGatewayService struct { accountRepo AccountRepository usageLogRepo UsageLogRepository + usageBillingRepo UsageBillingRepository userRepo UserRepository userSubRepo UserSubscriptionRepository cache GatewayCache @@ -295,6 +296,7 @@ type OpenAIGatewayService struct { func NewOpenAIGatewayService( accountRepo AccountRepository, usageLogRepo UsageLogRepository, + usageBillingRepo UsageBillingRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, userGroupRateRepo UserGroupRateRepository, @@ -312,6 +314,7 @@ func NewOpenAIGatewayService( svc := &OpenAIGatewayService{ accountRepo: accountRepo, usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, userRepo: userRepo, userSubRepo: userSubRepo, cache: cache, @@ -2014,7 +2017,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -2206,7 +2211,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( return nil, err } - upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -2543,6 +2550,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( var firstTokenMs *int clientDisconnected := false sawDone := false + sawTerminalEvent := false upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) scanner := bufio.NewScanner(resp.Body) @@ -2562,6 +2570,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( if trimmedData == "[DONE]" { sawDone = true } + if openAIStreamEventIsTerminal(trimmedData) { + sawTerminalEvent = true + } if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms @@ -2579,19 +2590,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( } } if err := scanner.Err(); err != nil { - if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err) + if sawTerminalEvent { return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil } + if clientDisconnected { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) + } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.openai_gateway", - "[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v", - account.ID, - upstreamRequestID, - err, - ctx.Err(), - ) - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err) } if errors.Is(err, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) @@ -2605,12 +2611,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) } - if !clientDisconnected && !sawDone && ctx.Err() == nil { + if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { logger.FromContext(ctx).With( zap.String("component", "service.openai_gateway"), zap.Int64("account_id", account.ID), zap.String("upstream_request_id", upstreamRequestID), ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event") } return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil @@ -3030,6 +3037,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp // 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。 errorEventSent := false clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage + sawTerminalEvent := false sendErrorEvent := func(reason string) { if errorEventSent || clientDisconnected { return @@ -3060,22 +3068,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") } } + if !sawTerminalEvent { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } return resultWithUsage(), nil } handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { if scanErr == nil { return nil, nil, false } + if sawTerminalEvent { + logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr) + return resultWithUsage(), nil, true + } // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { - logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") - return resultWithUsage(), nil, true + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true } // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr) - return resultWithUsage(), nil, true + return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true } if errors.Is(scanErr, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) @@ -3098,6 +3111,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } dataBytes := []byte(data) + if openAIStreamEventIsTerminal(data) { + sawTerminalEvent = true + } // Correct Codex tool calls if needed (apply_patch -> edit, etc.) if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { @@ -3214,8 +3230,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp continue } if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage") - return resultWithUsage(), nil + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 @@ -3313,11 +3328,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { return } - // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 - if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) { + // 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。 + if len(data) < 72 { return } - if gjson.GetBytes(data, "type").String() != "response.completed" { + eventType := gjson.GetBytes(data, "type").String() + if eventType != "response.completed" && eventType != "response.done" { return } @@ -3670,14 +3686,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { - Result *OpenAIForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - APIKeyService APIKeyQuotaUpdater + Result *OpenAIForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string + APIKeyService APIKeyQuotaUpdater } // RecordUsage records usage and deducts balance @@ -3743,11 +3760,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Create usage log durationMs := int(result.Duration.Milliseconds()) accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: billingModel, ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort, @@ -3788,29 +3806,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil - - if shouldBill { - postUsageBilling(ctx, &postUsageBillingParams{ + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ Cost: cost, User: user, APIKey: apiKey, Account: account, Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, - }, s.billingDeps()) - } else { - s.deferredService.ScheduleLastUsedUpdate(account.ID) + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") return nil } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 43e2f39d..9e2f33f2 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) { } } -func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { +func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Gateway: config.GatewayConfig{ @@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { } _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") - if err != nil { - t.Fatalf("expected nil error, got %v", err) + if err == nil || !strings.Contains(err.Error(), "stream usage incomplete") { + t.Fatalf("expected incomplete stream error, got %v", err) } if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") { t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) @@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) { } } +func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + }() + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + if err == nil || !strings.Contains(err.Error(), "missing terminal event") { + t.Fatalf("expected missing terminal event error, got %v", err) + } +} + +func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + }() + + _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now()) + _ = pr.Close() + if err == nil || !strings.Contains(err.Error(), "missing terminal event") { + t.Fatalf("expected missing terminal event error, got %v", err) + } +} + +func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now()) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 2, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + require.Equal(t, 1, result.usage.CacheReadInputTokens) +} + func TestOpenAIStreamingTooLong(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { go func() { defer func() { _ = pw.Close() }() - _, _ = pw.Write([]byte("data: {}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{}}\n\n")) }() _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") @@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) { require.Equal(t, 3, usage.InputTokens) require.Equal(t, 5, usage.OutputTokens) require.Equal(t, 2, usage.CacheReadInputTokens) + + // done 事件同样可能携带最终 usage + svc.parseSSEUsage(`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`, usage) + require.Equal(t, 13, usage.InputTokens) + require.Equal(t, 15, usage.OutputTokens) + require.Equal(t, 4, usage.CacheReadInputTokens) } func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index 7295b13d..08eb397b 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -392,6 +392,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, cfg, nil, nil, diff --git a/backend/internal/service/usage_billing.go b/backend/internal/service/usage_billing.go new file mode 100644 index 00000000..73b05743 --- /dev/null +++ b/backend/internal/service/usage_billing.go @@ -0,0 +1,110 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "strings" +) + +var ErrUsageBillingRequestIDRequired = errors.New("usage billing request_id is required") +var ErrUsageBillingRequestConflict = errors.New("usage billing request fingerprint conflict") + +// UsageBillingCommand describes one billable request that must be applied at most once. +type UsageBillingCommand struct { + RequestID string + APIKeyID int64 + RequestFingerprint string + RequestPayloadHash string + + UserID int64 + AccountID int64 + SubscriptionID *int64 + AccountType string + Model string + ServiceTier string + ReasoningEffort string + BillingType int8 + InputTokens int + OutputTokens int + CacheCreationTokens int + CacheReadTokens int + ImageCount int + MediaType string + + BalanceCost float64 + SubscriptionCost float64 + APIKeyQuotaCost float64 + APIKeyRateLimitCost float64 + AccountQuotaCost float64 +} + +func (c *UsageBillingCommand) Normalize() { + if c == nil { + return + } + c.RequestID = strings.TrimSpace(c.RequestID) + if strings.TrimSpace(c.RequestFingerprint) == "" { + c.RequestFingerprint = buildUsageBillingFingerprint(c) + } +} + +func buildUsageBillingFingerprint(c *UsageBillingCommand) string { + if c == nil { + return "" + } + raw := fmt.Sprintf( + "%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f", + c.UserID, + c.AccountID, + c.APIKeyID, + strings.TrimSpace(c.AccountType), + strings.TrimSpace(c.Model), + strings.TrimSpace(c.ServiceTier), + strings.TrimSpace(c.ReasoningEffort), + c.BillingType, + c.InputTokens, + c.OutputTokens, + c.CacheCreationTokens, + c.CacheReadTokens, + c.ImageCount, + strings.TrimSpace(c.MediaType), + valueOrZero(c.SubscriptionID), + c.BalanceCost, + c.SubscriptionCost, + c.APIKeyQuotaCost, + c.APIKeyRateLimitCost, + c.AccountQuotaCost, + ) + if payloadHash := strings.TrimSpace(c.RequestPayloadHash); payloadHash != "" { + raw += "|" + payloadHash + } + sum := sha256.Sum256([]byte(raw)) + return hex.EncodeToString(sum[:]) +} + +func HashUsageRequestPayload(payload []byte) string { + if len(payload) == 0 { + return "" + } + sum := sha256.Sum256(payload) + return hex.EncodeToString(sum[:]) +} + +func valueOrZero(v *int64) int64 { + if v == nil { + return 0 + } + return *v +} + +type UsageBillingApplyResult struct { + Applied bool + APIKeyQuotaExhausted bool +} + +type UsageBillingRepository interface { + Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) +} diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go index 0fdbfd47..17f21bef 100644 --- a/backend/internal/service/usage_cleanup_service_test.go +++ b/backend/internal/service/usage_cleanup_service_test.go @@ -56,7 +56,8 @@ type cleanupRepoStub struct { } type dashboardRepoStub struct { - recomputeErr error + recomputeErr error + recomputeCalls int } func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error { @@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time. } func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + s.recomputeCalls++ return s.recomputeErr } @@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti return nil } +func (s *dashboardRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + return nil +} + func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { return nil } @@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) { } func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) { + dashboardRepo := &dashboardRepoStub{recomputeErr: errors.New("recompute failed")} repo := &cleanupRepoStub{ deleteQueue: []cleanupDeleteResponse{ {deleted: 0}, }, } - dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ - DashboardAgg: config.DashboardAggregationConfig{Enabled: false}, + dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{ + DashboardAgg: config.DashboardAggregationConfig{Enabled: true}, }) cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} svc := NewUsageCleanupService(repo, nil, dashboard, cfg) @@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) { repo.mu.Lock() defer repo.mu.Unlock() require.Len(t, repo.markSucceeded, 1) + require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond) } func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) { + dashboardRepo := &dashboardRepoStub{} repo := &cleanupRepoStub{ deleteQueue: []cleanupDeleteResponse{ {deleted: 0}, }, } - dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ + dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{ DashboardAgg: config.DashboardAggregationConfig{Enabled: true}, }) cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} @@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) { repo.mu.Lock() defer repo.mu.Unlock() require.Len(t, repo.markSucceeded, 1) + require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond) } func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) { diff --git a/backend/internal/service/usage_log_create_result.go b/backend/internal/service/usage_log_create_result.go new file mode 100644 index 00000000..5e18b44c --- /dev/null +++ b/backend/internal/service/usage_log_create_result.go @@ -0,0 +1,60 @@ +package service + +import "errors" + +type usageLogCreateDisposition int + +const ( + usageLogCreateDispositionUnknown usageLogCreateDisposition = iota + usageLogCreateDispositionNotPersisted +) + +type UsageLogCreateError struct { + err error + disposition usageLogCreateDisposition +} + +func (e *UsageLogCreateError) Error() string { + if e == nil || e.err == nil { + return "usage log create error" + } + return e.err.Error() +} + +func (e *UsageLogCreateError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func MarkUsageLogCreateNotPersisted(err error) error { + if err == nil { + return nil + } + return &UsageLogCreateError{ + err: err, + disposition: usageLogCreateDispositionNotPersisted, + } +} + +func IsUsageLogCreateNotPersisted(err error) bool { + if err == nil { + return false + } + var target *UsageLogCreateError + if !errors.As(err, &target) { + return false + } + return target.disposition == usageLogCreateDispositionNotPersisted +} + +func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool { + if inserted { + return true + } + if err == nil { + return false + } + return !IsUsageLogCreateNotPersisted(err) +} diff --git a/backend/migrations/071_add_usage_billing_dedup.sql b/backend/migrations/071_add_usage_billing_dedup.sql new file mode 100644 index 00000000..acc28459 --- /dev/null +++ b/backend/migrations/071_add_usage_billing_dedup.sql @@ -0,0 +1,13 @@ +-- 窄表账务幂等键:将“是否已扣费”从 usage_logs 解耦出来 +-- 幂等执行:可重复运行 + +CREATE TABLE IF NOT EXISTS usage_billing_dedup ( + id BIGSERIAL PRIMARY KEY, + request_id VARCHAR(255) NOT NULL, + api_key_id BIGINT NOT NULL, + request_fingerprint VARCHAR(64) NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_billing_dedup_request_api_key + ON usage_billing_dedup (request_id, api_key_id); diff --git a/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql b/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql new file mode 100644 index 00000000..965a3412 --- /dev/null +++ b/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql @@ -0,0 +1,7 @@ +-- usage_billing_dedup 是按时间追加写入的幂等窄表。 +-- 使用 BRIN 支撑按 created_at 的批量保留期清理,尽量降低写放大。 +-- 使用 CONCURRENTLY 避免在热表上长时间阻塞写入。 + +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_billing_dedup_created_at_brin + ON usage_billing_dedup + USING BRIN (created_at); diff --git a/backend/migrations/073_add_usage_billing_dedup_archive.sql b/backend/migrations/073_add_usage_billing_dedup_archive.sql new file mode 100644 index 00000000..d156d4eb --- /dev/null +++ b/backend/migrations/073_add_usage_billing_dedup_archive.sql @@ -0,0 +1,10 @@ +-- 冷归档旧账务幂等键,缩小热表索引与清理范围,同时不丢失长期去重能力。 + +CREATE TABLE IF NOT EXISTS usage_billing_dedup_archive ( + request_id VARCHAR(255) NOT NULL, + api_key_id BIGINT NOT NULL, + request_fingerprint VARCHAR(64) NOT NULL, + created_at TIMESTAMPTZ NOT NULL, + archived_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (request_id, api_key_id) +); diff --git a/deploy/build_image.sh b/deploy/build_image.sh old mode 100755 new mode 100644 diff --git a/deploy/install-datamanagementd.sh b/deploy/install-datamanagementd.sh old mode 100755 new mode 100644 From addefe79e13a47bcecba575dac27f1d22b6c761d Mon Sep 17 00:00:00 2001 From: ius Date: Thu, 12 Mar 2026 17:03:21 +0800 Subject: [PATCH 03/18] fix: align docker health checks with runtime image --- deploy/Dockerfile | 2 +- deploy/docker-compose.local.yml | 2 +- deploy/docker-compose.standalone.yml | 2 +- deploy/docker-compose.yml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deploy/Dockerfile b/deploy/Dockerfile index ffe815e5..0f4f1de9 100644 --- a/deploy/Dockerfile +++ b/deploy/Dockerfile @@ -105,7 +105,7 @@ EXPOSE 8080 # Health check HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ - CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1 + CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1 # Run the application ENTRYPOINT ["/app/sub2api"] diff --git a/deploy/docker-compose.local.yml b/deploy/docker-compose.local.yml index 0ef397df..d404ac0b 100644 --- a/deploy/docker-compose.local.yml +++ b/deploy/docker-compose.local.yml @@ -154,7 +154,7 @@ services: networks: - sub2api-network healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] interval: 30s timeout: 10s retries: 3 diff --git a/deploy/docker-compose.standalone.yml b/deploy/docker-compose.standalone.yml index 7676fb97..df0ccfcc 100644 --- a/deploy/docker-compose.standalone.yml +++ b/deploy/docker-compose.standalone.yml @@ -94,7 +94,7 @@ services: - GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-} - ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-} healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] interval: 30s timeout: 10s retries: 3 diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index e5c97bf8..acd21fd9 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -146,7 +146,7 @@ services: networks: - sub2api-network healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] interval: 30s timeout: 10s retries: 3 From 8d4d3b03bb22597cbcebef37cf078dc032caf593 Mon Sep 17 00:00:00 2001 From: ius Date: Thu, 12 Mar 2026 17:08:57 +0800 Subject: [PATCH 04/18] fix: remove unused gateway usage helpers --- backend/internal/service/gateway_service.go | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 5e23c9b1..f40119f7 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -107,22 +107,6 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() } -func claudeUsageHasAnyTokens(usage *ClaudeUsage) bool { - return usage != nil && (usage.InputTokens > 0 || - usage.OutputTokens > 0 || - usage.CacheCreationInputTokens > 0 || - usage.CacheReadInputTokens > 0 || - usage.CacheCreation5mTokens > 0 || - usage.CacheCreation1hTokens > 0) -} - -func openAIUsageHasAnyTokens(usage *OpenAIUsage) bool { - return usage != nil && (usage.InputTokens > 0 || - usage.OutputTokens > 0 || - usage.CacheCreationInputTokens > 0 || - usage.CacheReadInputTokens > 0) -} - func openAIStreamEventIsTerminal(data string) bool { trimmed := strings.TrimSpace(data) if trimmed == "" { From cdb64b0d337b147ff7218eab6b069fe5ea19f83b Mon Sep 17 00:00:00 2001 From: kunish Date: Thu, 12 Mar 2026 17:10:01 +0800 Subject: [PATCH 05/18] fix: remove SSE termination marker from DefaultStopSequences The SSE stream termination marker string was incorrectly included in DefaultStopSequences, causing Gemini to prematurely stop generating output whenever the model produced text containing that marker. The SSE-level protocol filtering in stream_transformer.go already handles this marker correctly; it should not be a stop sequence for the model's text generation. --- backend/internal/pkg/antigravity/gemini_types.go | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 0ff24a1f..1a0ca5bb 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -189,6 +189,5 @@ var DefaultStopSequences = []string{ "<|user|>", "<|endoftext|>", "<|end_of_turn|>", - "[DONE]", "\n\nHuman:", } From e97fd7e81c068e6120f1dbc0cd379acb5b1cd10b Mon Sep 17 00:00:00 2001 From: ius Date: Thu, 12 Mar 2026 17:22:01 +0800 Subject: [PATCH 06/18] test: align oauth passthrough stream expectations --- .../service/openai_oauth_passthrough_test.go | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 6fbd2469..f51a7491 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -439,7 +439,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) headers := make(http.Header) headers.Set("Content-Type", "application/json") @@ -453,7 +453,14 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes resp := &http.Response{ StatusCode: http.StatusOK, Header: headers, - Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"h"}`, + "", + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), } upstream := &httpUpstreamRecorder{resp: resp} @@ -895,7 +902,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_InfoWhenStreamEndsWithoutDone(t * } _, err := svc.Forward(context.Background(), c, account, originalBody) - require.NoError(t, err) + require.EqualError(t, err, "stream usage incomplete: missing terminal event") require.True(t, logSink.ContainsMessage("上游流在未收到 [DONE] 时结束,疑似断流")) require.True(t, logSink.ContainsMessageAtLevel("上游流在未收到 [DONE] 时结束,疑似断流", "info")) require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-truncate")) @@ -911,11 +918,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *t c.Request.Header.Set("x-stainless-timeout", "120000") c.Request.Header.Set("X-Test", "keep") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) resp := &http.Response{ StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-default"}}, - Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-filter-default"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), } upstream := &httpUpstreamRecorder{resp: resp} svc := &OpenAIGatewayService{ @@ -952,11 +964,16 @@ func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured c.Request.Header.Set("x-stainless-timeout", "120000") c.Request.Header.Set("X-Test", "keep") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) resp := &http.Response{ StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-allow"}}, - Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-filter-allow"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), } upstream := &httpUpstreamRecorder{resp: resp} svc := &OpenAIGatewayService{ From 18ba8d91669510db7d9a3e7463a07c1c9718c3b6 Mon Sep 17 00:00:00 2001 From: ius Date: Thu, 12 Mar 2026 17:42:41 +0800 Subject: [PATCH 07/18] fix: stabilize repository integration paths --- .../repository/fixtures_integration_test.go | 36 ++++++++++++++ backend/internal/repository/usage_log_repo.go | 47 +++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/backend/internal/repository/fixtures_integration_test.go b/backend/internal/repository/fixtures_integration_test.go index 23adb4e4..80b9cab6 100644 --- a/backend/internal/repository/fixtures_integration_test.go +++ b/backend/internal/repository/fixtures_integration_test.go @@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se SetKey(k.Key). SetName(k.Name). SetStatus(k.Status) + if k.Quota != 0 { + create.SetQuota(k.Quota) + } + if k.QuotaUsed != 0 { + create.SetQuotaUsed(k.QuotaUsed) + } + if k.RateLimit5h != 0 { + create.SetRateLimit5h(k.RateLimit5h) + } + if k.RateLimit1d != 0 { + create.SetRateLimit1d(k.RateLimit1d) + } + if k.RateLimit7d != 0 { + create.SetRateLimit7d(k.RateLimit7d) + } + if k.Usage5h != 0 { + create.SetUsage5h(k.Usage5h) + } + if k.Usage1d != 0 { + create.SetUsage1d(k.Usage1d) + } + if k.Usage7d != 0 { + create.SetUsage7d(k.Usage7d) + } + if k.Window5hStart != nil { + create.SetWindow5hStart(*k.Window5hStart) + } + if k.Window1dStart != nil { + create.SetWindow1dStart(*k.Window1dStart) + } + if k.Window7dStart != nil { + create.SetWindow7dStart(*k.Window7dStart) + } + if k.ExpiresAt != nil { + create.SetExpiresAt(*k.ExpiresAt) + } if k.GroupID != nil { create.SetGroupID(*k.GroupID) } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 5e81818b..53ca7d11 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -30,6 +30,45 @@ import ( const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at" +var usageLogInsertArgTypes = [...]string{ + "bigint", + "bigint", + "bigint", + "text", + "text", + "bigint", + "bigint", + "integer", + "integer", + "integer", + "integer", + "integer", + "integer", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "numeric", + "smallint", + "smallint", + "boolean", + "boolean", + "integer", + "integer", + "text", + "text", + "integer", + "text", + "text", + "text", + "text", + "boolean", + "timestamptz", +} + // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ "hour": "YYYY-MM-DD HH24:00", @@ -713,6 +752,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage _, _ = query.WriteString(",") _, _ = query.WriteString("$") _, _ = query.WriteString(strconv.Itoa(argPos)) + if i < len(usageLogInsertArgTypes) { + _, _ = query.WriteString("::") + _, _ = query.WriteString(usageLogInsertArgTypes[i]) + } argPos++ } _, _ = query.WriteString(")") @@ -877,6 +920,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( } _, _ = query.WriteString("$") _, _ = query.WriteString(strconv.Itoa(argPos)) + if i < len(usageLogInsertArgTypes) { + _, _ = query.WriteString("::") + _, _ = query.WriteString(usageLogInsertArgTypes[i]) + } argPos++ } _, _ = query.WriteString(")") From 69cafe8674d94a8763cdd5595f9b44f50534a388 Mon Sep 17 00:00:00 2001 From: wanXcode Date: Thu, 12 Mar 2026 17:42:41 +0800 Subject: [PATCH 08/18] fix(dashboard): prefer username in user usage trend --- .../pkg/usagestats/usage_log_types.go | 1 + backend/internal/repository/usage_log_repo.go | 5 ++-- frontend/src/types/index.ts | 1 + frontend/src/views/admin/DashboardView.vue | 24 ++++++++++++------- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 8826c048..04414ebb 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -96,6 +96,7 @@ type UserUsageTrendPoint struct { Date string `json:"date"` UserID int64 `json:"user_id"` Email string `json:"email"` + Username string `json:"username"` Requests int64 `json:"requests"` Tokens int64 `json:"tokens"` Cost float64 `json:"cost"` // 标准计费 diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index c91a68e5..c7adec44 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1114,6 +1114,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e TO_CHAR(u.created_at, '%s') as date, u.user_id, COALESCE(us.email, '') as email, + COALESCE(us.username, '') as username, COUNT(*) as requests, COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens, COALESCE(SUM(u.total_cost), 0) as cost, @@ -1122,7 +1123,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e LEFT JOIN users us ON u.user_id = us.id WHERE u.user_id IN (SELECT user_id FROM top_users) AND u.created_at >= $4 AND u.created_at < $5 - GROUP BY date, u.user_id, us.email + GROUP BY date, u.user_id, us.email, us.username ORDER BY date ASC, tokens DESC `, dateFormat) @@ -1142,7 +1143,7 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e results = make([]UserUsageTrendPoint, 0) for rows.Next() { var row UserUsageTrendPoint - if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil { + if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Username, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil { return nil, err } results = append(results, row) diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 5764134d..2181a011 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -1155,6 +1155,7 @@ export interface UserUsageTrendPoint { date: string user_id: number email: string + username: string requests: number tokens: number cost: number // 标准计费 diff --git a/frontend/src/views/admin/DashboardView.vue b/frontend/src/views/admin/DashboardView.vue index f86b54c5..7986dea4 100644 --- a/frontend/src/views/admin/DashboardView.vue +++ b/frontend/src/views/admin/DashboardView.vue @@ -415,23 +415,29 @@ const lineOptions = computed(() => ({ const userTrendChartData = computed(() => { if (!userTrend.value?.length) return null - // Extract display name from email (part before @) - const getDisplayName = (email: string, userId: number): string => { - if (email && email.includes('@')) { - return email.split('@')[0] + const getDisplayName = (point: UserUsageTrendPoint): string => { + const username = point.username?.trim() + if (username) { + return username } - return t('admin.redeem.userPrefix', { id: userId }) + + const email = point.email?.trim() + if (email) { + return email + } + + return t('admin.redeem.userPrefix', { id: point.user_id }) } - // Group by user - const userGroups = new Map }>() + // Group by user_id to avoid merging different users with the same display name + const userGroups = new Map }>() const allDates = new Set() userTrend.value.forEach((point) => { allDates.add(point.date) - const key = getDisplayName(point.email, point.user_id) + const key = point.user_id if (!userGroups.has(key)) { - userGroups.set(key, { name: key, data: new Map() }) + userGroups.set(key, { name: getDisplayName(point), data: new Map() }) } userGroups.get(key)!.data.set(point.date, point.tokens) }) From 32d25f76fc0b91a15c6ddf730e7b138373e59b7f Mon Sep 17 00:00:00 2001 From: ius Date: Thu, 12 Mar 2026 17:44:57 +0800 Subject: [PATCH 09/18] fix: respect preconfigured usage log batch channels --- backend/internal/repository/usage_log_repo.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 53ca7d11..5a022665 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -377,7 +377,7 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa } func (r *usageLogRepository) ensureCreateBatcher() { - if r == nil || r.db == nil { + if r == nil || r.db == nil || r.createBatchCh != nil { return } r.createBatchOnce.Do(func() { @@ -387,7 +387,7 @@ func (r *usageLogRepository) ensureCreateBatcher() { } func (r *usageLogRepository) ensureBestEffortBatcher() { - if r == nil || r.db == nil { + if r == nil || r.db == nil || r.bestEffortBatchCh != nil { return } r.bestEffortBatchOnce.Do(func() { From 6a685727d0504d07b21e2a763eff59ba9767e12d Mon Sep 17 00:00:00 2001 From: ius Date: Thu, 12 Mar 2026 18:38:09 +0800 Subject: [PATCH 10/18] fix: harden usage billing idempotency and backpressure --- backend/internal/repository/usage_log_repo.go | 46 ++++++---- .../usage_log_repo_integration_test.go | 57 +++++++++++++ .../repository/usage_log_repo_unit_test.go | 26 ++++++ .../service/gateway_record_usage_test.go | 85 +++++++++++++++++++ backend/internal/service/gateway_service.go | 11 ++- .../openai_gateway_record_usage_test.go | 85 +++++++++++++++++++ .../service/usage_log_create_result.go | 22 +++++ 7 files changed, 311 insertions(+), 21 deletions(-) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 5a022665..aab66081 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -246,16 +246,16 @@ func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service. select { case r.bestEffortBatchCh <- req: case <-ctx.Done(): - return ctx.Err() + return service.MarkUsageLogCreateDropped(ctx.Err()) default: - return errors.New("usage log best-effort queue full") + return service.MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")) } select { case err := <-req.resultCh: return err case <-ctx.Done(): - return ctx.Err() + return service.MarkUsageLogCreateDropped(ctx.Err()) } } @@ -355,7 +355,7 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa case <-ctx.Done(): return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) default: - return r.createSingle(ctx, r.sql, log) + return false, service.MarkUsageLogCreateNotPersisted(errors.New("usage log create batch queue full")) } select { @@ -840,27 +840,39 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage cache_ttl_overridden, created_at FROM input - ON CONFLICT (request_id, api_key_id) DO UPDATE - SET request_id = usage_logs.request_id - RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING request_id, api_key_id, id, created_at + ), + resolved AS ( + SELECT + input.input_idx, + input.request_id, + input.api_key_id, + COALESCE(inserted.id, existing.id) AS id, + COALESCE(inserted.created_at, existing.created_at) AS created_at, + (inserted.id IS NOT NULL) AS inserted + FROM input + LEFT JOIN inserted + ON inserted.request_id = input.request_id + AND inserted.api_key_id = input.api_key_id + LEFT JOIN usage_logs existing + ON existing.request_id = input.request_id + AND existing.api_key_id = input.api_key_id ) SELECT COALESCE( json_agg( json_build_object( - 'request_id', inserted.request_id, - 'api_key_id', inserted.api_key_id, - 'id', inserted.id, - 'created_at', inserted.created_at, - 'inserted', inserted.inserted + 'request_id', resolved.request_id, + 'api_key_id', resolved.api_key_id, + 'id', resolved.id, + 'created_at', resolved.created_at, + 'inserted', resolved.inserted ) - ORDER BY input.input_idx + ORDER BY resolved.input_idx ), '[]'::json ) - FROM input - JOIN inserted - ON inserted.request_id = input.request_id - AND inserted.api_key_id = input.api_key_id + FROM resolved `) return query.String(), args } diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 00740878..0383f3bc 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -288,6 +288,34 @@ func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testi }, 3*time.Second, 20*time.Millisecond) } +func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1) + repo.bestEffortBatchCh <- usageLogBestEffortRequest{} + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()}) + + err := repo.CreateBestEffort(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.Error(t, err) + require.True(t, service.IsUsageLogCreateDropped(err)) +} + func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) { client := testEntClient(t) repo := newUsageLogRepositoryWithSQL(client, integrationDB) @@ -317,6 +345,35 @@ func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *t require.True(t, service.IsUsageLogCreateNotPersisted(err)) } +func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.createBatchCh = make(chan usageLogCreateRequest, 1) + repo.createBatchCh <- usageLogCreateRequest{} + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()}) + + inserted, err := repo.Create(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.False(t, inserted) + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) +} + func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) { client := testEntClient(t) repo := newUsageLogRepositoryWithSQL(client, integrationDB) diff --git a/backend/internal/repository/usage_log_repo_unit_test.go b/backend/internal/repository/usage_log_repo_unit_test.go index d0e14ffd..0458902d 100644 --- a/backend/internal/repository/usage_log_repo_unit_test.go +++ b/backend/internal/repository/usage_log_repo_unit_test.go @@ -3,8 +3,11 @@ package repository import ( + "strings" "testing" + "time" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" ) @@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) { }) } } + +func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) { + log := &service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-batch-no-update", + Model: "gpt-5", + InputTokens: 10, + OutputTokens: 5, + TotalCost: 1.2, + ActualCost: 1.2, + CreatedAt: time.Now().UTC(), + } + prepared := prepareUsageLogInsert(log) + + query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{ + usageLogBatchKey(log.RequestID, log.APIKeyID): prepared, + }) + + require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING") + require.NotContains(t, strings.ToUpper(query), "DO UPDATE") +} diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 92e59ac8..475dea6f 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -4,6 +4,8 @@ package service import ( "context" + "errors" + "strings" "testing" "time" @@ -233,6 +235,89 @@ func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID) } +func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123") + ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "upstream-volatile-456", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 506}, + User: &User{ID: 606}, + Account: &Account{ID: 706}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 507}, + User: &User{ID: 607}, + Account: &Account{ID: 707}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:")) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) { + usageRepo := &openAIRecordUsageBestEffortLogRepoStub{ + bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")), + } + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_drop_usage_log", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 508}, + User: &User{ID: 608}, + Account: &Account{ID: 708}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.bestEffortCalls) + require.Equal(t, 0, usageRepo.createCalls) +} + func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{} billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index f40119f7..a87255b0 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -6745,9 +6745,6 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill } func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { - if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" { - return requestID - } if ctx != nil { if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { return "client:" + strings.TrimSpace(clientRequestID) @@ -6756,7 +6753,10 @@ func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) return "local:" + strings.TrimSpace(requestID) } } - return "" + if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" { + return requestID + } + return "generated:" + generateRequestID() } func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string { @@ -6931,6 +6931,9 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage if writer, ok := repo.(usageLogBestEffortWriter); ok { if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil { logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + if IsUsageLogCreateDropped(err) { + return + } if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil { logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr) } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index f05fa5f5..438e9aeb 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -3,6 +3,7 @@ package service import ( "context" "errors" + "strings" "testing" "time" @@ -28,6 +29,31 @@ func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog return s.inserted, s.err } +type openAIRecordUsageBestEffortLogRepoStub struct { + UsageLogRepository + + bestEffortErr error + createErr error + bestEffortCalls int + createCalls int + lastLog *UsageLog + lastCtxErr error +} + +func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error { + s.bestEffortCalls++ + s.lastLog = log + s.lastCtxErr = ctx.Err() + return s.bestEffortErr +} + +func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { + s.createCalls++ + s.lastLog = log + s.lastCtxErr = ctx.Err() + return false, s.createErr +} + type openAIRecordUsageBillingRepoStub struct { UsageBillingRepository @@ -543,6 +569,65 @@ func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsage require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID) } +func TestOpenAIGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-stable-123") + err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "upstream-openai-volatile-456", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10049}, + User: &User{ID: 20049}, + Account: &Account{ID: 30049}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "client:openai-client-stable-123", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "client:openai-client-stable-123", usageRepo.lastLog.RequestID) +} + +func TestOpenAIGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10050}, + User: &User{ID: 20050}, + Account: &Account{ID: 30050}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:")) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID) +} + func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{} billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")} diff --git a/backend/internal/service/usage_log_create_result.go b/backend/internal/service/usage_log_create_result.go index 5e18b44c..1cd84f44 100644 --- a/backend/internal/service/usage_log_create_result.go +++ b/backend/internal/service/usage_log_create_result.go @@ -7,6 +7,7 @@ type usageLogCreateDisposition int const ( usageLogCreateDispositionUnknown usageLogCreateDisposition = iota usageLogCreateDispositionNotPersisted + usageLogCreateDispositionDropped ) type UsageLogCreateError struct { @@ -38,6 +39,16 @@ func MarkUsageLogCreateNotPersisted(err error) error { } } +func MarkUsageLogCreateDropped(err error) error { + if err == nil { + return nil + } + return &UsageLogCreateError{ + err: err, + disposition: usageLogCreateDispositionDropped, + } +} + func IsUsageLogCreateNotPersisted(err error) bool { if err == nil { return false @@ -49,6 +60,17 @@ func IsUsageLogCreateNotPersisted(err error) bool { return target.disposition == usageLogCreateDispositionNotPersisted } +func IsUsageLogCreateDropped(err error) bool { + if err == nil { + return false + } + var target *UsageLogCreateError + if !errors.As(err, &target) { + return false + } + return target.disposition == usageLogCreateDispositionDropped +} + func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool { if inserted { return true From 64b3f3cec183a9334a93aa476fc7a2b96abb4308 Mon Sep 17 00:00:00 2001 From: ius Date: Thu, 12 Mar 2026 18:43:37 +0800 Subject: [PATCH 11/18] test: relocate best-effort usage log stub --- .../service/gateway_record_usage_test.go | 25 +++++++++++++++++++ .../openai_gateway_record_usage_test.go | 25 ------------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 475dea6f..4e7e545a 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -49,6 +49,31 @@ func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogReposi return svc } +type openAIRecordUsageBestEffortLogRepoStub struct { + UsageLogRepository + + bestEffortErr error + createErr error + bestEffortCalls int + createCalls int + lastLog *UsageLog + lastCtxErr error +} + +func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error { + s.bestEffortCalls++ + s.lastLog = log + s.lastCtxErr = ctx.Err() + return s.bestEffortErr +} + +func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { + s.createCalls++ + s.lastLog = log + s.lastCtxErr = ctx.Err() + return false, s.createErr +} + func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} userRepo := &openAIRecordUsageUserRepoStub{} diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 438e9aeb..cd4d58fd 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -29,31 +29,6 @@ func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog return s.inserted, s.err } -type openAIRecordUsageBestEffortLogRepoStub struct { - UsageLogRepository - - bestEffortErr error - createErr error - bestEffortCalls int - createCalls int - lastLog *UsageLog - lastCtxErr error -} - -func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error { - s.bestEffortCalls++ - s.lastLog = log - s.lastCtxErr = ctx.Err() - return s.bestEffortErr -} - -func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { - s.createCalls++ - s.lastLog = log - s.lastCtxErr = ctx.Err() - return false, s.createErr -} - type openAIRecordUsageBillingRepoStub struct { UsageBillingRepository From f16910d6167bb9fb82a37eb735af2a89a4fd48ee Mon Sep 17 00:00:00 2001 From: yexueduxing Date: Thu, 12 Mar 2026 20:52:35 +0800 Subject: [PATCH 12/18] chore: codex transform fixes and feature compatibility --- .../service/openai_codex_transform.go | 72 ++++++++++++++++++- .../service/openai_codex_transform_test.go | 4 +- 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index b0e4d44f..8fffce1b 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -129,6 +129,41 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact } } + // 兼容遗留的 functions 和 function_call,转换为 tools 和 tool_choice + if functionsRaw, ok := reqBody["functions"]; ok { + if functions, k := functionsRaw.([]any); k { + tools := make([]any, 0, len(functions)) + for _, f := range functions { + tools = append(tools, map[string]any{ + "type": "function", + "function": f, + }) + } + reqBody["tools"] = tools + } + delete(reqBody, "functions") + result.Modified = true + } + + if fcRaw, ok := reqBody["function_call"]; ok { + if fcStr, ok := fcRaw.(string); ok { + // e.g. "auto", "none" + reqBody["tool_choice"] = fcStr + } else if fcObj, ok := fcRaw.(map[string]any); ok { + // e.g. {"name": "my_func"} + if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" { + reqBody["tool_choice"] = map[string]any{ + "type": "function", + "function": map[string]any{ + "name": name, + }, + } + } + } + delete(reqBody, "function_call") + result.Modified = true + } + if normalizeCodexTools(reqBody) { result.Modified = true } @@ -303,6 +338,18 @@ func filterCodexInput(input []any, preserveReferences bool) []any { continue } typ, _ := m["type"].(string) + + // 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'" + fixIDPrefix := func(id string) string { + if id == "" || strings.HasPrefix(id, "fc") { + return id + } + if strings.HasPrefix(id, "call_") { + return "fc" + strings.TrimPrefix(id, "call_") + } + return "fc_" + id + } + if typ == "item_reference" { if !preserveReferences { continue @@ -311,6 +358,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any { for key, value := range m { newItem[key] = value } + if id, ok := newItem["id"].(string); ok && id != "" { + newItem["id"] = fixIDPrefix(id) + } filtered = append(filtered, newItem) continue } @@ -330,10 +380,20 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } if isCodexToolCallItemType(typ) { - if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" { + callID, ok := m["call_id"].(string) + if !ok || strings.TrimSpace(callID) == "" { if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" { + callID = id ensureCopy() - newItem["call_id"] = id + newItem["call_id"] = callID + } + } + + if callID != "" { + fixedCallID := fixIDPrefix(callID) + if fixedCallID != callID { + ensureCopy() + newItem["call_id"] = fixedCallID } } } @@ -344,6 +404,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any { if !isCodexToolCallItemType(typ) { delete(newItem, "call_id") } + } else { + if id, ok := newItem["id"].(string); ok && id != "" { + fixedID := fixIDPrefix(id) + if fixedID != id { + ensureCopy() + newItem["id"] = fixedID + } + } } filtered = append(filtered, newItem) diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index c8097aed..df012d7c 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -33,12 +33,12 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { first, ok := input[0].(map[string]any) require.True(t, ok) require.Equal(t, "item_reference", first["type"]) - require.Equal(t, "ref1", first["id"]) + require.Equal(t, "fc_ref1", first["id"]) // 校验 input[1] 为 map,确保后续字段断言安全。 second, ok := input[1].(map[string]any) require.True(t, ok) - require.Equal(t, "o1", second["id"]) + require.Equal(t, "fc_o1", second["id"]) } func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { From 80d8d6c3bc70595c6f4cd43e0355fb6c865d73c2 Mon Sep 17 00:00:00 2001 From: Peter <1tRq4X287b7W7sfKf9GsWI+Peter@noreply.cnb.cool> Date: Fri, 13 Mar 2026 03:41:29 +0800 Subject: [PATCH 13/18] feat(admin): add user spending ranking dashboard view --- .../handler/admin/dashboard_handler.go | 51 ++++ .../dashboard_handler_request_type_test.go | 43 ++++ .../handler/sora_gateway_handler_test.go | 3 + .../pkg/usagestats/usage_log_types.go | 15 ++ backend/internal/repository/usage_log_repo.go | 76 ++++++ .../usage_log_repo_request_type_test.go | 29 +++ backend/internal/server/api_contract_test.go | 4 + backend/internal/server/routes/admin.go | 1 + .../internal/service/account_usage_service.go | 1 + backend/internal/service/dashboard_service.go | 8 + frontend/src/api/admin/dashboard.ts | 21 ++ .../charts/ModelDistributionChart.vue | 223 +++++++++++++++--- frontend/src/i18n/locales/en.ts | 12 + frontend/src/i18n/locales/zh.ts | 12 + frontend/src/types/index.ts | 15 ++ frontend/src/views/admin/DashboardView.vue | 80 ++++++- frontend/src/views/admin/UsageView.vue | 36 ++- 17 files changed, 591 insertions(+), 39 deletions(-) diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index aa82b24f..cc4ef2d0 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -466,9 +466,60 @@ type BatchUsersUsageRequest struct { UserIDs []int64 `json:"user_ids" binding:"required"` } +var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second) var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second) +func parseRankingLimit(raw string) int { + limit, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || limit <= 0 { + return 12 + } + if limit > 50 { + return 50 + } + return limit +} + +// GetUserSpendingRanking handles getting user spending ranking data. +// GET /api/v1/admin/dashboard/users-ranking +func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) { + startTime, endTime := parseTimeRange(c) + limit := parseRankingLimit(c.DefaultQuery("limit", "12")) + + keyRaw, _ := json.Marshal(struct { + Start string `json:"start"` + End string `json:"end"` + Limit int `json:"limit"` + }{ + Start: startTime.UTC().Format(time.RFC3339), + End: endTime.UTC().Format(time.RFC3339), + Limit: limit, + }) + cacheKey := string(keyRaw) + if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit) + if err != nil { + response.Error(c, 500, "Failed to get user spending ranking") + return + } + + payload := gin.H{ + "ranking": ranking.Ranking, + "total_actual_cost": ranking.TotalActualCost, + "start_date": startTime.Format("2006-01-02"), + "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"), + } + dashboardUsersRankingCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) +} + // GetBatchUsersUsage handles getting usage stats for multiple users // POST /api/v1/admin/dashboard/users-usage func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go index 72af6b45..6b363bb5 100644 --- a/backend/internal/handler/admin/dashboard_handler_request_type_test.go +++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go @@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct { trendStream *bool modelRequestType *int16 modelStream *bool + rankingLimit int + ranking []usagestats.UserSpendingRankingItem + rankingTotal float64 } func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters( @@ -49,6 +52,18 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters( return []usagestats.ModelStat{}, nil } +func (s *dashboardUsageRepoCapture) GetUserSpendingRanking( + ctx context.Context, + startTime, endTime time.Time, + limit int, +) (*usagestats.UserSpendingRankingResponse, error) { + s.rankingLimit = limit + return &usagestats.UserSpendingRankingResponse{ + Ranking: s.ranking, + TotalActualCost: s.rankingTotal, + }, nil +} + func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine { gin.SetMode(gin.TestMode) dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) @@ -56,6 +71,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng router := gin.New() router.GET("/admin/dashboard/trend", handler.GetUsageTrend) router.GET("/admin/dashboard/models", handler.GetModelStats) + router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking) return router } @@ -130,3 +146,30 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) { require.Equal(t, http.StatusBadRequest, rec.Code) } + +func TestDashboardUsersRankingLimitAndCache(t *testing.T) { + dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) + repo := &dashboardUsageRepoCapture{ + ranking: []usagestats.UserSpendingRankingItem{ + {UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300}, + }, + rankingTotal: 88.8, + } + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 50, repo.rankingLimit) + require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8") + require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache")) + + req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil) + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + + require.Equal(t, http.StatusOK, rec2.Code) + require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache")) +} diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 688c5d12..d452b6cb 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -343,6 +343,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) { return nil, nil } +func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) { + return nil, nil +} func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, nil } diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 8826c048..78ca9107 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -102,6 +102,21 @@ type UserUsageTrendPoint struct { ActualCost float64 `json:"actual_cost"` // 实际扣除 } +// UserSpendingRankingItem represents a user spending ranking row. +type UserSpendingRankingItem struct { + UserID int64 `json:"user_id"` + Email string `json:"email"` + ActualCost float64 `json:"actual_cost"` // 实际扣除 + Requests int64 `json:"requests"` + Tokens int64 `json:"tokens"` +} + +// UserSpendingRankingResponse represents ranking rows plus total spend for the time range. +type UserSpendingRankingResponse struct { + Ranking []UserSpendingRankingItem `json:"ranking"` + TotalActualCost float64 `json:"total_actual_cost"` +} + // APIKeyUsageTrendPoint represents API key usage trend data point type APIKeyUsageTrendPoint struct { Date string `json:"date"` diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index c91a68e5..7cf23ac0 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1039,6 +1039,10 @@ type ModelStat = usagestats.ModelStat // UserUsageTrendPoint represents user usage trend data point type UserUsageTrendPoint = usagestats.UserUsageTrendPoint +// UserSpendingRankingItem represents a user spending ranking row. +type UserSpendingRankingItem = usagestats.UserSpendingRankingItem +type UserSpendingRankingResponse = usagestats.UserSpendingRankingResponse + // APIKeyUsageTrendPoint represents API key usage trend data point type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint @@ -1154,6 +1158,78 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e return results, nil } +// GetUserSpendingRanking returns user spending ranking aggregated within the time range. +func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (result *UserSpendingRankingResponse, err error) { + if limit <= 0 { + limit = 12 + } + + query := ` + WITH user_spend AS ( + SELECT + u.user_id, + COALESCE(us.email, '') as email, + COALESCE(SUM(u.actual_cost), 0) as actual_cost, + COUNT(*) as requests, + COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens + FROM usage_logs u + LEFT JOIN users us ON u.user_id = us.id + WHERE u.created_at >= $1 AND u.created_at < $2 + GROUP BY u.user_id, us.email + ), + ranked AS ( + SELECT + user_id, + email, + actual_cost, + requests, + tokens, + COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost + FROM user_spend + ORDER BY actual_cost DESC, tokens DESC, user_id ASC + LIMIT $3 + ) + SELECT + user_id, + email, + actual_cost, + requests, + tokens, + total_actual_cost + FROM ranked + ORDER BY actual_cost DESC, tokens DESC, user_id ASC + ` + + rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + result = nil + } + }() + + ranking := make([]UserSpendingRankingItem, 0) + totalActualCost := 0.0 + for rows.Next() { + var row UserSpendingRankingItem + if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost); err != nil { + return nil, err + } + ranking = append(ranking, row) + } + if err = rows.Err(); err != nil { + return nil, err + } + + return &UserSpendingRankingResponse{ + Ranking: ranking, + TotalActualCost: totalActualCost, + }, nil +} + // UserDashboardStats 用户仪表盘统计 type UserDashboardStats = usagestats.UserDashboardStats diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 7d82b4d0..bcb23717 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -248,6 +248,35 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) require.NoError(t, mock.ExpectationsWereMet()) } +func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + + rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}). + AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0). + AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0). + AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0) + + mock.ExpectQuery("WITH user_spend AS \\("). + WithArgs(start, end, 12). + WillReturnRows(rows) + + got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12) + require.NoError(t, err) + require.Equal(t, &usagestats.UserSpendingRankingResponse{ + Ranking: []usagestats.UserSpendingRankingItem{ + {UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900}, + {UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800}, + {UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300}, + }, + TotalActualCost: 40.0, + }, got) + require.NoError(t, mock.ExpectationsWereMet()) +} + func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) { tests := []struct { name string diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 0b36bf66..15c5506d 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -1635,6 +1635,10 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) { + return nil, errors.New("not implemented") +} + func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { logs := r.userLogs[userID] if len(logs) == 0 { diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 9fdb233b..4842be28 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -192,6 +192,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats) dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend) dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) + dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking) dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage) dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation) diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index e4245133..3dd931be 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -47,6 +47,7 @@ type UsageLogRepository interface { GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) + GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index 2af43386..63cad243 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -327,6 +327,14 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end return trend, nil } +func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) { + ranking, err := s.usageRepo.GetUserSpendingRanking(ctx, startTime, endTime, limit) + if err != nil { + return nil, fmt.Errorf("get user spending ranking: %w", err) + } + return ranking, nil +} + func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime) if err != nil { diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index 4393dda3..85200506 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -11,6 +11,7 @@ import type { GroupStat, ApiKeyUsageTrendPoint, UserUsageTrendPoint, + UserSpendingRankingResponse, UsageRequestType } from '@/types' @@ -201,6 +202,11 @@ export interface UserTrendResponse { granularity: string } +export interface UserSpendingRankingParams + extends Pick { + limit?: number +} + /** * Get user usage trend data * @param params - Query parameters for filtering @@ -213,6 +219,20 @@ export async function getUserUsageTrend(params?: UserTrendParams): Promise { + const { data } = await apiClient.get('/admin/dashboard/users-ranking', { + params + }) + return data +} + export interface BatchUserUsageStats { user_id: number today_actual_cost: number @@ -271,6 +291,7 @@ export const dashboardAPI = { getSnapshotV2, getApiKeyUsageTrend, getUserUsageTrend, + getUserSpendingRanking, getBatchUsersUsage, getBatchApiKeysUsage } diff --git a/frontend/src/components/charts/ModelDistributionChart.vue b/frontend/src/components/charts/ModelDistributionChart.vue index 6f80e541..5db5a14f 100644 --- a/frontend/src/components/charts/ModelDistributionChart.vue +++ b/frontend/src/components/charts/ModelDistributionChart.vue @@ -2,38 +2,72 @@

- {{ t('admin.dashboard.modelDistribution') }} + {{ !enableRankingView || activeView === 'model_distribution' + ? t('admin.dashboard.modelDistribution') + : t('admin.dashboard.spendingRankingTitle') }}

-
- - + + +
+
+ + +
-
+ +
-
+
@@ -77,6 +111,70 @@
+
+ {{ t('admin.dashboard.noDataAvailable') }} +
+ +
+ +
+
+ {{ t('admin.dashboard.failedToLoad') }} +
+
+
+ +
+
+ + + + + + + + + + + + + + + + + +
{{ t('admin.dashboard.spendingRankingUser') }}{{ t('admin.dashboard.spendingRankingRequests') }}{{ t('admin.dashboard.spendingRankingTokens') }}{{ t('admin.dashboard.spendingRankingSpend') }}
+
+ + #{{ index + 1 }} + + + {{ getRankingUserLabel(item) }} + +
+
+ {{ formatNumber(item.requests) }} + + {{ formatTokens(item.tokens) }} + + ${{ formatCost(item.actual_cost) }} +
+
+