Batch usage log writes in repository

This commit is contained in:
ius
2026-03-11 20:29:48 +08:00
parent 7455476c60
commit c9debc50b1
2 changed files with 538 additions and 63 deletions

View File

@@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -43,6 +45,39 @@ func safeDateFormat(granularity string) string {
type usageLogRepository struct {
client *dbent.Client
sql sqlExecutor
db *sql.DB
createBatchOnce sync.Once
createBatchCh chan usageLogCreateRequest
}
const (
usageLogCreateBatchMaxSize = 64
usageLogCreateBatchWindow = 3 * time.Millisecond
usageLogCreateBatchQueueCap = 4096
)
type usageLogCreateRequest struct {
log *service.UsageLog
resultCh chan usageLogCreateResult
}
type usageLogCreateResult struct {
inserted bool
err error
}
type usageLogInsertPrepared struct {
createdAt time.Time
requestID string
rateMultiplier float64
requestType int16
args []any
}
type usageLogBatchState struct {
ID int64
CreatedAt time.Time
}
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
@@ -51,7 +86,11 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
// 使用 scanSingleRow 替代 QueryRowContext保证 ent.Tx 作为 sqlExecutor 可用。
return &usageLogRepository{client: client, sql: sqlq}
repo := &usageLogRepository{client: client, sql: sqlq}
if db, ok := sqlq.(*sql.DB); ok {
repo.db = db
}
return repo
}
// getPerformanceStats 获取 RPM 和 TPM近5分钟平均值可选按用户过滤
@@ -82,24 +121,25 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
return false, nil
}
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL保证与其他更新同事务。
// 无事务时回退到默认的 *sql.DB 执行器。
sqlq := r.sql
if tx := dbent.TxFromContext(ctx); tx != nil {
sqlq = tx.Client()
return r.createSingle(ctx, tx.Client(), log)
}
createdAt := log.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now()
if r.db == nil {
return r.createSingle(ctx, r.sql, log)
}
requestID := strings.TrimSpace(log.RequestID)
if requestID == "" {
return r.createSingle(ctx, r.sql, log)
}
log.RequestID = requestID
return r.createBatched(ctx, log)
}
rateMultiplier := log.RateMultiplier
log.SyncRequestTypeAndLegacyFields()
requestType := int16(log.RequestType)
func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) {
prepared := prepareUsageLogInsert(log)
if sqlq == nil {
sqlq = r.sql
}
query := `
INSERT INTO usage_logs (
@@ -151,6 +191,336 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
RETURNING id, created_at
`
if err := scanSingleRow(ctx, sqlq, query, prepared.args, &log.ID, &log.CreatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) && prepared.requestID != "" {
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{prepared.requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
return false, err
}
log.RateMultiplier = prepared.rateMultiplier
return false, nil
} else {
return false, err
}
}
log.RateMultiplier = prepared.rateMultiplier
return true, nil
}
func (r *usageLogRepository) createBatched(ctx context.Context, log *service.UsageLog) (bool, error) {
if r.db == nil {
return r.createSingle(ctx, r.sql, log)
}
r.ensureCreateBatcher()
if r.createBatchCh == nil {
return r.createSingle(ctx, r.sql, log)
}
req := usageLogCreateRequest{
log: log,
resultCh: make(chan usageLogCreateResult, 1),
}
select {
case r.createBatchCh <- req:
case <-ctx.Done():
return false, ctx.Err()
default:
return r.createSingle(ctx, r.sql, log)
}
select {
case res := <-req.resultCh:
return res.inserted, res.err
case <-ctx.Done():
return false, ctx.Err()
}
}
func (r *usageLogRepository) ensureCreateBatcher() {
if r == nil || r.db == nil {
return
}
r.createBatchOnce.Do(func() {
r.createBatchCh = make(chan usageLogCreateRequest, usageLogCreateBatchQueueCap)
go r.runCreateBatcher(r.db)
})
}
func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
for {
first, ok := <-r.createBatchCh
if !ok {
return
}
batch := make([]usageLogCreateRequest, 0, usageLogCreateBatchMaxSize)
batch = append(batch, first)
timer := time.NewTimer(usageLogCreateBatchWindow)
batchLoop:
for len(batch) < usageLogCreateBatchMaxSize {
select {
case req, ok := <-r.createBatchCh:
if !ok {
break batchLoop
}
batch = append(batch, req)
case <-timer.C:
break batchLoop
}
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
r.flushCreateBatch(db, batch)
}
}
func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) {
if len(batch) == 0 {
return
}
uniqueOrder := make([]string, 0, len(batch))
preparedByKey := make(map[string]usageLogInsertPrepared, len(batch))
requestsByKey := make(map[string][]usageLogCreateRequest, len(batch))
fallback := make([]usageLogCreateRequest, 0)
for _, req := range batch {
if req.log == nil {
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil})
continue
}
prepared := prepareUsageLogInsert(req.log)
if prepared.requestID == "" {
fallback = append(fallback, req)
continue
}
key := usageLogBatchKey(prepared.requestID, req.log.APIKeyID)
if _, exists := requestsByKey[key]; !exists {
uniqueOrder = append(uniqueOrder, key)
preparedByKey[key] = prepared
}
requestsByKey[key] = append(requestsByKey[key], req)
}
if len(uniqueOrder) > 0 {
insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
if err != nil {
for _, key := range uniqueOrder {
fallback = append(fallback, requestsByKey[key]...)
}
} else {
for _, key := range uniqueOrder {
reqs := requestsByKey[key]
state, ok := stateMap[key]
if !ok {
for _, req := range reqs {
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
inserted: false,
err: fmt.Errorf("usage log batch state missing for key=%s", key),
})
}
continue
}
for idx, req := range reqs {
req.log.ID = state.ID
req.log.CreatedAt = state.CreatedAt
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
inserted: idx == 0 && insertedMap[key],
err: nil,
})
}
}
}
}
if len(fallback) == 0 {
return
}
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for _, req := range fallback {
inserted, err := r.createSingle(fallbackCtx, db, req.log)
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err})
}
}
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) {
if len(keys) == 0 {
return map[string]bool{}, map[string]usageLogBatchState{}, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, nil, err
}
insertedMap := make(map[string]bool, len(keys))
for rows.Next() {
var (
requestID string
apiKeyID int64
id int64
createdAt time.Time
)
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
_ = rows.Close()
return nil, nil, err
}
insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true
}
if err := rows.Err(); err != nil {
_ = rows.Close()
return nil, nil, err
}
_ = rows.Close()
stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey)
if err != nil {
return nil, nil, err
}
return insertedMap, stateMap, nil
}
func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) {
var query strings.Builder
_, _ = query.WriteString(`
INSERT INTO usage_logs (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES `)
args := make([]any, 0, len(keys)*36)
argPos := 1
for idx, key := range keys {
if idx > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("(")
prepared := preparedByKey[key]
for i := 0; i < len(prepared.args); i++ {
if i > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
argPos++
}
_, _ = query.WriteString(")")
args = append(args, prepared.args...)
}
_, _ = query.WriteString(`
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING request_id, api_key_id, id, created_at
`)
return query.String(), args
}
func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) {
var query strings.Builder
_, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `)
args := make([]any, 0, len(keys)*2)
argPos := 1
for idx, key := range keys {
if idx > 0 {
_, _ = query.WriteString(" OR ")
}
prepared := preparedByKey[key]
apiKeyID := prepared.args[1]
_, _ = query.WriteString("(request_id = $")
_, _ = query.WriteString(strconv.Itoa(argPos))
_, _ = query.WriteString(" AND api_key_id = $")
_, _ = query.WriteString(strconv.Itoa(argPos + 1))
_, _ = query.WriteString(")")
args = append(args, prepared.requestID, apiKeyID)
argPos += 2
}
rows, err := db.QueryContext(ctx, query.String(), args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
stateMap := make(map[string]usageLogBatchState, len(keys))
for rows.Next() {
var (
requestID string
apiKeyID int64
id int64
createdAt time.Time
)
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
return nil, err
}
stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{
ID: id,
CreatedAt: createdAt,
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return stateMap, nil
}
func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
createdAt := log.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now()
}
requestID := strings.TrimSpace(log.RequestID)
log.RequestID = requestID
rateMultiplier := log.RateMultiplier
log.SyncRequestTypeAndLegacyFields()
requestType := int16(log.RequestType)
groupID := nullInt64(log.GroupID)
subscriptionID := nullInt64(log.SubscriptionID)
duration := nullInt(log.DurationMs)
@@ -167,58 +537,64 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
requestIDArg = requestID
}
args := []any{
log.UserID,
log.APIKeyID,
log.AccountID,
requestIDArg,
log.Model,
groupID,
subscriptionID,
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
log.CacheReadCost,
log.TotalCost,
log.ActualCost,
rateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
requestType,
log.Stream,
log.OpenAIWSMode,
duration,
firstToken,
userAgent,
ipAddress,
log.ImageCount,
imageSize,
mediaType,
serviceTier,
reasoningEffort,
log.CacheTTLOverridden,
createdAt,
return usageLogInsertPrepared{
createdAt: createdAt,
requestID: requestID,
rateMultiplier: rateMultiplier,
requestType: requestType,
args: []any{
log.UserID,
log.APIKeyID,
log.AccountID,
requestIDArg,
log.Model,
groupID,
subscriptionID,
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
log.CacheReadCost,
log.TotalCost,
log.ActualCost,
rateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
requestType,
log.Stream,
log.OpenAIWSMode,
duration,
firstToken,
userAgent,
ipAddress,
log.ImageCount,
imageSize,
mediaType,
serviceTier,
reasoningEffort,
log.CacheTTLOverridden,
createdAt,
},
}
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
return false, err
}
log.RateMultiplier = rateMultiplier
return false, nil
} else {
return false, err
}
}
func usageLogBatchKey(requestID string, apiKeyID int64) string {
return requestID + "\x1f" + strconv.FormatInt(apiKeyID, 10)
}
func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateResult) {
if ch == nil {
return
}
select {
case ch <- res:
default:
}
log.RateMultiplier = rateMultiplier
return true, nil
}
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {

View File

@@ -4,6 +4,8 @@ package repository
import (
"context"
"fmt"
"sync"
"testing"
"time"
@@ -14,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -84,6 +87,102 @@ func (s *UsageLogRepoSuite) TestCreate() {
s.Require().NotZero(log.ID)
}
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
const total = 16
results := make([]bool, total)
errs := make([]error, total)
logs := make([]*service.UsageLog, total)
var wg sync.WaitGroup
wg.Add(total)
for i := 0; i < total; i++ {
i := i
logs[i] = &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
go func() {
defer wg.Done()
results[i], errs[i] = repo.Create(ctx, logs[i])
}()
}
wg.Wait()
for i := 0; i < total; i++ {
require.NoError(t, errs[i])
require.True(t, results[i])
require.NotZero(t, logs[i].ID)
}
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
require.Equal(t, total, count)
}
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
inserted1, err1 := repo.Create(ctx, log1)
inserted2, err2 := repo.Create(ctx, log2)
require.NoError(t, err1)
require.NoError(t, err2)
require.True(t, inserted1)
require.False(t, inserted2)
require.Equal(t, log1.ID, log2.ID)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})