From 7134266acfae3e3bfce4d983b6258afc6624526c Mon Sep 17 00:00:00 2001 From: Ethan0x0000 <3352979663@qq.com> Date: Tue, 17 Mar 2026 19:25:52 +0800 Subject: [PATCH] feat(dashboard): add model source dimension to stats queries Support querying model statistics by 'requested', 'upstream', or 'mapping' dimension. Add resolveModelDimensionExpression for safe SQL expression generation, IsValidModelSource whitelist validator, and NormalizeModelSource fallback. Repository persists and scans upstream_model in all insert/select paths. --- .../pkg/usagestats/usage_log_types.go | 23 ++++++ backend/internal/repository/usage_log_repo.go | 77 ++++++++++++++----- backend/internal/service/dashboard_service.go | 21 +++++ 3 files changed, 102 insertions(+), 19 deletions(-) diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index f42a746f..de3ad378 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -3,6 +3,28 @@ package usagestats import "time" +const ( + ModelSourceRequested = "requested" + ModelSourceUpstream = "upstream" + ModelSourceMapping = "mapping" +) + +func IsValidModelSource(source string) bool { + switch source { + case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping: + return true + default: + return false + } +} + +func NormalizeModelSource(source string) string { + if IsValidModelSource(source) { + return source + } + return ModelSourceRequested +} + // DashboardStats 仪表盘统计 type DashboardStats struct { // 用户统计 @@ -143,6 +165,7 @@ type UserBreakdownItem struct { type UserBreakdownDimension struct { GroupID int64 // filter by group_id (>0 to enable) Model string // filter by model name (non-empty to enable) + ModelType string // "requested", "upstream", or "mapping" Endpoint string // filter by endpoint value (non-empty to enable) EndpointType string // "inbound", "upstream", or "path" } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index dcdaeaee..61a54267 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,7 +28,7 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" var usageLogInsertArgTypes = [...]string{ "bigint", @@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{ "bigint", "text", "text", + "text", "bigint", "bigint", "integer", @@ -277,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -311,12 +313,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 + $1, $2, $3, $4, $5, $6, + $7, $8, + $9, $10, $11, $12, + $13, $14, + $15, $16, $17, $18, $19, $20, + $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -707,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -742,7 +745,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage created_at ) AS (VALUES `) - args := make([]any, 0, len(keys)*38) + args := make([]any, 0, len(keys)*39) argPos := 1 for idx, key := range keys { if idx > 0 { @@ -776,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -816,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -896,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -931,7 +937,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*38) + args := make([]any, 0, len(preparedList)*39) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -962,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1002,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1050,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1084,12 +1093,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 + $1, $2, $3, $4, $5, $6, + $7, $8, + $9, $10, $11, $12, + $13, $14, + $15, $16, $17, $18, $19, $20, + $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1121,6 +1130,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint) + upstreamModel := nullString(log.UpstreamModel) var requestIDArg any if requestID != "" { @@ -1138,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { log.AccountID, requestIDArg, log.Model, + upstreamModel, groupID, subscriptionID, log.InputTokens, @@ -2864,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st // GetModelStatsWithFilters returns model statistics with optional filters func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { + return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested) +} + +// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension. +// source: requested | upstream | mapping. +func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) { + return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source) +} + +func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" } + modelExpr := resolveModelDimensionExpression(source) query := fmt.Sprintf(` SELECT - model, + %s as model, COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, @@ -2883,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start %s FROM usage_logs WHERE created_at >= $1 AND created_at < $2 - `, actualCostExpr) + `, modelExpr, actualCostExpr) args := []any{startTime, endTime} if userID > 0 { @@ -2907,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) } - query += " GROUP BY model ORDER BY total_tokens DESC" + query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr) rows, err := r.sql.QueryContext(ctx, query, args...) if err != nil { @@ -3021,7 +3043,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim args = append(args, dim.GroupID) } if dim.Model != "" { - query += fmt.Sprintf(" AND ul.model = $%d", len(args)+1) + query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1) args = append(args, dim.Model) } if dim.Endpoint != "" { @@ -3067,6 +3089,18 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim return results, nil } +// resolveModelDimensionExpression maps model source type to a safe SQL expression. +func resolveModelDimensionExpression(modelType string) string { + switch usagestats.NormalizeModelSource(modelType) { + case usagestats.ModelSourceUpstream: + return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)" + case usagestats.ModelSourceMapping: + return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))" + default: + return "model" + } +} + // resolveEndpointColumn maps endpoint type to the corresponding DB column name. func resolveEndpointColumn(endpointType string) string { switch endpointType { @@ -3819,6 +3853,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e accountID int64 requestID sql.NullString model string + upstreamModel sql.NullString groupID sql.NullInt64 subscriptionID sql.NullInt64 inputTokens int @@ -3861,6 +3896,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &accountID, &requestID, &model, + &upstreamModel, &groupID, &subscriptionID, &inputTokens, @@ -3973,6 +4009,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if upstreamEndpoint.Valid { log.UpstreamEndpoint = &upstreamEndpoint.String } + if upstreamModel.Valid { + log.UpstreamModel = &upstreamModel.String + } return log, nil } diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index ad29990f..1c960fdf 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi return stats, nil } +func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) { + normalizedSource := usagestats.NormalizeModelSource(modelSource) + if normalizedSource == usagestats.ModelSourceRequested { + return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + } + + type modelStatsBySourceRepo interface { + GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error) + } + + if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok { + stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource) + if err != nil { + return nil, fmt.Errorf("get model stats with filters by source: %w", err) + } + return stats, nil + } + + return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) +} + func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil {