feat(dashboard): add model source dimension to stats queries

Support querying model statistics by 'requested', 'upstream', or 'mapping' dimension. Add resolveModelDimensionExpression for safe SQL expression generation, IsValidModelSource whitelist validator, and NormalizeModelSource fallback. Repository persists and scans upstream_model in all insert/select paths.
This commit is contained in:
Ethan0x0000
2026-03-17 19:25:52 +08:00
parent 2e4ac88ad9
commit 7134266acf
3 changed files with 102 additions and 19 deletions

View File

@@ -3,6 +3,28 @@ package usagestats
import "time"
const (
ModelSourceRequested = "requested"
ModelSourceUpstream = "upstream"
ModelSourceMapping = "mapping"
)
func IsValidModelSource(source string) bool {
switch source {
case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping:
return true
default:
return false
}
}
func NormalizeModelSource(source string) string {
if IsValidModelSource(source) {
return source
}
return ModelSourceRequested
}
// DashboardStats 仪表盘统计
type DashboardStats struct {
// 用户统计
@@ -143,6 +165,7 @@ type UserBreakdownItem struct {
type UserBreakdownDimension struct {
GroupID int64 // filter by group_id (>0 to enable)
Model string // filter by model name (non-empty to enable)
ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path"
}

View File

@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
var usageLogInsertArgTypes = [...]string{
"bigint",
@@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{
"bigint",
"text",
"text",
"text",
"bigint",
"bigint",
"integer",
@@ -277,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -311,12 +313,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
$1, $2, $3, $4, $5, $6,
$7, $8,
$9, $10, $11, $12,
$13, $14,
$15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -707,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -742,7 +745,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*38)
args := make([]any, 0, len(keys)*39)
argPos := 1
for idx, key := range keys {
if idx > 0 {
@@ -776,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -816,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -896,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -931,7 +937,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*38)
args := make([]any, 0, len(preparedList)*39)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -962,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -1002,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -1050,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -1084,12 +1093,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
$1, $2, $3, $4, $5, $6,
$7, $8,
$9, $10, $11, $12,
$13, $14,
$15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1121,6 +1130,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint)
upstreamModel := nullString(log.UpstreamModel)
var requestIDArg any
if requestID != "" {
@@ -1138,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.AccountID,
requestIDArg,
log.Model,
upstreamModel,
groupID,
subscriptionID,
log.InputTokens,
@@ -2864,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
// GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested)
}
// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension.
// source: requested | upstream | mapping.
func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source)
}
func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
modelExpr := resolveModelDimensionExpression(source)
query := fmt.Sprintf(`
SELECT
model,
%s as model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
@@ -2883,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, actualCostExpr)
`, modelExpr, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
@@ -2907,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY model ORDER BY total_tokens DESC"
query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr)
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
@@ -3021,7 +3043,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
args = append(args, dim.GroupID)
}
if dim.Model != "" {
query += fmt.Sprintf(" AND ul.model = $%d", len(args)+1)
query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1)
args = append(args, dim.Model)
}
if dim.Endpoint != "" {
@@ -3067,6 +3089,18 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
return results, nil
}
// resolveModelDimensionExpression maps model source type to a safe SQL expression.
func resolveModelDimensionExpression(modelType string) string {
switch usagestats.NormalizeModelSource(modelType) {
case usagestats.ModelSourceUpstream:
return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"
case usagestats.ModelSourceMapping:
return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"
default:
return "model"
}
}
// resolveEndpointColumn maps endpoint type to the corresponding DB column name.
func resolveEndpointColumn(endpointType string) string {
switch endpointType {
@@ -3819,6 +3853,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
accountID int64
requestID sql.NullString
model string
upstreamModel sql.NullString
groupID sql.NullInt64
subscriptionID sql.NullInt64
inputTokens int
@@ -3861,6 +3896,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&accountID,
&requestID,
&model,
&upstreamModel,
&groupID,
&subscriptionID,
&inputTokens,
@@ -3973,6 +4009,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamEndpoint.Valid {
log.UpstreamEndpoint = &upstreamEndpoint.String
}
if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String
}
return log, nil
}

View File

@@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) {
normalizedSource := usagestats.NormalizeModelSource(modelSource)
if normalizedSource == usagestats.ModelSourceRequested {
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
type modelStatsBySourceRepo interface {
GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error)
}
if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok {
stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource)
if err != nil {
return nil, fmt.Errorf("get model stats with filters by source: %w", err)
}
return stats, nil
}
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {