mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-27 01:44:48 +08:00
feat: decouple billing correctness from usage log batching
This commit is contained in:
@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
const usageLogsCleanupBatchSize = 10000
|
||||
const usageBillingDedupCleanupBatchSize = 10000
|
||||
|
||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||
if sqlDB == nil {
|
||||
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
if r == nil || r.sql == nil {
|
||||
return nil
|
||||
}
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
if db, ok := r.sql.(*sql.DB); ok {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
|
||||
if isPartitioned {
|
||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||
}
|
||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
||||
return err
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid
|
||||
FROM usage_logs
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM usage_logs
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageLogsCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageLogsCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM usage_billing_dedup
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
), archived AS (
|
||||
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
|
||||
SELECT request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM victims
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
)
|
||||
DELETE FROM usage_billing_dedup
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageBillingDedupCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
|
||||
@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||
|
||||
// usage_billing_dedup: billing idempotency narrow table
|
||||
var usageBillingDedupRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
|
||||
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
|
||||
|
||||
var usageBillingDedupArchiveRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
|
||||
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
|
||||
|
||||
// settings table should exist
|
||||
var settingsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
||||
}
|
||||
|
||||
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
|
||||
t.Helper()
|
||||
|
||||
var exists bool
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename = $1
|
||||
AND indexname = $2
|
||||
)
|
||||
`, table, index).Scan(&exists)
|
||||
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
|
||||
require.True(t, exists, "expected index %s on %s", index, table)
|
||||
}
|
||||
|
||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
308
backend/internal/repository/usage_billing_repo.go
Normal file
308
backend/internal/repository/usage_billing_repo.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type usageBillingRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
|
||||
return &usageBillingRepository{db: sqlDB}
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
|
||||
if cmd == nil {
|
||||
return &service.UsageBillingApplyResult{}, nil
|
||||
}
|
||||
if r == nil || r.db == nil {
|
||||
return nil, errors.New("usage billing repository db is nil")
|
||||
}
|
||||
|
||||
cmd.Normalize()
|
||||
if cmd.RequestID == "" {
|
||||
return nil, service.ErrUsageBillingRequestIDRequired
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !applied {
|
||||
return &service.UsageBillingApplyResult{Applied: false}, nil
|
||||
}
|
||||
|
||||
result := &service.UsageBillingApplyResult{Applied: true}
|
||||
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tx = nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
|
||||
var id int64
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id
|
||||
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
var existingFingerprint string
|
||||
if err := tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var archivedFingerprint string
|
||||
err = tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup_archive
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
|
||||
if err == nil {
|
||||
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
|
||||
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
|
||||
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.BalanceCost > 0 {
|
||||
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.APIKeyQuotaCost > 0 {
|
||||
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.APIKeyQuotaExhausted = exhausted
|
||||
}
|
||||
|
||||
if cmd.APIKeyRateLimitCost > 0 {
|
||||
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
||||
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
|
||||
const updateSQL = `
|
||||
UPDATE user_subscriptions us
|
||||
SET
|
||||
daily_usage_usd = us.daily_usage_usd + $1,
|
||||
weekly_usage_usd = us.weekly_usage_usd + $1,
|
||||
monthly_usage_usd = us.monthly_usage_usd + $1,
|
||||
updated_at = NOW()
|
||||
FROM groups g
|
||||
WHERE us.id = $2
|
||||
AND us.deleted_at IS NULL
|
||||
AND us.group_id = g.id
|
||||
AND g.deleted_at IS NULL
|
||||
`
|
||||
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE users
|
||||
SET balance = balance - $1,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, amount, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
||||
var exhausted bool
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
UPDATE api_keys
|
||||
SET quota_used = quota_used + $1,
|
||||
status = CASE
|
||||
WHEN quota > 0
|
||||
AND status = $3
|
||||
AND quota_used < quota
|
||||
AND quota_used + $1 >= quota
|
||||
THEN $4
|
||||
ELSE status
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
|
||||
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, service.ErrAPIKeyNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exhausted, nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, cost, apiKeyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
|
||||
rows, err := tx.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||
amount, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var newUsed, limit float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-" + uuid.NewString(),
|
||||
Name: "billing",
|
||||
Quota: 1,
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
BalanceCost: 1.25,
|
||||
APIKeyQuotaCost: 1.25,
|
||||
APIKeyRateLimitCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result1)
|
||||
require.True(t, result1.Applied)
|
||||
require.True(t, result1.APIKeyQuotaExhausted)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result2)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed))
|
||||
require.InDelta(t, 1.25, quotaUsed, 0.000001)
|
||||
|
||||
var usage5h float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
|
||||
require.InDelta(t, 1.25, usage5h, 0.000001)
|
||||
|
||||
var status string
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
|
||||
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
|
||||
|
||||
var dedupCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
|
||||
require.Equal(t, 1, dedupCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
group := mustCreateGroup(t, client, &service.Group{
|
||||
Name: "usage-billing-group-" + uuid.NewString(),
|
||||
Platform: service.PlatformAnthropic,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
GroupID: &group.ID,
|
||||
Key: "sk-usage-billing-sub-" + uuid.NewString(),
|
||||
Name: "billing-sub",
|
||||
})
|
||||
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: 0,
|
||||
SubscriptionID: &subscription.ID,
|
||||
SubscriptionCost: 2.5,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var dailyUsage float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
|
||||
require.InDelta(t, 2.5, dailyUsage, 0.000001)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
|
||||
Name: "billing-conflict",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 2.50,
|
||||
})
|
||||
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-account-" + uuid.NewString(),
|
||||
Name: "billing-account",
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-quota-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_limit": 100.0,
|
||||
},
|
||||
})
|
||||
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: uuid.NewString(),
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
AccountQuotaCost: 3.5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed))
|
||||
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
oldRequestID := "dedup-old-" + uuid.NewString()
|
||||
newRequestID := "dedup-new-" + uuid.NewString()
|
||||
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
|
||||
newCreatedAt := time.Now().UTC().Add(-time.Hour)
|
||||
|
||||
_, err := integrationDB.ExecContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
|
||||
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
|
||||
`,
|
||||
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
|
||||
newRequestID, strings.Repeat("b", 64), newCreatedAt,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
var oldCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
|
||||
require.Equal(t, 0, oldCount)
|
||||
|
||||
var newCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
|
||||
require.Equal(t, 1, newCount)
|
||||
|
||||
var archivedCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
|
||||
require.Equal(t, 1, archivedCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-archive-" + uuid.NewString(),
|
||||
Name: "billing-archive",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
_, err = integrationDB.ExecContext(ctx, `
|
||||
UPDATE usage_billing_dedup
|
||||
SET created_at = $1
|
||||
WHERE request_id = $2 AND api_key_id = $3
|
||||
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
}
|
||||
@@ -3,12 +3,14 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -17,11 +19,13 @@ import (
|
||||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
@@ -47,18 +51,29 @@ type usageLogRepository struct {
|
||||
sql sqlExecutor
|
||||
db *sql.DB
|
||||
|
||||
createBatchOnce sync.Once
|
||||
createBatchCh chan usageLogCreateRequest
|
||||
createBatchOnce sync.Once
|
||||
createBatchCh chan usageLogCreateRequest
|
||||
bestEffortBatchOnce sync.Once
|
||||
bestEffortBatchCh chan usageLogBestEffortRequest
|
||||
bestEffortRecent *gocache.Cache
|
||||
}
|
||||
|
||||
const (
|
||||
usageLogCreateBatchMaxSize = 64
|
||||
usageLogCreateBatchWindow = 3 * time.Millisecond
|
||||
usageLogCreateBatchQueueCap = 4096
|
||||
usageLogCreateCancelWait = 2 * time.Second
|
||||
|
||||
usageLogBestEffortBatchMaxSize = 256
|
||||
usageLogBestEffortBatchWindow = 20 * time.Millisecond
|
||||
usageLogBestEffortBatchQueueCap = 32768
|
||||
usageLogBestEffortRecentTTL = 30 * time.Second
|
||||
)
|
||||
|
||||
type usageLogCreateRequest struct {
|
||||
log *service.UsageLog
|
||||
prepared usageLogInsertPrepared
|
||||
shared *usageLogCreateShared
|
||||
resultCh chan usageLogCreateResult
|
||||
}
|
||||
|
||||
@@ -67,6 +82,12 @@ type usageLogCreateResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
type usageLogBestEffortRequest struct {
|
||||
prepared usageLogInsertPrepared
|
||||
apiKeyID int64
|
||||
resultCh chan error
|
||||
}
|
||||
|
||||
type usageLogInsertPrepared struct {
|
||||
createdAt time.Time
|
||||
requestID string
|
||||
@@ -80,6 +101,25 @@ type usageLogBatchState struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type usageLogBatchRow struct {
|
||||
RequestID string `json:"request_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Inserted bool `json:"inserted"`
|
||||
}
|
||||
|
||||
type usageLogCreateShared struct {
|
||||
state atomic.Int32
|
||||
}
|
||||
|
||||
const (
|
||||
usageLogCreateStateQueued int32 = iota
|
||||
usageLogCreateStateProcessing
|
||||
usageLogCreateStateCompleted
|
||||
usageLogCreateStateCanceled
|
||||
)
|
||||
|
||||
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
|
||||
return newUsageLogRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
@@ -90,6 +130,7 @@ func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usage
|
||||
if db, ok := sqlq.(*sql.DB); ok {
|
||||
repo.db = db
|
||||
}
|
||||
repo.bestEffortRecent = gocache.New(usageLogBestEffortRecentTTL, time.Minute)
|
||||
return repo
|
||||
}
|
||||
|
||||
@@ -124,9 +165,6 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return r.createSingle(ctx, tx.Client(), log)
|
||||
}
|
||||
if r.db == nil {
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
}
|
||||
requestID := strings.TrimSpace(log.RequestID)
|
||||
if requestID == "" {
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
@@ -135,11 +173,61 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
return r.createBatched(ctx, log)
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error {
|
||||
if log == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
_, err := r.createSingle(ctx, tx.Client(), log)
|
||||
return err
|
||||
}
|
||||
if r.db == nil {
|
||||
_, err := r.createSingle(ctx, r.sql, log)
|
||||
return err
|
||||
}
|
||||
|
||||
r.ensureBestEffortBatcher()
|
||||
if r.bestEffortBatchCh == nil {
|
||||
_, err := r.createSingle(ctx, r.sql, log)
|
||||
return err
|
||||
}
|
||||
|
||||
req := usageLogBestEffortRequest{
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
apiKeyID: log.APIKeyID,
|
||||
resultCh: make(chan error, 1),
|
||||
}
|
||||
if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok {
|
||||
if _, exists := r.bestEffortRecent.Get(key); exists {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case r.bestEffortBatchCh <- req:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return errors.New("usage log best-effort queue full")
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-req.resultCh:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) {
|
||||
prepared := prepareUsageLogInsert(log)
|
||||
if sqlq == nil {
|
||||
sqlq = r.sql
|
||||
}
|
||||
if ctx != nil && ctx.Err() != nil {
|
||||
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO usage_logs (
|
||||
@@ -218,13 +306,15 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
|
||||
|
||||
req := usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
shared: &usageLogCreateShared{},
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
case r.createBatchCh <- req:
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
||||
default:
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
}
|
||||
@@ -233,7 +323,17 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
|
||||
case res := <-req.resultCh:
|
||||
return res.inserted, res.err
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) {
|
||||
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
||||
}
|
||||
timer := time.NewTimer(usageLogCreateCancelWait)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case res := <-req.resultCh:
|
||||
return res.inserted, res.err
|
||||
case <-timer.C:
|
||||
return false, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +347,16 @@ func (r *usageLogRepository) ensureCreateBatcher() {
|
||||
})
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ensureBestEffortBatcher() {
|
||||
if r == nil || r.db == nil {
|
||||
return
|
||||
}
|
||||
r.bestEffortBatchOnce.Do(func() {
|
||||
r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap)
|
||||
go r.runBestEffortBatcher(r.db)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
|
||||
for {
|
||||
first, ok := <-r.createBatchCh
|
||||
@@ -281,6 +391,40 @@ func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) {
|
||||
for {
|
||||
first, ok := <-r.bestEffortBatchCh
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize)
|
||||
batch = append(batch, first)
|
||||
|
||||
timer := time.NewTimer(usageLogBestEffortBatchWindow)
|
||||
bestEffortLoop:
|
||||
for len(batch) < usageLogBestEffortBatchMaxSize {
|
||||
select {
|
||||
case req, ok := <-r.bestEffortBatchCh:
|
||||
if !ok {
|
||||
break bestEffortLoop
|
||||
}
|
||||
batch = append(batch, req)
|
||||
case <-timer.C:
|
||||
break bestEffortLoop
|
||||
}
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
r.flushBestEffortBatch(db, batch)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
@@ -293,10 +437,19 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
||||
|
||||
for _, req := range batch {
|
||||
if req.log == nil {
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil})
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||
continue
|
||||
}
|
||||
prepared := prepareUsageLogInsert(req.log)
|
||||
if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) {
|
||||
if req.shared.state.Load() == usageLogCreateStateCanceled {
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{
|
||||
inserted: false,
|
||||
err: service.MarkUsageLogCreateNotPersisted(context.Canceled),
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
prepared := req.prepared
|
||||
if prepared.requestID == "" {
|
||||
fallback = append(fallback, req)
|
||||
continue
|
||||
@@ -310,10 +463,37 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
||||
}
|
||||
|
||||
if len(uniqueOrder) > 0 {
|
||||
insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
|
||||
insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
|
||||
if err != nil {
|
||||
for _, key := range uniqueOrder {
|
||||
fallback = append(fallback, requestsByKey[key]...)
|
||||
if safeFallback {
|
||||
for _, key := range uniqueOrder {
|
||||
fallback = append(fallback, requestsByKey[key]...)
|
||||
}
|
||||
} else {
|
||||
for _, key := range uniqueOrder {
|
||||
reqs := requestsByKey[key]
|
||||
state, hasState := stateMap[key]
|
||||
inserted := insertedMap[key]
|
||||
for idx, req := range reqs {
|
||||
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
|
||||
if hasState {
|
||||
req.log.ID = state.ID
|
||||
req.log.CreatedAt = state.CreatedAt
|
||||
}
|
||||
switch {
|
||||
case inserted && idx == 0:
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil})
|
||||
case inserted:
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||
case hasState:
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||
case idx == 0:
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err})
|
||||
default:
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, key := range uniqueOrder {
|
||||
@@ -321,7 +501,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
||||
state, ok := stateMap[key]
|
||||
if !ok {
|
||||
for _, req := range reqs {
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{
|
||||
inserted: false,
|
||||
err: fmt.Errorf("usage log batch state missing for key=%s", key),
|
||||
})
|
||||
@@ -332,7 +512,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
||||
req.log.ID = state.ID
|
||||
req.log.CreatedAt = state.CreatedAt
|
||||
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{
|
||||
inserted: idx == 0 && insertedMap[key],
|
||||
err: nil,
|
||||
})
|
||||
@@ -345,56 +525,366 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
||||
return
|
||||
}
|
||||
|
||||
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
for _, req := range fallback {
|
||||
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
inserted, err := r.createSingle(fallbackCtx, db, req.log)
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err})
|
||||
cancel()
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) {
|
||||
func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
type bestEffortGroup struct {
|
||||
prepared usageLogInsertPrepared
|
||||
apiKeyID int64
|
||||
key string
|
||||
reqs []usageLogBestEffortRequest
|
||||
}
|
||||
|
||||
groupsByKey := make(map[string]*bestEffortGroup, len(batch))
|
||||
groupOrder := make([]*bestEffortGroup, 0, len(batch))
|
||||
preparedList := make([]usageLogInsertPrepared, 0, len(batch))
|
||||
|
||||
for idx, req := range batch {
|
||||
prepared := req.prepared
|
||||
key := fmt.Sprintf("__best_effort_%d", idx)
|
||||
if prepared.requestID != "" {
|
||||
key = usageLogBatchKey(prepared.requestID, req.apiKeyID)
|
||||
}
|
||||
group, exists := groupsByKey[key]
|
||||
if !exists {
|
||||
group = &bestEffortGroup{
|
||||
prepared: prepared,
|
||||
apiKeyID: req.apiKeyID,
|
||||
key: key,
|
||||
}
|
||||
groupsByKey[key] = group
|
||||
groupOrder = append(groupOrder, group)
|
||||
preparedList = append(preparedList, prepared)
|
||||
}
|
||||
group.reqs = append(group.reqs, req)
|
||||
}
|
||||
|
||||
if len(preparedList) == 0 {
|
||||
for _, req := range batch {
|
||||
sendUsageLogBestEffortResult(req.resultCh, nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
query, args := buildUsageLogBestEffortInsertQuery(preparedList)
|
||||
if _, err := db.ExecContext(ctx, query, args...); err != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err)
|
||||
for _, group := range groupOrder {
|
||||
singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared)
|
||||
if singleErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr)
|
||||
} else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
|
||||
r.bestEffortRecent.SetDefault(group.key, struct{}{})
|
||||
}
|
||||
for _, req := range group.reqs {
|
||||
sendUsageLogBestEffortResult(req.resultCh, singleErr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, group := range groupOrder {
|
||||
if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
|
||||
r.bestEffortRecent.SetDefault(group.key, struct{}{})
|
||||
}
|
||||
for _, req := range group.reqs {
|
||||
sendUsageLogBestEffortResult(req.resultCh, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sendUsageLogBestEffortResult(ch chan error, err error) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func completeUsageLogCreateRequest(req usageLogCreateRequest, res usageLogCreateResult) {
|
||||
if req.shared != nil {
|
||||
req.shared.state.Store(usageLogCreateStateCompleted)
|
||||
}
|
||||
sendUsageLogCreateResult(req.resultCh, res)
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) {
|
||||
if len(keys) == 0 {
|
||||
return map[string]bool{}, map[string]usageLogBatchState{}, nil
|
||||
return map[string]bool{}, map[string]usageLogBatchState{}, false, nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey)
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
var payload []byte
|
||||
if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil {
|
||||
return nil, nil, true, err
|
||||
}
|
||||
var rows []usageLogBatchRow
|
||||
if err := json.Unmarshal(payload, &rows); err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
insertedMap := make(map[string]bool, len(keys))
|
||||
for rows.Next() {
|
||||
var (
|
||||
requestID string
|
||||
apiKeyID int64
|
||||
id int64
|
||||
createdAt time.Time
|
||||
)
|
||||
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, nil, err
|
||||
stateMap := make(map[string]usageLogBatchState, len(keys))
|
||||
for _, row := range rows {
|
||||
key := usageLogBatchKey(row.RequestID, row.APIKeyID)
|
||||
insertedMap[key] = row.Inserted
|
||||
stateMap[key] = usageLogBatchState{
|
||||
ID: row.ID,
|
||||
CreatedAt: row.CreatedAt,
|
||||
}
|
||||
insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, nil, err
|
||||
if len(stateMap) != len(keys) {
|
||||
return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys))
|
||||
}
|
||||
_ = rows.Close()
|
||||
|
||||
stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return insertedMap, stateMap, nil
|
||||
return insertedMap, stateMap, false, nil
|
||||
}
|
||||
|
||||
func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) {
|
||||
var query strings.Builder
|
||||
_, _ = query.WriteString(`
|
||||
WITH input (
|
||||
input_idx,
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(keys)*37)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("(")
|
||||
_, _ = query.WriteString("$")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
args = append(args, idx)
|
||||
argPos++
|
||||
prepared := preparedByKey[key]
|
||||
for i := 0; i < len(prepared.args); i++ {
|
||||
_, _ = query.WriteString(",")
|
||||
_, _ = query.WriteString("$")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
argPos++
|
||||
}
|
||||
_, _ = query.WriteString(")")
|
||||
args = append(args, prepared.args...)
|
||||
}
|
||||
_, _ = query.WriteString(`
|
||||
),
|
||||
inserted AS (
|
||||
INSERT INTO usage_logs (
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
)
|
||||
SELECT
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
FROM input
|
||||
ON CONFLICT (request_id, api_key_id) DO UPDATE
|
||||
SET request_id = usage_logs.request_id
|
||||
RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted
|
||||
)
|
||||
SELECT COALESCE(
|
||||
json_agg(
|
||||
json_build_object(
|
||||
'request_id', inserted.request_id,
|
||||
'api_key_id', inserted.api_key_id,
|
||||
'id', inserted.id,
|
||||
'created_at', inserted.created_at,
|
||||
'inserted', inserted.inserted
|
||||
)
|
||||
ORDER BY input.input_idx
|
||||
),
|
||||
'[]'::json
|
||||
)
|
||||
FROM input
|
||||
JOIN inserted
|
||||
ON inserted.request_id = input.request_id
|
||||
AND inserted.api_key_id = input.api_key_id
|
||||
`)
|
||||
return query.String(), args
|
||||
}
|
||||
|
||||
func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) {
|
||||
var query strings.Builder
|
||||
_, _ = query.WriteString(`
|
||||
WITH input (
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(preparedList)*36)
|
||||
argPos := 1
|
||||
for idx, prepared := range preparedList {
|
||||
if idx > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("(")
|
||||
for i := 0; i < len(prepared.args); i++ {
|
||||
if i > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("$")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
argPos++
|
||||
}
|
||||
_, _ = query.WriteString(")")
|
||||
args = append(args, prepared.args...)
|
||||
}
|
||||
|
||||
_, _ = query.WriteString(`
|
||||
)
|
||||
INSERT INTO usage_logs (
|
||||
user_id,
|
||||
api_key_id,
|
||||
@@ -432,80 +922,101 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES `)
|
||||
|
||||
args := make([]any, 0, len(keys)*36)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("(")
|
||||
prepared := preparedByKey[key]
|
||||
for i := 0; i < len(prepared.args); i++ {
|
||||
if i > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("$")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
argPos++
|
||||
}
|
||||
_, _ = query.WriteString(")")
|
||||
args = append(args, prepared.args...)
|
||||
}
|
||||
_, _ = query.WriteString(`
|
||||
)
|
||||
SELECT
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
FROM input
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING request_id, api_key_id, id, created_at
|
||||
`)
|
||||
|
||||
return query.String(), args
|
||||
}
|
||||
|
||||
func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) {
|
||||
var query strings.Builder
|
||||
_, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `)
|
||||
args := make([]any, 0, len(keys)*2)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
_, _ = query.WriteString(" OR ")
|
||||
}
|
||||
prepared := preparedByKey[key]
|
||||
apiKeyID := prepared.args[1]
|
||||
_, _ = query.WriteString("(request_id = $")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
_, _ = query.WriteString(" AND api_key_id = $")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos + 1))
|
||||
_, _ = query.WriteString(")")
|
||||
args = append(args, prepared.requestID, apiKeyID)
|
||||
argPos += 2
|
||||
}
|
||||
|
||||
rows, err := db.QueryContext(ctx, query.String(), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
stateMap := make(map[string]usageLogBatchState, len(keys))
|
||||
for rows.Next() {
|
||||
var (
|
||||
requestID string
|
||||
apiKeyID int64
|
||||
id int64
|
||||
createdAt time.Time
|
||||
func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error {
|
||||
_, err := sqlq.ExecContext(ctx, `
|
||||
INSERT INTO usage_logs (
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
$6, $7,
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
)
|
||||
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{
|
||||
ID: id,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stateMap, nil
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
`, prepared.args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
@@ -597,6 +1108,14 @@ func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateRe
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) bestEffortRecentKey(requestID string, apiKeyID int64) (string, bool) {
|
||||
requestID = strings.TrimSpace(requestID)
|
||||
if requestID == "" || r == nil || r.bestEffortRecent == nil {
|
||||
return "", false
|
||||
}
|
||||
return usageLogBatchKey(requestID, apiKeyID), true
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
||||
rows, err := r.sql.QueryContext(ctx, query, id)
|
||||
|
||||
@@ -183,6 +183,214 @@ func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
const total = 8
|
||||
batch := make([]usageLogCreateRequest, 0, total)
|
||||
logs := make([]*service.UsageLog, 0, total)
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10 + i,
|
||||
OutputTokens: 20 + i,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
logs = append(logs, log)
|
||||
batch = append(batch, usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
})
|
||||
}
|
||||
|
||||
repo.flushCreateBatch(integrationDB, batch)
|
||||
|
||||
insertedCount := 0
|
||||
var firstID int64
|
||||
for idx, req := range batch {
|
||||
res := <-req.resultCh
|
||||
require.NoError(t, res.err)
|
||||
if res.inserted {
|
||||
insertedCount++
|
||||
}
|
||||
require.NotZero(t, logs[idx].ID)
|
||||
if idx == 0 {
|
||||
firstID = logs[idx].ID
|
||||
} else {
|
||||
require.Equal(t, firstID, logs[idx].ID)
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, insertedCount)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log1))
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log2))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
var count int
|
||||
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
|
||||
return err == nil && count == 1
|
||||
}, 3*time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.False(t, inserted)
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
_, err := repo.createBatched(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
req := <-repo.createBatchCh
|
||||
require.NotNil(t, req.shared)
|
||||
cancel()
|
||||
|
||||
err := <-errCh
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
req := usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
shared: &usageLogCreateShared{},
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
}
|
||||
req.shared.state.Store(usageLogCreateStateCanceled)
|
||||
|
||||
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
|
||||
|
||||
res := <-req.resultCh
|
||||
require.False(t, res.inserted)
|
||||
require.Error(t, res.err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
|
||||
@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAnnouncementRepository,
|
||||
NewAnnouncementReadRepository,
|
||||
NewUsageLogRepository,
|
||||
NewUsageBillingRepository,
|
||||
NewIdempotencyRepository,
|
||||
NewUsageCleanupRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
|
||||
Reference in New Issue
Block a user