feat: decouple billing correctness from usage log batching

This commit is contained in:
ius
2026-03-12 16:53:18 +08:00
parent c9debc50b1
commit 611fd884bd
37 changed files with 3379 additions and 330 deletions

View File

@@ -3,12 +3,14 @@ package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -17,11 +19,13 @@ import (
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
@@ -47,18 +51,29 @@ type usageLogRepository struct {
sql sqlExecutor
db *sql.DB
createBatchOnce sync.Once
createBatchCh chan usageLogCreateRequest
createBatchOnce sync.Once
createBatchCh chan usageLogCreateRequest
bestEffortBatchOnce sync.Once
bestEffortBatchCh chan usageLogBestEffortRequest
bestEffortRecent *gocache.Cache
}
const (
usageLogCreateBatchMaxSize = 64
usageLogCreateBatchWindow = 3 * time.Millisecond
usageLogCreateBatchQueueCap = 4096
usageLogCreateCancelWait = 2 * time.Second
usageLogBestEffortBatchMaxSize = 256
usageLogBestEffortBatchWindow = 20 * time.Millisecond
usageLogBestEffortBatchQueueCap = 32768
usageLogBestEffortRecentTTL = 30 * time.Second
)
type usageLogCreateRequest struct {
log *service.UsageLog
prepared usageLogInsertPrepared
shared *usageLogCreateShared
resultCh chan usageLogCreateResult
}
@@ -67,6 +82,12 @@ type usageLogCreateResult struct {
err error
}
type usageLogBestEffortRequest struct {
prepared usageLogInsertPrepared
apiKeyID int64
resultCh chan error
}
type usageLogInsertPrepared struct {
createdAt time.Time
requestID string
@@ -80,6 +101,25 @@ type usageLogBatchState struct {
CreatedAt time.Time
}
type usageLogBatchRow struct {
RequestID string `json:"request_id"`
APIKeyID int64 `json:"api_key_id"`
ID int64 `json:"id"`
CreatedAt time.Time `json:"created_at"`
Inserted bool `json:"inserted"`
}
type usageLogCreateShared struct {
state atomic.Int32
}
const (
usageLogCreateStateQueued int32 = iota
usageLogCreateStateProcessing
usageLogCreateStateCompleted
usageLogCreateStateCanceled
)
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
return newUsageLogRepositoryWithSQL(client, sqlDB)
}
@@ -90,6 +130,7 @@ func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usage
if db, ok := sqlq.(*sql.DB); ok {
repo.db = db
}
repo.bestEffortRecent = gocache.New(usageLogBestEffortRecentTTL, time.Minute)
return repo
}
@@ -124,9 +165,6 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
if tx := dbent.TxFromContext(ctx); tx != nil {
return r.createSingle(ctx, tx.Client(), log)
}
if r.db == nil {
return r.createSingle(ctx, r.sql, log)
}
requestID := strings.TrimSpace(log.RequestID)
if requestID == "" {
return r.createSingle(ctx, r.sql, log)
@@ -135,11 +173,61 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
return r.createBatched(ctx, log)
}
func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error {
if log == nil {
return nil
}
if tx := dbent.TxFromContext(ctx); tx != nil {
_, err := r.createSingle(ctx, tx.Client(), log)
return err
}
if r.db == nil {
_, err := r.createSingle(ctx, r.sql, log)
return err
}
r.ensureBestEffortBatcher()
if r.bestEffortBatchCh == nil {
_, err := r.createSingle(ctx, r.sql, log)
return err
}
req := usageLogBestEffortRequest{
prepared: prepareUsageLogInsert(log),
apiKeyID: log.APIKeyID,
resultCh: make(chan error, 1),
}
if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok {
if _, exists := r.bestEffortRecent.Get(key); exists {
return nil
}
}
select {
case r.bestEffortBatchCh <- req:
case <-ctx.Done():
return ctx.Err()
default:
return errors.New("usage log best-effort queue full")
}
select {
case err := <-req.resultCh:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) {
prepared := prepareUsageLogInsert(log)
if sqlq == nil {
sqlq = r.sql
}
if ctx != nil && ctx.Err() != nil {
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
}
query := `
INSERT INTO usage_logs (
@@ -218,13 +306,15 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
req := usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
shared: &usageLogCreateShared{},
resultCh: make(chan usageLogCreateResult, 1),
}
select {
case r.createBatchCh <- req:
case <-ctx.Done():
return false, ctx.Err()
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
default:
return r.createSingle(ctx, r.sql, log)
}
@@ -233,7 +323,17 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
case res := <-req.resultCh:
return res.inserted, res.err
case <-ctx.Done():
return false, ctx.Err()
if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) {
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
}
timer := time.NewTimer(usageLogCreateCancelWait)
defer timer.Stop()
select {
case res := <-req.resultCh:
return res.inserted, res.err
case <-timer.C:
return false, ctx.Err()
}
}
}
@@ -247,6 +347,16 @@ func (r *usageLogRepository) ensureCreateBatcher() {
})
}
func (r *usageLogRepository) ensureBestEffortBatcher() {
if r == nil || r.db == nil {
return
}
r.bestEffortBatchOnce.Do(func() {
r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap)
go r.runBestEffortBatcher(r.db)
})
}
func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
for {
first, ok := <-r.createBatchCh
@@ -281,6 +391,40 @@ func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
}
}
func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) {
for {
first, ok := <-r.bestEffortBatchCh
if !ok {
return
}
batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize)
batch = append(batch, first)
timer := time.NewTimer(usageLogBestEffortBatchWindow)
bestEffortLoop:
for len(batch) < usageLogBestEffortBatchMaxSize {
select {
case req, ok := <-r.bestEffortBatchCh:
if !ok {
break bestEffortLoop
}
batch = append(batch, req)
case <-timer.C:
break bestEffortLoop
}
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
r.flushBestEffortBatch(db, batch)
}
}
func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) {
if len(batch) == 0 {
return
@@ -293,10 +437,19 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
for _, req := range batch {
if req.log == nil {
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil})
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
continue
}
prepared := prepareUsageLogInsert(req.log)
if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) {
if req.shared.state.Load() == usageLogCreateStateCanceled {
completeUsageLogCreateRequest(req, usageLogCreateResult{
inserted: false,
err: service.MarkUsageLogCreateNotPersisted(context.Canceled),
})
continue
}
}
prepared := req.prepared
if prepared.requestID == "" {
fallback = append(fallback, req)
continue
@@ -310,10 +463,37 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
}
if len(uniqueOrder) > 0 {
insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
if err != nil {
for _, key := range uniqueOrder {
fallback = append(fallback, requestsByKey[key]...)
if safeFallback {
for _, key := range uniqueOrder {
fallback = append(fallback, requestsByKey[key]...)
}
} else {
for _, key := range uniqueOrder {
reqs := requestsByKey[key]
state, hasState := stateMap[key]
inserted := insertedMap[key]
for idx, req := range reqs {
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
if hasState {
req.log.ID = state.ID
req.log.CreatedAt = state.CreatedAt
}
switch {
case inserted && idx == 0:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil})
case inserted:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
case hasState:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
case idx == 0:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err})
default:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
}
}
}
}
} else {
for _, key := range uniqueOrder {
@@ -321,7 +501,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
state, ok := stateMap[key]
if !ok {
for _, req := range reqs {
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
completeUsageLogCreateRequest(req, usageLogCreateResult{
inserted: false,
err: fmt.Errorf("usage log batch state missing for key=%s", key),
})
@@ -332,7 +512,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
req.log.ID = state.ID
req.log.CreatedAt = state.CreatedAt
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
completeUsageLogCreateRequest(req, usageLogCreateResult{
inserted: idx == 0 && insertedMap[key],
err: nil,
})
@@ -345,56 +525,366 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
return
}
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for _, req := range fallback {
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
inserted, err := r.createSingle(fallbackCtx, db, req.log)
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err})
cancel()
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err})
}
}
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) {
func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) {
if len(batch) == 0 {
return
}
type bestEffortGroup struct {
prepared usageLogInsertPrepared
apiKeyID int64
key string
reqs []usageLogBestEffortRequest
}
groupsByKey := make(map[string]*bestEffortGroup, len(batch))
groupOrder := make([]*bestEffortGroup, 0, len(batch))
preparedList := make([]usageLogInsertPrepared, 0, len(batch))
for idx, req := range batch {
prepared := req.prepared
key := fmt.Sprintf("__best_effort_%d", idx)
if prepared.requestID != "" {
key = usageLogBatchKey(prepared.requestID, req.apiKeyID)
}
group, exists := groupsByKey[key]
if !exists {
group = &bestEffortGroup{
prepared: prepared,
apiKeyID: req.apiKeyID,
key: key,
}
groupsByKey[key] = group
groupOrder = append(groupOrder, group)
preparedList = append(preparedList, prepared)
}
group.reqs = append(group.reqs, req)
}
if len(preparedList) == 0 {
for _, req := range batch {
sendUsageLogBestEffortResult(req.resultCh, nil)
}
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
query, args := buildUsageLogBestEffortInsertQuery(preparedList)
if _, err := db.ExecContext(ctx, query, args...); err != nil {
logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err)
for _, group := range groupOrder {
singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared)
if singleErr != nil {
logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr)
} else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
r.bestEffortRecent.SetDefault(group.key, struct{}{})
}
for _, req := range group.reqs {
sendUsageLogBestEffortResult(req.resultCh, singleErr)
}
}
return
}
for _, group := range groupOrder {
if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
r.bestEffortRecent.SetDefault(group.key, struct{}{})
}
for _, req := range group.reqs {
sendUsageLogBestEffortResult(req.resultCh, nil)
}
}
}
func sendUsageLogBestEffortResult(ch chan error, err error) {
if ch == nil {
return
}
select {
case ch <- err:
default:
}
}
func completeUsageLogCreateRequest(req usageLogCreateRequest, res usageLogCreateResult) {
if req.shared != nil {
req.shared.state.Store(usageLogCreateStateCompleted)
}
sendUsageLogCreateResult(req.resultCh, res)
}
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) {
if len(keys) == 0 {
return map[string]bool{}, map[string]usageLogBatchState{}, nil
return map[string]bool{}, map[string]usageLogBatchState{}, false, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, nil, err
var payload []byte
if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil {
return nil, nil, true, err
}
var rows []usageLogBatchRow
if err := json.Unmarshal(payload, &rows); err != nil {
return nil, nil, false, err
}
insertedMap := make(map[string]bool, len(keys))
for rows.Next() {
var (
requestID string
apiKeyID int64
id int64
createdAt time.Time
)
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
_ = rows.Close()
return nil, nil, err
stateMap := make(map[string]usageLogBatchState, len(keys))
for _, row := range rows {
key := usageLogBatchKey(row.RequestID, row.APIKeyID)
insertedMap[key] = row.Inserted
stateMap[key] = usageLogBatchState{
ID: row.ID,
CreatedAt: row.CreatedAt,
}
insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true
}
if err := rows.Err(); err != nil {
_ = rows.Close()
return nil, nil, err
if len(stateMap) != len(keys) {
return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys))
}
_ = rows.Close()
stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey)
if err != nil {
return nil, nil, err
}
return insertedMap, stateMap, nil
return insertedMap, stateMap, false, nil
}
func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) {
var query strings.Builder
_, _ = query.WriteString(`
WITH input (
input_idx,
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*37)
argPos := 1
for idx, key := range keys {
if idx > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("(")
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
args = append(args, idx)
argPos++
prepared := preparedByKey[key]
for i := 0; i < len(prepared.args); i++ {
_, _ = query.WriteString(",")
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
argPos++
}
_, _ = query.WriteString(")")
args = append(args, prepared.args...)
}
_, _ = query.WriteString(`
),
inserted AS (
INSERT INTO usage_logs (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
)
SELECT
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO UPDATE
SET request_id = usage_logs.request_id
RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted
)
SELECT COALESCE(
json_agg(
json_build_object(
'request_id', inserted.request_id,
'api_key_id', inserted.api_key_id,
'id', inserted.id,
'created_at', inserted.created_at,
'inserted', inserted.inserted
)
ORDER BY input.input_idx
),
'[]'::json
)
FROM input
JOIN inserted
ON inserted.request_id = input.request_id
AND inserted.api_key_id = input.api_key_id
`)
return query.String(), args
}
func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) {
var query strings.Builder
_, _ = query.WriteString(`
WITH input (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*36)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("(")
for i := 0; i < len(prepared.args); i++ {
if i > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
argPos++
}
_, _ = query.WriteString(")")
args = append(args, prepared.args...)
}
_, _ = query.WriteString(`
)
INSERT INTO usage_logs (
user_id,
api_key_id,
@@ -432,80 +922,101 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES `)
args := make([]any, 0, len(keys)*36)
argPos := 1
for idx, key := range keys {
if idx > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("(")
prepared := preparedByKey[key]
for i := 0; i < len(prepared.args); i++ {
if i > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
argPos++
}
_, _ = query.WriteString(")")
args = append(args, prepared.args...)
}
_, _ = query.WriteString(`
)
SELECT
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING request_id, api_key_id, id, created_at
`)
return query.String(), args
}
func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) {
var query strings.Builder
_, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `)
args := make([]any, 0, len(keys)*2)
argPos := 1
for idx, key := range keys {
if idx > 0 {
_, _ = query.WriteString(" OR ")
}
prepared := preparedByKey[key]
apiKeyID := prepared.args[1]
_, _ = query.WriteString("(request_id = $")
_, _ = query.WriteString(strconv.Itoa(argPos))
_, _ = query.WriteString(" AND api_key_id = $")
_, _ = query.WriteString(strconv.Itoa(argPos + 1))
_, _ = query.WriteString(")")
args = append(args, prepared.requestID, apiKeyID)
argPos += 2
}
rows, err := db.QueryContext(ctx, query.String(), args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
stateMap := make(map[string]usageLogBatchState, len(keys))
for rows.Next() {
var (
requestID string
apiKeyID int64
id int64
createdAt time.Time
func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error {
_, err := sqlq.ExecContext(ctx, `
INSERT INTO usage_logs (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
)
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
return nil, err
}
stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{
ID: id,
CreatedAt: createdAt,
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return stateMap, nil
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
return err
}
func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
@@ -597,6 +1108,14 @@ func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateRe
}
}
func (r *usageLogRepository) bestEffortRecentKey(requestID string, apiKeyID int64) (string, bool) {
requestID = strings.TrimSpace(requestID)
if requestID == "" || r == nil || r.bestEffortRecent == nil {
return "", false
}
return usageLogBatchKey(requestID, apiKeyID), true
}
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
rows, err := r.sql.QueryContext(ctx, query, id)