mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
Batch usage log writes in repository
This commit is contained in:
@@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -43,6 +45,39 @@ func safeDateFormat(granularity string) string {
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
db *sql.DB
|
||||
|
||||
createBatchOnce sync.Once
|
||||
createBatchCh chan usageLogCreateRequest
|
||||
}
|
||||
|
||||
const (
|
||||
usageLogCreateBatchMaxSize = 64
|
||||
usageLogCreateBatchWindow = 3 * time.Millisecond
|
||||
usageLogCreateBatchQueueCap = 4096
|
||||
)
|
||||
|
||||
type usageLogCreateRequest struct {
|
||||
log *service.UsageLog
|
||||
resultCh chan usageLogCreateResult
|
||||
}
|
||||
|
||||
type usageLogCreateResult struct {
|
||||
inserted bool
|
||||
err error
|
||||
}
|
||||
|
||||
type usageLogInsertPrepared struct {
|
||||
createdAt time.Time
|
||||
requestID string
|
||||
rateMultiplier float64
|
||||
requestType int16
|
||||
args []any
|
||||
}
|
||||
|
||||
type usageLogBatchState struct {
|
||||
ID int64
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
|
||||
@@ -51,7 +86,11 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog
|
||||
|
||||
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
|
||||
// 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。
|
||||
return &usageLogRepository{client: client, sql: sqlq}
|
||||
repo := &usageLogRepository{client: client, sql: sqlq}
|
||||
if db, ok := sqlq.(*sql.DB); ok {
|
||||
repo.db = db
|
||||
}
|
||||
return repo
|
||||
}
|
||||
|
||||
// getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤)
|
||||
@@ -82,24 +121,25 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。
|
||||
// 无事务时回退到默认的 *sql.DB 执行器。
|
||||
sqlq := r.sql
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
sqlq = tx.Client()
|
||||
return r.createSingle(ctx, tx.Client(), log)
|
||||
}
|
||||
|
||||
createdAt := log.CreatedAt
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now()
|
||||
if r.db == nil {
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(log.RequestID)
|
||||
if requestID == "" {
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
}
|
||||
log.RequestID = requestID
|
||||
return r.createBatched(ctx, log)
|
||||
}
|
||||
|
||||
rateMultiplier := log.RateMultiplier
|
||||
log.SyncRequestTypeAndLegacyFields()
|
||||
requestType := int16(log.RequestType)
|
||||
func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) {
|
||||
prepared := prepareUsageLogInsert(log)
|
||||
if sqlq == nil {
|
||||
sqlq = r.sql
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO usage_logs (
|
||||
@@ -151,6 +191,336 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
RETURNING id, created_at
|
||||
`
|
||||
|
||||
if err := scanSingleRow(ctx, sqlq, query, prepared.args, &log.ID, &log.CreatedAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) && prepared.requestID != "" {
|
||||
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
|
||||
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{prepared.requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
|
||||
return false, err
|
||||
}
|
||||
log.RateMultiplier = prepared.rateMultiplier
|
||||
return false, nil
|
||||
} else {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
log.RateMultiplier = prepared.rateMultiplier
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) createBatched(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
if r.db == nil {
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
}
|
||||
r.ensureCreateBatcher()
|
||||
if r.createBatchCh == nil {
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
}
|
||||
|
||||
req := usageLogCreateRequest{
|
||||
log: log,
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
case r.createBatchCh <- req:
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
default:
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
}
|
||||
|
||||
select {
|
||||
case res := <-req.resultCh:
|
||||
return res.inserted, res.err
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ensureCreateBatcher() {
|
||||
if r == nil || r.db == nil {
|
||||
return
|
||||
}
|
||||
r.createBatchOnce.Do(func() {
|
||||
r.createBatchCh = make(chan usageLogCreateRequest, usageLogCreateBatchQueueCap)
|
||||
go r.runCreateBatcher(r.db)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
|
||||
for {
|
||||
first, ok := <-r.createBatchCh
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
batch := make([]usageLogCreateRequest, 0, usageLogCreateBatchMaxSize)
|
||||
batch = append(batch, first)
|
||||
|
||||
timer := time.NewTimer(usageLogCreateBatchWindow)
|
||||
batchLoop:
|
||||
for len(batch) < usageLogCreateBatchMaxSize {
|
||||
select {
|
||||
case req, ok := <-r.createBatchCh:
|
||||
if !ok {
|
||||
break batchLoop
|
||||
}
|
||||
batch = append(batch, req)
|
||||
case <-timer.C:
|
||||
break batchLoop
|
||||
}
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
r.flushCreateBatch(db, batch)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
uniqueOrder := make([]string, 0, len(batch))
|
||||
preparedByKey := make(map[string]usageLogInsertPrepared, len(batch))
|
||||
requestsByKey := make(map[string][]usageLogCreateRequest, len(batch))
|
||||
fallback := make([]usageLogCreateRequest, 0)
|
||||
|
||||
for _, req := range batch {
|
||||
if req.log == nil {
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil})
|
||||
continue
|
||||
}
|
||||
prepared := prepareUsageLogInsert(req.log)
|
||||
if prepared.requestID == "" {
|
||||
fallback = append(fallback, req)
|
||||
continue
|
||||
}
|
||||
key := usageLogBatchKey(prepared.requestID, req.log.APIKeyID)
|
||||
if _, exists := requestsByKey[key]; !exists {
|
||||
uniqueOrder = append(uniqueOrder, key)
|
||||
preparedByKey[key] = prepared
|
||||
}
|
||||
requestsByKey[key] = append(requestsByKey[key], req)
|
||||
}
|
||||
|
||||
if len(uniqueOrder) > 0 {
|
||||
insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
|
||||
if err != nil {
|
||||
for _, key := range uniqueOrder {
|
||||
fallback = append(fallback, requestsByKey[key]...)
|
||||
}
|
||||
} else {
|
||||
for _, key := range uniqueOrder {
|
||||
reqs := requestsByKey[key]
|
||||
state, ok := stateMap[key]
|
||||
if !ok {
|
||||
for _, req := range reqs {
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
|
||||
inserted: false,
|
||||
err: fmt.Errorf("usage log batch state missing for key=%s", key),
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
for idx, req := range reqs {
|
||||
req.log.ID = state.ID
|
||||
req.log.CreatedAt = state.CreatedAt
|
||||
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
|
||||
inserted: idx == 0 && insertedMap[key],
|
||||
err: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(fallback) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
for _, req := range fallback {
|
||||
inserted, err := r.createSingle(fallbackCtx, db, req.log)
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) {
|
||||
if len(keys) == 0 {
|
||||
return map[string]bool{}, map[string]usageLogBatchState{}, nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey)
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
insertedMap := make(map[string]bool, len(keys))
|
||||
for rows.Next() {
|
||||
var (
|
||||
requestID string
|
||||
apiKeyID int64
|
||||
id int64
|
||||
createdAt time.Time
|
||||
)
|
||||
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
_ = rows.Close()
|
||||
|
||||
stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return insertedMap, stateMap, nil
|
||||
}
|
||||
|
||||
func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) {
|
||||
var query strings.Builder
|
||||
_, _ = query.WriteString(`
|
||||
INSERT INTO usage_logs (
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES `)
|
||||
|
||||
args := make([]any, 0, len(keys)*36)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("(")
|
||||
prepared := preparedByKey[key]
|
||||
for i := 0; i < len(prepared.args); i++ {
|
||||
if i > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("$")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
argPos++
|
||||
}
|
||||
_, _ = query.WriteString(")")
|
||||
args = append(args, prepared.args...)
|
||||
}
|
||||
_, _ = query.WriteString(`
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING request_id, api_key_id, id, created_at
|
||||
`)
|
||||
return query.String(), args
|
||||
}
|
||||
|
||||
func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) {
|
||||
var query strings.Builder
|
||||
_, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `)
|
||||
args := make([]any, 0, len(keys)*2)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
_, _ = query.WriteString(" OR ")
|
||||
}
|
||||
prepared := preparedByKey[key]
|
||||
apiKeyID := prepared.args[1]
|
||||
_, _ = query.WriteString("(request_id = $")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
_, _ = query.WriteString(" AND api_key_id = $")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos + 1))
|
||||
_, _ = query.WriteString(")")
|
||||
args = append(args, prepared.requestID, apiKeyID)
|
||||
argPos += 2
|
||||
}
|
||||
|
||||
rows, err := db.QueryContext(ctx, query.String(), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
stateMap := make(map[string]usageLogBatchState, len(keys))
|
||||
for rows.Next() {
|
||||
var (
|
||||
requestID string
|
||||
apiKeyID int64
|
||||
id int64
|
||||
createdAt time.Time
|
||||
)
|
||||
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{
|
||||
ID: id,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stateMap, nil
|
||||
}
|
||||
|
||||
func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
createdAt := log.CreatedAt
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now()
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(log.RequestID)
|
||||
log.RequestID = requestID
|
||||
|
||||
rateMultiplier := log.RateMultiplier
|
||||
log.SyncRequestTypeAndLegacyFields()
|
||||
requestType := int16(log.RequestType)
|
||||
|
||||
groupID := nullInt64(log.GroupID)
|
||||
subscriptionID := nullInt64(log.SubscriptionID)
|
||||
duration := nullInt(log.DurationMs)
|
||||
@@ -167,58 +537,64 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
requestIDArg = requestID
|
||||
}
|
||||
|
||||
args := []any{
|
||||
log.UserID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
requestIDArg,
|
||||
log.Model,
|
||||
groupID,
|
||||
subscriptionID,
|
||||
log.InputTokens,
|
||||
log.OutputTokens,
|
||||
log.CacheCreationTokens,
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
log.CacheReadCost,
|
||||
log.TotalCost,
|
||||
log.ActualCost,
|
||||
rateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
requestType,
|
||||
log.Stream,
|
||||
log.OpenAIWSMode,
|
||||
duration,
|
||||
firstToken,
|
||||
userAgent,
|
||||
ipAddress,
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
mediaType,
|
||||
serviceTier,
|
||||
reasoningEffort,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
return usageLogInsertPrepared{
|
||||
createdAt: createdAt,
|
||||
requestID: requestID,
|
||||
rateMultiplier: rateMultiplier,
|
||||
requestType: requestType,
|
||||
args: []any{
|
||||
log.UserID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
requestIDArg,
|
||||
log.Model,
|
||||
groupID,
|
||||
subscriptionID,
|
||||
log.InputTokens,
|
||||
log.OutputTokens,
|
||||
log.CacheCreationTokens,
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
log.CacheReadCost,
|
||||
log.TotalCost,
|
||||
log.ActualCost,
|
||||
rateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
requestType,
|
||||
log.Stream,
|
||||
log.OpenAIWSMode,
|
||||
duration,
|
||||
firstToken,
|
||||
userAgent,
|
||||
ipAddress,
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
mediaType,
|
||||
serviceTier,
|
||||
reasoningEffort,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
},
|
||||
}
|
||||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
|
||||
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
|
||||
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
|
||||
return false, err
|
||||
}
|
||||
log.RateMultiplier = rateMultiplier
|
||||
return false, nil
|
||||
} else {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
func usageLogBatchKey(requestID string, apiKeyID int64) string {
|
||||
return requestID + "\x1f" + strconv.FormatInt(apiKeyID, 10)
|
||||
}
|
||||
|
||||
func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateResult) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- res:
|
||||
default:
|
||||
}
|
||||
log.RateMultiplier = rateMultiplier
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||||
|
||||
@@ -4,6 +4,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -14,6 +16,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
@@ -84,6 +87,102 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
||||
s.Require().NotZero(log.ID)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
|
||||
|
||||
const total = 16
|
||||
results := make([]bool, total)
|
||||
errs := make([]error, total)
|
||||
logs := make([]*service.UsageLog, total)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(total)
|
||||
for i := 0; i < total; i++ {
|
||||
i := i
|
||||
logs[i] = &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10 + i,
|
||||
OutputTokens: 20 + i,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i], errs[i] = repo.Create(ctx, logs[i])
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
require.NoError(t, errs[i])
|
||||
require.True(t, results[i])
|
||||
require.NotZero(t, logs[i].ID)
|
||||
}
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
|
||||
require.Equal(t, total, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
inserted1, err1 := repo.Create(ctx, log1)
|
||||
inserted2, err2 := repo.Create(ctx, log2)
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
require.True(t, inserted1)
|
||||
require.False(t, inserted2)
|
||||
require.Equal(t, log1.ID, log2.ID)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
|
||||
Reference in New Issue
Block a user