diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index f42a746f..de3ad378 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -3,6 +3,28 @@ package usagestats import "time" +const ( + ModelSourceRequested = "requested" + ModelSourceUpstream = "upstream" + ModelSourceMapping = "mapping" +) + +func IsValidModelSource(source string) bool { + switch source { + case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping: + return true + default: + return false + } +} + +func NormalizeModelSource(source string) string { + if IsValidModelSource(source) { + return source + } + return ModelSourceRequested +} + // DashboardStats 仪表盘统计 type DashboardStats struct { // 用户统计 @@ -143,6 +165,7 @@ type UserBreakdownItem struct { type UserBreakdownDimension struct { GroupID int64 // filter by group_id (>0 to enable) Model string // filter by model name (non-empty to enable) + ModelType string // "requested", "upstream", or "mapping" Endpoint string // filter by endpoint value (non-empty to enable) EndpointType string // "inbound", "upstream", or "path" } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index dcdaeaee..61a54267 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,7 +28,7 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" var usageLogInsertArgTypes = [...]string{ "bigint", @@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{ "bigint", "text", "text", + "text", "bigint", "bigint", "integer", @@ -277,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -311,12 +313,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 + $1, $2, $3, $4, $5, $6, + $7, $8, + $9, $10, $11, $12, + $13, $14, + $15, $16, $17, $18, $19, $20, + $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -707,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -742,7 +745,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage created_at ) AS (VALUES `) - args := make([]any, 0, len(keys)*38) + args := make([]any, 0, len(keys)*39) argPos := 1 for idx, key := range keys { if idx > 0 { @@ -776,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -816,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -896,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -931,7 +937,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*38) + args := make([]any, 0, len(preparedList)*39) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -962,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1002,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1050,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1084,12 +1093,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 + $1, $2, $3, $4, $5, $6, + $7, $8, + $9, $10, $11, $12, + $13, $14, + $15, $16, $17, $18, $19, $20, + $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1121,6 +1130,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint) + upstreamModel := nullString(log.UpstreamModel) var requestIDArg any if requestID != "" { @@ -1138,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { log.AccountID, requestIDArg, log.Model, + upstreamModel, groupID, subscriptionID, log.InputTokens, @@ -2864,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st // GetModelStatsWithFilters returns model statistics with optional filters func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { + return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested) +} + +// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension. +// source: requested | upstream | mapping. +func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) { + return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source) +} + +func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" } + modelExpr := resolveModelDimensionExpression(source) query := fmt.Sprintf(` SELECT - model, + %s as model, COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, @@ -2883,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start %s FROM usage_logs WHERE created_at >= $1 AND created_at < $2 - `, actualCostExpr) + `, modelExpr, actualCostExpr) args := []any{startTime, endTime} if userID > 0 { @@ -2907,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) } - query += " GROUP BY model ORDER BY total_tokens DESC" + query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr) rows, err := r.sql.QueryContext(ctx, query, args...) if err != nil { @@ -3021,7 +3043,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim args = append(args, dim.GroupID) } if dim.Model != "" { - query += fmt.Sprintf(" AND ul.model = $%d", len(args)+1) + query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1) args = append(args, dim.Model) } if dim.Endpoint != "" { @@ -3067,6 +3089,18 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim return results, nil } +// resolveModelDimensionExpression maps model source type to a safe SQL expression. +func resolveModelDimensionExpression(modelType string) string { + switch usagestats.NormalizeModelSource(modelType) { + case usagestats.ModelSourceUpstream: + return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)" + case usagestats.ModelSourceMapping: + return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))" + default: + return "model" + } +} + // resolveEndpointColumn maps endpoint type to the corresponding DB column name. func resolveEndpointColumn(endpointType string) string { switch endpointType { @@ -3819,6 +3853,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e accountID int64 requestID sql.NullString model string + upstreamModel sql.NullString groupID sql.NullInt64 subscriptionID sql.NullInt64 inputTokens int @@ -3861,6 +3896,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &accountID, &requestID, &model, + &upstreamModel, &groupID, &subscriptionID, &inputTokens, @@ -3973,6 +4009,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if upstreamEndpoint.Valid { log.UpstreamEndpoint = &upstreamEndpoint.String } + if upstreamModel.Valid { + log.UpstreamModel = &upstreamModel.String + } return log, nil } diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index ad29990f..1c960fdf 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi return stats, nil } +func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) { + normalizedSource := usagestats.NormalizeModelSource(modelSource) + if normalizedSource == usagestats.ModelSourceRequested { + return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + } + + type modelStatsBySourceRepo interface { + GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error) + } + + if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok { + stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource) + if err != nil { + return nil, fmt.Errorf("get model stats with filters by source: %w", err) + } + return stats, nil + } + + return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) +} + func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil {