feat: decouple billing correctness from usage log batching

This commit is contained in:
ius
2026-03-12 16:53:18 +08:00
parent c9debc50b1
commit 611fd884bd
37 changed files with 3379 additions and 330 deletions

View File

@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
sql sqlExecutor
}
const usageLogsCleanupBatchSize = 10000
const usageBillingDedupCleanupBatchSize = 10000
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
if sqlDB == nil {
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
if r == nil || r.sql == nil {
return nil
}
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
dayEnd = dayEnd.Add(24 * time.Hour)
}
if db, ok := r.sql.(*sql.DB); ok {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
}
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
if isPartitioned {
return r.dropUsageLogsPartitions(ctx, cutoff)
}
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
return err
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid
FROM usage_logs
WHERE created_at < $1
LIMIT $2
)
DELETE FROM usage_logs
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageLogsCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageLogsCleanupBatchSize {
return nil
}
}
}
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
FROM usage_billing_dedup
WHERE created_at < $1
LIMIT $2
), archived AS (
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
SELECT request_id, api_key_id, request_fingerprint, created_at
FROM victims
ON CONFLICT (request_id, api_key_id) DO NOTHING
)
DELETE FROM usage_billing_dedup
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageBillingDedupCleanupBatchSize {
return nil
}
}
}
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {

View File

@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// usage_billing_dedup: billing idempotency narrow table
var usageBillingDedupRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
var usageBillingDedupArchiveRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
// settings table should exist
var settingsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
var exists bool
err := tx.QueryRowContext(context.Background(), `
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`, table, index).Scan(&exists)
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
require.True(t, exists, "expected index %s on %s", index, table)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()

View File

@@ -0,0 +1,308 @@
package repository
import (
"context"
"database/sql"
"errors"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type usageBillingRepository struct {
db *sql.DB
}
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
return &usageBillingRepository{db: sqlDB}
}
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
if cmd == nil {
return &service.UsageBillingApplyResult{}, nil
}
if r == nil || r.db == nil {
return nil, errors.New("usage billing repository db is nil")
}
cmd.Normalize()
if cmd.RequestID == "" {
return nil, service.ErrUsageBillingRequestIDRequired
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
if err != nil {
return nil, err
}
if !applied {
return &service.UsageBillingApplyResult{Applied: false}, nil
}
result := &service.UsageBillingApplyResult{Applied: true}
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
tx = nil
return result, nil
}
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
var id int64
err := tx.QueryRowContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
VALUES ($1, $2, $3)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
var existingFingerprint string
if err := tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
return false, err
}
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if err != nil {
return false, err
}
var archivedFingerprint string
err = tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup_archive
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
if err == nil {
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if !errors.Is(err, sql.ErrNoRows) {
return false, err
}
return true, nil
}
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
return err
}
}
if cmd.BalanceCost > 0 {
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
return err
}
}
if cmd.APIKeyQuotaCost > 0 {
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
if err != nil {
return err
}
result.APIKeyQuotaExhausted = exhausted
}
if cmd.APIKeyRateLimitCost > 0 {
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
return err
}
}
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
return err
}
}
return nil
}
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
const updateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
weekly_usage_usd = us.weekly_usage_usd + $1,
monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
`
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrSubscriptionNotFound
}
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE users
SET balance = balance - $1,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, amount, userID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrUserNotFound
}
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
var exhausted bool
err := tx.QueryRowContext(ctx, `
UPDATE api_keys
SET quota_used = quota_used + $1,
status = CASE
WHEN quota > 0
AND status = $3
AND quota_used < quota
AND quota_used + $1 >= quota
THEN $4
ELSE status
END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
if errors.Is(err, sql.ErrNoRows) {
return false, service.ErrAPIKeyNotFound
}
if err != nil {
return false, err
}
return exhausted, nil
}
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, cost, apiKeyID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAPIKeyNotFound
}
return nil
}
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
rows, err := tx.QueryContext(ctx,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
COALESCE((extra->>'quota_limit')::numeric, 0)`,
amount, accountID)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
var newUsed, limit float64
if rows.Next() {
if err := rows.Scan(&newUsed, &limit); err != nil {
return err
}
} else {
if err := rows.Err(); err != nil {
return err
}
return service.ErrAccountNotFound
}
if err := rows.Err(); err != nil {
return err
}
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
return err
}
}
return nil
}

View File

@@ -0,0 +1,279 @@
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-" + uuid.NewString(),
Name: "billing",
Quota: 1,
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
BalanceCost: 1.25,
APIKeyQuotaCost: 1.25,
APIKeyRateLimitCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result1)
require.True(t, result1.Applied)
require.True(t, result1.APIKeyQuotaExhausted)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result2)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan(&quotaUsed))
require.InDelta(t, 1.25, quotaUsed, 0.000001)
var usage5h float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
require.InDelta(t, 1.25, usage5h, 0.000001)
var status string
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
var dedupCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
require.Equal(t, 1, dedupCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
group := mustCreateGroup(t, client, &service.Group{
Name: "usage-billing-group-" + uuid.NewString(),
Platform: service.PlatformAnthropic,
SubscriptionType: service.SubscriptionTypeSubscription,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
GroupID: &group.ID,
Key: "sk-usage-billing-sub-" + uuid.NewString(),
Name: "billing-sub",
})
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: 0,
SubscriptionID: &subscription.ID,
SubscriptionCost: 2.5,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var dailyUsage float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
require.InDelta(t, 2.5, dailyUsage, 0.000001)
}
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
Name: "billing-conflict",
})
requestID := uuid.NewString()
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
})
require.NoError(t, err)
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 2.50,
})
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
}
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-account-" + uuid.NewString(),
Name: "billing-account",
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-quota-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
Extra: map[string]any{
"quota_limit": 100.0,
},
})
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 3.5,
})
require.NoError(t, err)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan(&quotaUsed))
require.InDelta(t, 3.5, quotaUsed, 0.000001)
}
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
ctx := context.Background()
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
oldRequestID := "dedup-old-" + uuid.NewString()
newRequestID := "dedup-new-" + uuid.NewString()
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
newCreatedAt := time.Now().UTC().Add(-time.Hour)
_, err := integrationDB.ExecContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
`,
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
newRequestID, strings.Repeat("b", 64), newCreatedAt,
)
require.NoError(t, err)
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
var oldCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
require.Equal(t, 0, oldCount)
var newCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
require.Equal(t, 1, newCount)
var archivedCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
require.Equal(t, 1, archivedCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-archive-" + uuid.NewString(),
Name: "billing-archive",
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
_, err = integrationDB.ExecContext(ctx, `
UPDATE usage_billing_dedup
SET created_at = $1
WHERE request_id = $2 AND api_key_id = $3
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
require.NoError(t, err)
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
}

View File

@@ -3,12 +3,14 @@ package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -17,11 +19,13 @@ import (
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
@@ -47,18 +51,29 @@ type usageLogRepository struct {
sql sqlExecutor
db *sql.DB
createBatchOnce sync.Once
createBatchCh chan usageLogCreateRequest
createBatchOnce sync.Once
createBatchCh chan usageLogCreateRequest
bestEffortBatchOnce sync.Once
bestEffortBatchCh chan usageLogBestEffortRequest
bestEffortRecent *gocache.Cache
}
const (
usageLogCreateBatchMaxSize = 64
usageLogCreateBatchWindow = 3 * time.Millisecond
usageLogCreateBatchQueueCap = 4096
usageLogCreateCancelWait = 2 * time.Second
usageLogBestEffortBatchMaxSize = 256
usageLogBestEffortBatchWindow = 20 * time.Millisecond
usageLogBestEffortBatchQueueCap = 32768
usageLogBestEffortRecentTTL = 30 * time.Second
)
type usageLogCreateRequest struct {
log *service.UsageLog
prepared usageLogInsertPrepared
shared *usageLogCreateShared
resultCh chan usageLogCreateResult
}
@@ -67,6 +82,12 @@ type usageLogCreateResult struct {
err error
}
type usageLogBestEffortRequest struct {
prepared usageLogInsertPrepared
apiKeyID int64
resultCh chan error
}
type usageLogInsertPrepared struct {
createdAt time.Time
requestID string
@@ -80,6 +101,25 @@ type usageLogBatchState struct {
CreatedAt time.Time
}
type usageLogBatchRow struct {
RequestID string `json:"request_id"`
APIKeyID int64 `json:"api_key_id"`
ID int64 `json:"id"`
CreatedAt time.Time `json:"created_at"`
Inserted bool `json:"inserted"`
}
type usageLogCreateShared struct {
state atomic.Int32
}
const (
usageLogCreateStateQueued int32 = iota
usageLogCreateStateProcessing
usageLogCreateStateCompleted
usageLogCreateStateCanceled
)
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
return newUsageLogRepositoryWithSQL(client, sqlDB)
}
@@ -90,6 +130,7 @@ func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usage
if db, ok := sqlq.(*sql.DB); ok {
repo.db = db
}
repo.bestEffortRecent = gocache.New(usageLogBestEffortRecentTTL, time.Minute)
return repo
}
@@ -124,9 +165,6 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
if tx := dbent.TxFromContext(ctx); tx != nil {
return r.createSingle(ctx, tx.Client(), log)
}
if r.db == nil {
return r.createSingle(ctx, r.sql, log)
}
requestID := strings.TrimSpace(log.RequestID)
if requestID == "" {
return r.createSingle(ctx, r.sql, log)
@@ -135,11 +173,61 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
return r.createBatched(ctx, log)
}
func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error {
if log == nil {
return nil
}
if tx := dbent.TxFromContext(ctx); tx != nil {
_, err := r.createSingle(ctx, tx.Client(), log)
return err
}
if r.db == nil {
_, err := r.createSingle(ctx, r.sql, log)
return err
}
r.ensureBestEffortBatcher()
if r.bestEffortBatchCh == nil {
_, err := r.createSingle(ctx, r.sql, log)
return err
}
req := usageLogBestEffortRequest{
prepared: prepareUsageLogInsert(log),
apiKeyID: log.APIKeyID,
resultCh: make(chan error, 1),
}
if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok {
if _, exists := r.bestEffortRecent.Get(key); exists {
return nil
}
}
select {
case r.bestEffortBatchCh <- req:
case <-ctx.Done():
return ctx.Err()
default:
return errors.New("usage log best-effort queue full")
}
select {
case err := <-req.resultCh:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) {
prepared := prepareUsageLogInsert(log)
if sqlq == nil {
sqlq = r.sql
}
if ctx != nil && ctx.Err() != nil {
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
}
query := `
INSERT INTO usage_logs (
@@ -218,13 +306,15 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
req := usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
shared: &usageLogCreateShared{},
resultCh: make(chan usageLogCreateResult, 1),
}
select {
case r.createBatchCh <- req:
case <-ctx.Done():
return false, ctx.Err()
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
default:
return r.createSingle(ctx, r.sql, log)
}
@@ -233,7 +323,17 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
case res := <-req.resultCh:
return res.inserted, res.err
case <-ctx.Done():
return false, ctx.Err()
if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) {
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
}
timer := time.NewTimer(usageLogCreateCancelWait)
defer timer.Stop()
select {
case res := <-req.resultCh:
return res.inserted, res.err
case <-timer.C:
return false, ctx.Err()
}
}
}
@@ -247,6 +347,16 @@ func (r *usageLogRepository) ensureCreateBatcher() {
})
}
func (r *usageLogRepository) ensureBestEffortBatcher() {
if r == nil || r.db == nil {
return
}
r.bestEffortBatchOnce.Do(func() {
r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap)
go r.runBestEffortBatcher(r.db)
})
}
func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
for {
first, ok := <-r.createBatchCh
@@ -281,6 +391,40 @@ func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
}
}
func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) {
for {
first, ok := <-r.bestEffortBatchCh
if !ok {
return
}
batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize)
batch = append(batch, first)
timer := time.NewTimer(usageLogBestEffortBatchWindow)
bestEffortLoop:
for len(batch) < usageLogBestEffortBatchMaxSize {
select {
case req, ok := <-r.bestEffortBatchCh:
if !ok {
break bestEffortLoop
}
batch = append(batch, req)
case <-timer.C:
break bestEffortLoop
}
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
r.flushBestEffortBatch(db, batch)
}
}
func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) {
if len(batch) == 0 {
return
@@ -293,10 +437,19 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
for _, req := range batch {
if req.log == nil {
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil})
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
continue
}
prepared := prepareUsageLogInsert(req.log)
if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) {
if req.shared.state.Load() == usageLogCreateStateCanceled {
completeUsageLogCreateRequest(req, usageLogCreateResult{
inserted: false,
err: service.MarkUsageLogCreateNotPersisted(context.Canceled),
})
continue
}
}
prepared := req.prepared
if prepared.requestID == "" {
fallback = append(fallback, req)
continue
@@ -310,10 +463,37 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
}
if len(uniqueOrder) > 0 {
insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
if err != nil {
for _, key := range uniqueOrder {
fallback = append(fallback, requestsByKey[key]...)
if safeFallback {
for _, key := range uniqueOrder {
fallback = append(fallback, requestsByKey[key]...)
}
} else {
for _, key := range uniqueOrder {
reqs := requestsByKey[key]
state, hasState := stateMap[key]
inserted := insertedMap[key]
for idx, req := range reqs {
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
if hasState {
req.log.ID = state.ID
req.log.CreatedAt = state.CreatedAt
}
switch {
case inserted && idx == 0:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil})
case inserted:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
case hasState:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
case idx == 0:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err})
default:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
}
}
}
}
} else {
for _, key := range uniqueOrder {
@@ -321,7 +501,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
state, ok := stateMap[key]
if !ok {
for _, req := range reqs {
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
completeUsageLogCreateRequest(req, usageLogCreateResult{
inserted: false,
err: fmt.Errorf("usage log batch state missing for key=%s", key),
})
@@ -332,7 +512,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
req.log.ID = state.ID
req.log.CreatedAt = state.CreatedAt
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
completeUsageLogCreateRequest(req, usageLogCreateResult{
inserted: idx == 0 && insertedMap[key],
err: nil,
})
@@ -345,56 +525,366 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
return
}
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for _, req := range fallback {
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
inserted, err := r.createSingle(fallbackCtx, db, req.log)
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err})
cancel()
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err})
}
}
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) {
func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) {
if len(batch) == 0 {
return
}
type bestEffortGroup struct {
prepared usageLogInsertPrepared
apiKeyID int64
key string
reqs []usageLogBestEffortRequest
}
groupsByKey := make(map[string]*bestEffortGroup, len(batch))
groupOrder := make([]*bestEffortGroup, 0, len(batch))
preparedList := make([]usageLogInsertPrepared, 0, len(batch))
for idx, req := range batch {
prepared := req.prepared
key := fmt.Sprintf("__best_effort_%d", idx)
if prepared.requestID != "" {
key = usageLogBatchKey(prepared.requestID, req.apiKeyID)
}
group, exists := groupsByKey[key]
if !exists {
group = &bestEffortGroup{
prepared: prepared,
apiKeyID: req.apiKeyID,
key: key,
}
groupsByKey[key] = group
groupOrder = append(groupOrder, group)
preparedList = append(preparedList, prepared)
}
group.reqs = append(group.reqs, req)
}
if len(preparedList) == 0 {
for _, req := range batch {
sendUsageLogBestEffortResult(req.resultCh, nil)
}
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
query, args := buildUsageLogBestEffortInsertQuery(preparedList)
if _, err := db.ExecContext(ctx, query, args...); err != nil {
logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err)
for _, group := range groupOrder {
singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared)
if singleErr != nil {
logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr)
} else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
r.bestEffortRecent.SetDefault(group.key, struct{}{})
}
for _, req := range group.reqs {
sendUsageLogBestEffortResult(req.resultCh, singleErr)
}
}
return
}
for _, group := range groupOrder {
if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
r.bestEffortRecent.SetDefault(group.key, struct{}{})
}
for _, req := range group.reqs {
sendUsageLogBestEffortResult(req.resultCh, nil)
}
}
}
func sendUsageLogBestEffortResult(ch chan error, err error) {
if ch == nil {
return
}
select {
case ch <- err:
default:
}
}
func completeUsageLogCreateRequest(req usageLogCreateRequest, res usageLogCreateResult) {
if req.shared != nil {
req.shared.state.Store(usageLogCreateStateCompleted)
}
sendUsageLogCreateResult(req.resultCh, res)
}
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) {
if len(keys) == 0 {
return map[string]bool{}, map[string]usageLogBatchState{}, nil
return map[string]bool{}, map[string]usageLogBatchState{}, false, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, nil, err
var payload []byte
if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil {
return nil, nil, true, err
}
var rows []usageLogBatchRow
if err := json.Unmarshal(payload, &rows); err != nil {
return nil, nil, false, err
}
insertedMap := make(map[string]bool, len(keys))
for rows.Next() {
var (
requestID string
apiKeyID int64
id int64
createdAt time.Time
)
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
_ = rows.Close()
return nil, nil, err
stateMap := make(map[string]usageLogBatchState, len(keys))
for _, row := range rows {
key := usageLogBatchKey(row.RequestID, row.APIKeyID)
insertedMap[key] = row.Inserted
stateMap[key] = usageLogBatchState{
ID: row.ID,
CreatedAt: row.CreatedAt,
}
insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true
}
if err := rows.Err(); err != nil {
_ = rows.Close()
return nil, nil, err
if len(stateMap) != len(keys) {
return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys))
}
_ = rows.Close()
stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey)
if err != nil {
return nil, nil, err
}
return insertedMap, stateMap, nil
return insertedMap, stateMap, false, nil
}
func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) {
var query strings.Builder
_, _ = query.WriteString(`
WITH input (
input_idx,
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*37)
argPos := 1
for idx, key := range keys {
if idx > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("(")
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
args = append(args, idx)
argPos++
prepared := preparedByKey[key]
for i := 0; i < len(prepared.args); i++ {
_, _ = query.WriteString(",")
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
argPos++
}
_, _ = query.WriteString(")")
args = append(args, prepared.args...)
}
_, _ = query.WriteString(`
),
inserted AS (
INSERT INTO usage_logs (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
)
SELECT
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO UPDATE
SET request_id = usage_logs.request_id
RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted
)
SELECT COALESCE(
json_agg(
json_build_object(
'request_id', inserted.request_id,
'api_key_id', inserted.api_key_id,
'id', inserted.id,
'created_at', inserted.created_at,
'inserted', inserted.inserted
)
ORDER BY input.input_idx
),
'[]'::json
)
FROM input
JOIN inserted
ON inserted.request_id = input.request_id
AND inserted.api_key_id = input.api_key_id
`)
return query.String(), args
}
func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) {
var query strings.Builder
_, _ = query.WriteString(`
WITH input (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*36)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("(")
for i := 0; i < len(prepared.args); i++ {
if i > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
argPos++
}
_, _ = query.WriteString(")")
args = append(args, prepared.args...)
}
_, _ = query.WriteString(`
)
INSERT INTO usage_logs (
user_id,
api_key_id,
@@ -432,80 +922,101 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES `)
args := make([]any, 0, len(keys)*36)
argPos := 1
for idx, key := range keys {
if idx > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("(")
prepared := preparedByKey[key]
for i := 0; i < len(prepared.args); i++ {
if i > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
argPos++
}
_, _ = query.WriteString(")")
args = append(args, prepared.args...)
}
_, _ = query.WriteString(`
)
SELECT
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING request_id, api_key_id, id, created_at
`)
return query.String(), args
}
func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) {
var query strings.Builder
_, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `)
args := make([]any, 0, len(keys)*2)
argPos := 1
for idx, key := range keys {
if idx > 0 {
_, _ = query.WriteString(" OR ")
}
prepared := preparedByKey[key]
apiKeyID := prepared.args[1]
_, _ = query.WriteString("(request_id = $")
_, _ = query.WriteString(strconv.Itoa(argPos))
_, _ = query.WriteString(" AND api_key_id = $")
_, _ = query.WriteString(strconv.Itoa(argPos + 1))
_, _ = query.WriteString(")")
args = append(args, prepared.requestID, apiKeyID)
argPos += 2
}
rows, err := db.QueryContext(ctx, query.String(), args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
stateMap := make(map[string]usageLogBatchState, len(keys))
for rows.Next() {
var (
requestID string
apiKeyID int64
id int64
createdAt time.Time
func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error {
_, err := sqlq.ExecContext(ctx, `
INSERT INTO usage_logs (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
)
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
return nil, err
}
stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{
ID: id,
CreatedAt: createdAt,
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return stateMap, nil
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
return err
}
func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
@@ -597,6 +1108,14 @@ func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateRe
}
}
func (r *usageLogRepository) bestEffortRecentKey(requestID string, apiKeyID int64) (string, bool) {
requestID = strings.TrimSpace(requestID)
if requestID == "" || r == nil || r.bestEffortRecent == nil {
return "", false
}
return usageLogBatchKey(requestID, apiKeyID), true
}
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
rows, err := r.sql.QueryContext(ctx, query, id)

