mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-18 22:04:45 +08:00
Batch usage log writes in repository
This commit is contained in:
@@ -6,7 +6,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
@@ -43,6 +45,39 @@ func safeDateFormat(granularity string) string {
|
|||||||
type usageLogRepository struct {
|
type usageLogRepository struct {
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
sql sqlExecutor
|
sql sqlExecutor
|
||||||
|
db *sql.DB
|
||||||
|
|
||||||
|
createBatchOnce sync.Once
|
||||||
|
createBatchCh chan usageLogCreateRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
usageLogCreateBatchMaxSize = 64
|
||||||
|
usageLogCreateBatchWindow = 3 * time.Millisecond
|
||||||
|
usageLogCreateBatchQueueCap = 4096
|
||||||
|
)
|
||||||
|
|
||||||
|
type usageLogCreateRequest struct {
|
||||||
|
log *service.UsageLog
|
||||||
|
resultCh chan usageLogCreateResult
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageLogCreateResult struct {
|
||||||
|
inserted bool
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageLogInsertPrepared struct {
|
||||||
|
createdAt time.Time
|
||||||
|
requestID string
|
||||||
|
rateMultiplier float64
|
||||||
|
requestType int16
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageLogBatchState struct {
|
||||||
|
ID int64
|
||||||
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
|
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
|
||||||
@@ -51,7 +86,11 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog
|
|||||||
|
|
||||||
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
|
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
|
||||||
// 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。
|
// 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。
|
||||||
return &usageLogRepository{client: client, sql: sqlq}
|
repo := &usageLogRepository{client: client, sql: sqlq}
|
||||||
|
if db, ok := sqlq.(*sql.DB); ok {
|
||||||
|
repo.db = db
|
||||||
|
}
|
||||||
|
return repo
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤)
|
// getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤)
|
||||||
@@ -82,24 +121,25 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。
|
|
||||||
// 无事务时回退到默认的 *sql.DB 执行器。
|
|
||||||
sqlq := r.sql
|
|
||||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||||
sqlq = tx.Client()
|
return r.createSingle(ctx, tx.Client(), log)
|
||||||
}
|
}
|
||||||
|
if r.db == nil {
|
||||||
createdAt := log.CreatedAt
|
return r.createSingle(ctx, r.sql, log)
|
||||||
if createdAt.IsZero() {
|
|
||||||
createdAt = time.Now()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
requestID := strings.TrimSpace(log.RequestID)
|
requestID := strings.TrimSpace(log.RequestID)
|
||||||
|
if requestID == "" {
|
||||||
|
return r.createSingle(ctx, r.sql, log)
|
||||||
|
}
|
||||||
log.RequestID = requestID
|
log.RequestID = requestID
|
||||||
|
return r.createBatched(ctx, log)
|
||||||
|
}
|
||||||
|
|
||||||
rateMultiplier := log.RateMultiplier
|
func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) {
|
||||||
log.SyncRequestTypeAndLegacyFields()
|
prepared := prepareUsageLogInsert(log)
|
||||||
requestType := int16(log.RequestType)
|
if sqlq == nil {
|
||||||
|
sqlq = r.sql
|
||||||
|
}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO usage_logs (
|
INSERT INTO usage_logs (
|
||||||
@@ -151,6 +191,336 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
RETURNING id, created_at
|
RETURNING id, created_at
|
||||||
`
|
`
|
||||||
|
|
||||||
|
if err := scanSingleRow(ctx, sqlq, query, prepared.args, &log.ID, &log.CreatedAt); err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) && prepared.requestID != "" {
|
||||||
|
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
|
||||||
|
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{prepared.requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
log.RateMultiplier = prepared.rateMultiplier
|
||||||
|
return false, nil
|
||||||
|
} else {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.RateMultiplier = prepared.rateMultiplier
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) createBatched(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||||
|
if r.db == nil {
|
||||||
|
return r.createSingle(ctx, r.sql, log)
|
||||||
|
}
|
||||||
|
r.ensureCreateBatcher()
|
||||||
|
if r.createBatchCh == nil {
|
||||||
|
return r.createSingle(ctx, r.sql, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := usageLogCreateRequest{
|
||||||
|
log: log,
|
||||||
|
resultCh: make(chan usageLogCreateResult, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case r.createBatchCh <- req:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false, ctx.Err()
|
||||||
|
default:
|
||||||
|
return r.createSingle(ctx, r.sql, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case res := <-req.resultCh:
|
||||||
|
return res.inserted, res.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) ensureCreateBatcher() {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.createBatchOnce.Do(func() {
|
||||||
|
r.createBatchCh = make(chan usageLogCreateRequest, usageLogCreateBatchQueueCap)
|
||||||
|
go r.runCreateBatcher(r.db)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
|
||||||
|
for {
|
||||||
|
first, ok := <-r.createBatchCh
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
batch := make([]usageLogCreateRequest, 0, usageLogCreateBatchMaxSize)
|
||||||
|
batch = append(batch, first)
|
||||||
|
|
||||||
|
timer := time.NewTimer(usageLogCreateBatchWindow)
|
||||||
|
batchLoop:
|
||||||
|
for len(batch) < usageLogCreateBatchMaxSize {
|
||||||
|
select {
|
||||||
|
case req, ok := <-r.createBatchCh:
|
||||||
|
if !ok {
|
||||||
|
break batchLoop
|
||||||
|
}
|
||||||
|
batch = append(batch, req)
|
||||||
|
case <-timer.C:
|
||||||
|
break batchLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.flushCreateBatch(db, batch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) {
|
||||||
|
if len(batch) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
uniqueOrder := make([]string, 0, len(batch))
|
||||||
|
preparedByKey := make(map[string]usageLogInsertPrepared, len(batch))
|
||||||
|
requestsByKey := make(map[string][]usageLogCreateRequest, len(batch))
|
||||||
|
fallback := make([]usageLogCreateRequest, 0)
|
||||||
|
|
||||||
|
for _, req := range batch {
|
||||||
|
if req.log == nil {
|
||||||
|
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
prepared := prepareUsageLogInsert(req.log)
|
||||||
|
if prepared.requestID == "" {
|
||||||
|
fallback = append(fallback, req)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := usageLogBatchKey(prepared.requestID, req.log.APIKeyID)
|
||||||
|
if _, exists := requestsByKey[key]; !exists {
|
||||||
|
uniqueOrder = append(uniqueOrder, key)
|
||||||
|
preparedByKey[key] = prepared
|
||||||
|
}
|
||||||
|
requestsByKey[key] = append(requestsByKey[key], req)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(uniqueOrder) > 0 {
|
||||||
|
insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
|
||||||
|
if err != nil {
|
||||||
|
for _, key := range uniqueOrder {
|
||||||
|
fallback = append(fallback, requestsByKey[key]...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, key := range uniqueOrder {
|
||||||
|
reqs := requestsByKey[key]
|
||||||
|
state, ok := stateMap[key]
|
||||||
|
if !ok {
|
||||||
|
for _, req := range reqs {
|
||||||
|
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
|
||||||
|
inserted: false,
|
||||||
|
err: fmt.Errorf("usage log batch state missing for key=%s", key),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for idx, req := range reqs {
|
||||||
|
req.log.ID = state.ID
|
||||||
|
req.log.CreatedAt = state.CreatedAt
|
||||||
|
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
|
||||||
|
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
|
||||||
|
inserted: idx == 0 && insertedMap[key],
|
||||||
|
err: nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(fallback) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
for _, req := range fallback {
|
||||||
|
inserted, err := r.createSingle(fallbackCtx, db, req.log)
|
||||||
|
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return map[string]bool{}, map[string]usageLogBatchState{}, nil
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey)
|
||||||
|
rows, err := db.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
insertedMap := make(map[string]bool, len(keys))
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
requestID string
|
||||||
|
apiKeyID int64
|
||||||
|
id int64
|
||||||
|
createdAt time.Time
|
||||||
|
)
|
||||||
|
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
_ = rows.Close()
|
||||||
|
|
||||||
|
stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return insertedMap, stateMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) {
|
||||||
|
var query strings.Builder
|
||||||
|
_, _ = query.WriteString(`
|
||||||
|
INSERT INTO usage_logs (
|
||||||
|
user_id,
|
||||||
|
api_key_id,
|
||||||
|
account_id,
|
||||||
|
request_id,
|
||||||
|
model,
|
||||||
|
group_id,
|
||||||
|
subscription_id,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
cache_creation_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
|
cache_creation_5m_tokens,
|
||||||
|
cache_creation_1h_tokens,
|
||||||
|
input_cost,
|
||||||
|
output_cost,
|
||||||
|
cache_creation_cost,
|
||||||
|
cache_read_cost,
|
||||||
|
total_cost,
|
||||||
|
actual_cost,
|
||||||
|
rate_multiplier,
|
||||||
|
account_rate_multiplier,
|
||||||
|
billing_type,
|
||||||
|
request_type,
|
||||||
|
stream,
|
||||||
|
openai_ws_mode,
|
||||||
|
duration_ms,
|
||||||
|
first_token_ms,
|
||||||
|
user_agent,
|
||||||
|
ip_address,
|
||||||
|
image_count,
|
||||||
|
image_size,
|
||||||
|
media_type,
|
||||||
|
service_tier,
|
||||||
|
reasoning_effort,
|
||||||
|
cache_ttl_overridden,
|
||||||
|
created_at
|
||||||
|
) VALUES `)
|
||||||
|
|
||||||
|
args := make([]any, 0, len(keys)*36)
|
||||||
|
argPos := 1
|
||||||
|
for idx, key := range keys {
|
||||||
|
if idx > 0 {
|
||||||
|
_, _ = query.WriteString(",")
|
||||||
|
}
|
||||||
|
_, _ = query.WriteString("(")
|
||||||
|
prepared := preparedByKey[key]
|
||||||
|
for i := 0; i < len(prepared.args); i++ {
|
||||||
|
if i > 0 {
|
||||||
|
_, _ = query.WriteString(",")
|
||||||
|
}
|
||||||
|
_, _ = query.WriteString("$")
|
||||||
|
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||||
|
argPos++
|
||||||
|
}
|
||||||
|
_, _ = query.WriteString(")")
|
||||||
|
args = append(args, prepared.args...)
|
||||||
|
}
|
||||||
|
_, _ = query.WriteString(`
|
||||||
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
|
RETURNING request_id, api_key_id, id, created_at
|
||||||
|
`)
|
||||||
|
return query.String(), args
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) {
|
||||||
|
var query strings.Builder
|
||||||
|
_, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `)
|
||||||
|
args := make([]any, 0, len(keys)*2)
|
||||||
|
argPos := 1
|
||||||
|
for idx, key := range keys {
|
||||||
|
if idx > 0 {
|
||||||
|
_, _ = query.WriteString(" OR ")
|
||||||
|
}
|
||||||
|
prepared := preparedByKey[key]
|
||||||
|
apiKeyID := prepared.args[1]
|
||||||
|
_, _ = query.WriteString("(request_id = $")
|
||||||
|
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||||
|
_, _ = query.WriteString(" AND api_key_id = $")
|
||||||
|
_, _ = query.WriteString(strconv.Itoa(argPos + 1))
|
||||||
|
_, _ = query.WriteString(")")
|
||||||
|
args = append(args, prepared.requestID, apiKeyID)
|
||||||
|
argPos += 2
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.QueryContext(ctx, query.String(), args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
stateMap := make(map[string]usageLogBatchState, len(keys))
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
requestID string
|
||||||
|
apiKeyID int64
|
||||||
|
id int64
|
||||||
|
createdAt time.Time
|
||||||
|
)
|
||||||
|
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{
|
||||||
|
ID: id,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return stateMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||||
|
createdAt := log.CreatedAt
|
||||||
|
if createdAt.IsZero() {
|
||||||
|
createdAt = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := strings.TrimSpace(log.RequestID)
|
||||||
|
log.RequestID = requestID
|
||||||
|
|
||||||
|
rateMultiplier := log.RateMultiplier
|
||||||
|
log.SyncRequestTypeAndLegacyFields()
|
||||||
|
requestType := int16(log.RequestType)
|
||||||
|
|
||||||
groupID := nullInt64(log.GroupID)
|
groupID := nullInt64(log.GroupID)
|
||||||
subscriptionID := nullInt64(log.SubscriptionID)
|
subscriptionID := nullInt64(log.SubscriptionID)
|
||||||
duration := nullInt(log.DurationMs)
|
duration := nullInt(log.DurationMs)
|
||||||
@@ -167,58 +537,64 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
requestIDArg = requestID
|
requestIDArg = requestID
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []any{
|
return usageLogInsertPrepared{
|
||||||
log.UserID,
|
createdAt: createdAt,
|
||||||
log.APIKeyID,
|
requestID: requestID,
|
||||||
log.AccountID,
|
rateMultiplier: rateMultiplier,
|
||||||
requestIDArg,
|
requestType: requestType,
|
||||||
log.Model,
|
args: []any{
|
||||||
groupID,
|
log.UserID,
|
||||||
subscriptionID,
|
log.APIKeyID,
|
||||||
log.InputTokens,
|
log.AccountID,
|
||||||
log.OutputTokens,
|
requestIDArg,
|
||||||
log.CacheCreationTokens,
|
log.Model,
|
||||||
log.CacheReadTokens,
|
groupID,
|
||||||
log.CacheCreation5mTokens,
|
subscriptionID,
|
||||||
log.CacheCreation1hTokens,
|
log.InputTokens,
|
||||||
log.InputCost,
|
log.OutputTokens,
|
||||||
log.OutputCost,
|
log.CacheCreationTokens,
|
||||||
log.CacheCreationCost,
|
log.CacheReadTokens,
|
||||||
log.CacheReadCost,
|
log.CacheCreation5mTokens,
|
||||||
log.TotalCost,
|
log.CacheCreation1hTokens,
|
||||||
log.ActualCost,
|
log.InputCost,
|
||||||
rateMultiplier,
|
log.OutputCost,
|
||||||
log.AccountRateMultiplier,
|
log.CacheCreationCost,
|
||||||
log.BillingType,
|
log.CacheReadCost,
|
||||||
requestType,
|
log.TotalCost,
|
||||||
log.Stream,
|
log.ActualCost,
|
||||||
log.OpenAIWSMode,
|
rateMultiplier,
|
||||||
duration,
|
log.AccountRateMultiplier,
|
||||||
firstToken,
|
log.BillingType,
|
||||||
userAgent,
|
requestType,
|
||||||
ipAddress,
|
log.Stream,
|
||||||
log.ImageCount,
|
log.OpenAIWSMode,
|
||||||
imageSize,
|
duration,
|
||||||
mediaType,
|
firstToken,
|
||||||
serviceTier,
|
userAgent,
|
||||||
reasoningEffort,
|
ipAddress,
|
||||||
log.CacheTTLOverridden,
|
log.ImageCount,
|
||||||
createdAt,
|
imageSize,
|
||||||
|
mediaType,
|
||||||
|
serviceTier,
|
||||||
|
reasoningEffort,
|
||||||
|
log.CacheTTLOverridden,
|
||||||
|
createdAt,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
}
|
||||||
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
|
|
||||||
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
|
func usageLogBatchKey(requestID string, apiKeyID int64) string {
|
||||||
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
|
return requestID + "\x1f" + strconv.FormatInt(apiKeyID, 10)
|
||||||
return false, err
|
}
|
||||||
}
|
|
||||||
log.RateMultiplier = rateMultiplier
|
func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateResult) {
|
||||||
return false, nil
|
if ch == nil {
|
||||||
} else {
|
return
|
||||||
return false, err
|
}
|
||||||
}
|
select {
|
||||||
|
case ch <- res:
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
log.RateMultiplier = rateMultiplier
|
|
||||||
return true, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -14,6 +16,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -84,6 +87,102 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
|||||||
s.Require().NotZero(log.ID)
|
s.Require().NotZero(log.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
|
||||||
|
|
||||||
|
const total = 16
|
||||||
|
results := make([]bool, total)
|
||||||
|
errs := make([]error, total)
|
||||||
|
logs := make([]*service.UsageLog, total)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(total)
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
i := i
|
||||||
|
logs[i] = &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10 + i,
|
||||||
|
OutputTokens: 20 + i,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
results[i], errs[i] = repo.Create(ctx, logs[i])
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
require.NoError(t, errs[i])
|
||||||
|
require.True(t, results[i])
|
||||||
|
require.NotZero(t, logs[i].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
|
||||||
|
require.Equal(t, total, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
|
||||||
|
log1 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
log2 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
inserted1, err1 := repo.Create(ctx, log1)
|
||||||
|
inserted2, err2 := repo.Create(ctx, log2)
|
||||||
|
require.NoError(t, err1)
|
||||||
|
require.NoError(t, err2)
|
||||||
|
require.True(t, inserted1)
|
||||||
|
require.False(t, inserted2)
|
||||||
|
require.Equal(t, log1.ID, log2.ID)
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||||
|
|||||||
Reference in New Issue
Block a user