mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-25 17:14:45 +08:00
Merge pull request #952 from xvhuan/feat/billing-ledger-decouple-usage-log-20260312
feat: 解耦计费正确性与 usage_logs 批量写压
This commit is contained in:
@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
const usageLogsCleanupBatchSize = 10000
|
||||
const usageBillingDedupCleanupBatchSize = 10000
|
||||
|
||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||
if sqlDB == nil {
|
||||
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
if r == nil || r.sql == nil {
|
||||
return nil
|
||||
}
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
if db, ok := r.sql.(*sql.DB); ok {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
|
||||
if isPartitioned {
|
||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||
}
|
||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
||||
return err
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid
|
||||
FROM usage_logs
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM usage_logs
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageLogsCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageLogsCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM usage_billing_dedup
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
), archived AS (
|
||||
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
|
||||
SELECT request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM victims
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
)
|
||||
DELETE FROM usage_billing_dedup
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageBillingDedupCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
|
||||
@@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
|
||||
SetKey(k.Key).
|
||||
SetName(k.Name).
|
||||
SetStatus(k.Status)
|
||||
if k.Quota != 0 {
|
||||
create.SetQuota(k.Quota)
|
||||
}
|
||||
if k.QuotaUsed != 0 {
|
||||
create.SetQuotaUsed(k.QuotaUsed)
|
||||
}
|
||||
if k.RateLimit5h != 0 {
|
||||
create.SetRateLimit5h(k.RateLimit5h)
|
||||
}
|
||||
if k.RateLimit1d != 0 {
|
||||
create.SetRateLimit1d(k.RateLimit1d)
|
||||
}
|
||||
if k.RateLimit7d != 0 {
|
||||
create.SetRateLimit7d(k.RateLimit7d)
|
||||
}
|
||||
if k.Usage5h != 0 {
|
||||
create.SetUsage5h(k.Usage5h)
|
||||
}
|
||||
if k.Usage1d != 0 {
|
||||
create.SetUsage1d(k.Usage1d)
|
||||
}
|
||||
if k.Usage7d != 0 {
|
||||
create.SetUsage7d(k.Usage7d)
|
||||
}
|
||||
if k.Window5hStart != nil {
|
||||
create.SetWindow5hStart(*k.Window5hStart)
|
||||
}
|
||||
if k.Window1dStart != nil {
|
||||
create.SetWindow1dStart(*k.Window1dStart)
|
||||
}
|
||||
if k.Window7dStart != nil {
|
||||
create.SetWindow7dStart(*k.Window7dStart)
|
||||
}
|
||||
if k.ExpiresAt != nil {
|
||||
create.SetExpiresAt(*k.ExpiresAt)
|
||||
}
|
||||
if k.GroupID != nil {
|
||||
create.SetGroupID(*k.GroupID)
|
||||
}
|
||||
|
||||
@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||
|
||||
// usage_billing_dedup: billing idempotency narrow table
|
||||
var usageBillingDedupRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
|
||||
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
|
||||
|
||||
var usageBillingDedupArchiveRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
|
||||
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
|
||||
|
||||
// settings table should exist
|
||||
var settingsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
||||
}
|
||||
|
||||
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
|
||||
t.Helper()
|
||||
|
||||
var exists bool
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename = $1
|
||||
AND indexname = $2
|
||||
)
|
||||
`, table, index).Scan(&exists)
|
||||
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
|
||||
require.True(t, exists, "expected index %s on %s", index, table)
|
||||
}
|
||||
|
||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
308
backend/internal/repository/usage_billing_repo.go
Normal file
308
backend/internal/repository/usage_billing_repo.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type usageBillingRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
|
||||
return &usageBillingRepository{db: sqlDB}
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
|
||||
if cmd == nil {
|
||||
return &service.UsageBillingApplyResult{}, nil
|
||||
}
|
||||
if r == nil || r.db == nil {
|
||||
return nil, errors.New("usage billing repository db is nil")
|
||||
}
|
||||
|
||||
cmd.Normalize()
|
||||
if cmd.RequestID == "" {
|
||||
return nil, service.ErrUsageBillingRequestIDRequired
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !applied {
|
||||
return &service.UsageBillingApplyResult{Applied: false}, nil
|
||||
}
|
||||
|
||||
result := &service.UsageBillingApplyResult{Applied: true}
|
||||
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tx = nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
|
||||
var id int64
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id
|
||||
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
var existingFingerprint string
|
||||
if err := tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var archivedFingerprint string
|
||||
err = tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup_archive
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
|
||||
if err == nil {
|
||||
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
|
||||
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
|
||||
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.BalanceCost > 0 {
|
||||
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.APIKeyQuotaCost > 0 {
|
||||
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.APIKeyQuotaExhausted = exhausted
|
||||
}
|
||||
|
||||
if cmd.APIKeyRateLimitCost > 0 {
|
||||
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
||||
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
|
||||
const updateSQL = `
|
||||
UPDATE user_subscriptions us
|
||||
SET
|
||||
daily_usage_usd = us.daily_usage_usd + $1,
|
||||
weekly_usage_usd = us.weekly_usage_usd + $1,
|
||||
monthly_usage_usd = us.monthly_usage_usd + $1,
|
||||
updated_at = NOW()
|
||||
FROM groups g
|
||||
WHERE us.id = $2
|
||||
AND us.deleted_at IS NULL
|
||||
AND us.group_id = g.id
|
||||
AND g.deleted_at IS NULL
|
||||
`
|
||||
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE users
|
||||
SET balance = balance - $1,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, amount, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
||||
var exhausted bool
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
UPDATE api_keys
|
||||
SET quota_used = quota_used + $1,
|
||||
status = CASE
|
||||
WHEN quota > 0
|
||||
AND status = $3
|
||||
AND quota_used < quota
|
||||
AND quota_used + $1 >= quota
|
||||
THEN $4
|
||||
ELSE status
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
|
||||
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, service.ErrAPIKeyNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exhausted, nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, cost, apiKeyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
|
||||
rows, err := tx.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||
amount, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var newUsed, limit float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-" + uuid.NewString(),
|
||||
Name: "billing",
|
||||
Quota: 1,
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
BalanceCost: 1.25,
|
||||
APIKeyQuotaCost: 1.25,
|
||||
APIKeyRateLimitCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result1)
|
||||
require.True(t, result1.Applied)
|
||||
require.True(t, result1.APIKeyQuotaExhausted)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result2)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed))
|
||||
require.InDelta(t, 1.25, quotaUsed, 0.000001)
|
||||
|
||||
var usage5h float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
|
||||
require.InDelta(t, 1.25, usage5h, 0.000001)
|
||||
|
||||
var status string
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
|
||||
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
|
||||
|
||||
var dedupCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
|
||||
require.Equal(t, 1, dedupCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
group := mustCreateGroup(t, client, &service.Group{
|
||||
Name: "usage-billing-group-" + uuid.NewString(),
|
||||
Platform: service.PlatformAnthropic,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
GroupID: &group.ID,
|
||||
Key: "sk-usage-billing-sub-" + uuid.NewString(),
|
||||
Name: "billing-sub",
|
||||
})
|
||||
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: 0,
|
||||
SubscriptionID: &subscription.ID,
|
||||
SubscriptionCost: 2.5,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var dailyUsage float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
|
||||
require.InDelta(t, 2.5, dailyUsage, 0.000001)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
|
||||
Name: "billing-conflict",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 2.50,
|
||||
})
|
||||
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-account-" + uuid.NewString(),
|
||||
Name: "billing-account",
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-quota-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_limit": 100.0,
|
||||
},
|
||||
})
|
||||
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: uuid.NewString(),
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
AccountQuotaCost: 3.5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed))
|
||||
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
oldRequestID := "dedup-old-" + uuid.NewString()
|
||||
newRequestID := "dedup-new-" + uuid.NewString()
|
||||
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
|
||||
newCreatedAt := time.Now().UTC().Add(-time.Hour)
|
||||
|
||||
_, err := integrationDB.ExecContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
|
||||
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
|
||||
`,
|
||||
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
|
||||
newRequestID, strings.Repeat("b", 64), newCreatedAt,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
var oldCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
|
||||
require.Equal(t, 0, oldCount)
|
||||
|
||||
var newCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
|
||||
require.Equal(t, 1, newCount)
|
||||
|
||||
var archivedCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
|
||||
require.Equal(t, 1, archivedCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-archive-" + uuid.NewString(),
|
||||
Name: "billing-archive",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
_, err = integrationDB.ExecContext(ctx, `
|
||||
UPDATE usage_billing_dedup
|
||||
SET created_at = $1
|
||||
WHERE request_id = $2 AND api_key_id = $3
|
||||
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -14,6 +16,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
@@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
||||
s.Require().NotZero(log.ID)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
|
||||
|
||||
const total = 16
|
||||
results := make([]bool, total)
|
||||
errs := make([]error, total)
|
||||
logs := make([]*service.UsageLog, total)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(total)
|
||||
for i := 0; i < total; i++ {
|
||||
i := i
|
||||
logs[i] = &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10 + i,
|
||||
OutputTokens: 20 + i,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i], errs[i] = repo.Create(ctx, logs[i])
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
require.NoError(t, errs[i])
|
||||
require.True(t, results[i])
|
||||
require.NotZero(t, logs[i].ID)
|
||||
}
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
|
||||
require.Equal(t, total, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
inserted1, err1 := repo.Create(ctx, log1)
|
||||
inserted2, err2 := repo.Create(ctx, log2)
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
require.True(t, inserted1)
|
||||
require.False(t, inserted2)
|
||||
require.Equal(t, log1.ID, log2.ID)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
const total = 8
|
||||
batch := make([]usageLogCreateRequest, 0, total)
|
||||
logs := make([]*service.UsageLog, 0, total)
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10 + i,
|
||||
OutputTokens: 20 + i,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
logs = append(logs, log)
|
||||
batch = append(batch, usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
})
|
||||
}
|
||||
|
||||
repo.flushCreateBatch(integrationDB, batch)
|
||||
|
||||
insertedCount := 0
|
||||
var firstID int64
|
||||
for idx, req := range batch {
|
||||
res := <-req.resultCh
|
||||
require.NoError(t, res.err)
|
||||
if res.inserted {
|
||||
insertedCount++
|
||||
}
|
||||
require.NotZero(t, logs[idx].ID)
|
||||
if idx == 0 {
|
||||
firstID = logs[idx].ID
|
||||
} else {
|
||||
require.Equal(t, firstID, logs[idx].ID)
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, insertedCount)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log1))
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log2))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
var count int
|
||||
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
|
||||
return err == nil && count == 1
|
||||
}, 3*time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
|
||||
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
|
||||
|
||||
err := repo.CreateBestEffort(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateDropped(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.False(t, inserted)
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||
repo.createBatchCh <- usageLogCreateRequest{}
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
|
||||
|
||||
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.False(t, inserted)
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
_, err := repo.createBatched(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
req := <-repo.createBatchCh
|
||||
require.NotNil(t, req.shared)
|
||||
cancel()
|
||||
|
||||
err := <-errCh
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
req := usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
shared: &usageLogCreateShared{},
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
}
|
||||
req.shared.state.Store(usageLogCreateStateCanceled)
|
||||
|
||||
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
|
||||
|
||||
res := <-req.resultCh
|
||||
require.False(t, res.inserted)
|
||||
require.Error(t, res.err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
|
||||
@@ -3,8 +3,11 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-batch-no-update",
|
||||
Model: "gpt-5",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalCost: 1.2,
|
||||
ActualCost: 1.2,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
prepared := prepareUsageLogInsert(log)
|
||||
|
||||
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
|
||||
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
|
||||
})
|
||||
|
||||
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
|
||||
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
|
||||
}
|
||||
|
||||
@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAnnouncementRepository,
|
||||
NewAnnouncementReadRepository,
|
||||
NewUsageLogRepository,
|
||||
NewUsageBillingRepository,
|
||||
NewIdempotencyRepository,
|
||||
NewUsageCleanupRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
|
||||
Reference in New Issue
Block a user