View File

@@ -183,6 +183,214 @@ func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
requestID := uuid.NewString()
const total = 8
batch := make([]usageLogCreateRequest, 0, total)
logs := make([]*service.UsageLog, 0, total)
for i := 0; i < total; i++ {
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
logs = append(logs, log)
batch = append(batch, usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
resultCh: make(chan usageLogCreateResult, 1),
})
}
repo.flushCreateBatch(integrationDB, batch)
insertedCount := 0
var firstID int64
for idx, req := range batch {
res := <-req.resultCh
require.NoError(t, res.err)
if res.inserted {
insertedCount++
}
require.NotZero(t, logs[idx].ID)
if idx == 0 {
firstID = logs[idx].ID
} else {
require.Equal(t, firstID, logs[idx].ID)
}
}
require.Equal(t, 1, insertedCount)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
require.NoError(t, repo.CreateBestEffort(ctx, log1))
require.NoError(t, repo.CreateBestEffort(ctx, log2))
require.Eventually(t, func() bool {
var count int
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
return err == nil && count == 1
}, 3*time.Second, 20*time.Millisecond)
}
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
cancel()
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
_, err := repo.createBatched(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
errCh <- err
}()
req := <-repo.createBatchCh
require.NotNil(t, req.shared)
cancel()
err := <-errCh
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
}
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
req := usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
shared: &usageLogCreateShared{},
resultCh: make(chan usageLogCreateResult, 1),
}
req.shared.state.Store(usageLogCreateStateCanceled)
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
res := <-req.resultCh
require.False(t, res.inserted)
require.Error(t, res.err)
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})

View File

@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
NewAnnouncementRepository,
NewAnnouncementReadRepository,
NewUsageLogRepository,
NewUsageBillingRepository,
NewIdempotencyRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository,