Merge pull request #1075 from touwaeriol/feat/dashboard-user-breakdown

feat(dashboard): add per-user drill-down for distribution charts
This commit is contained in:
Wesley Liddick
2026-03-17 09:25:43 +08:00
committed by GitHub
18 changed files with 695 additions and 72 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -604,3 +605,41 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// GetUserBreakdown handles getting per-user usage breakdown within a dimension.
// GET /api/v1/admin/dashboard/user-breakdown
// Query params: start_date, end_date, group_id, model, endpoint, endpoint_type, limit
func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
dim := usagestats.UserBreakdownDimension{}
if v := c.Query("group_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.GroupID = id
}
}
dim.Model = c.Query("model")
dim.Endpoint = c.Query("endpoint")
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
limit := 50
if v := c.Query("limit"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
limit = n
}
}
stats, err := h.dashboardService.GetUserBreakdownStats(
c.Request.Context(), startTime, endTime, dim, limit,
)
if err != nil {
response.Error(c, 500, "Failed to get user breakdown stats")
return
}
response.Success(c, gin.H{
"users": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}

View File

@@ -0,0 +1,203 @@
package admin
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// --- mock repo ---
type userBreakdownRepoCapture struct {
service.UsageLogRepository
capturedDim usagestats.UserBreakdownDimension
capturedLimit int
result []usagestats.UserBreakdownItem
}
func (r *userBreakdownRepoCapture) GetUserBreakdownStats(
_ context.Context, _, _ time.Time,
dim usagestats.UserBreakdownDimension, limit int,
) ([]usagestats.UserBreakdownItem, error) {
r.capturedDim = dim
r.capturedLimit = limit
if r.result != nil {
return r.result, nil
}
return []usagestats.UserBreakdownItem{}, nil
}
func newUserBreakdownRouter(repo *userBreakdownRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
svc := service.NewDashboardService(repo, nil, nil, nil)
h := NewDashboardHandler(svc, nil)
router := gin.New()
router.GET("/admin/dashboard/user-breakdown", h.GetUserBreakdown)
return router
}
// --- tests ---
func TestGetUserBreakdown_GroupIDFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, int64(42), repo.capturedDim.GroupID)
require.Empty(t, repo.capturedDim.Model)
require.Empty(t, repo.capturedDim.Endpoint)
require.Equal(t, 50, repo.capturedLimit) // default limit
}
func TestGetUserBreakdown_ModelFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
require.Equal(t, int64(0), repo.capturedDim.GroupID)
}
func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/v1/messages&endpoint_type=upstream", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "/v1/messages", repo.capturedDim.Endpoint)
require.Equal(t, "upstream", repo.capturedDim.EndpointType)
}
func TestGetUserBreakdown_DefaultEndpointType(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/chat", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "inbound", repo.capturedDim.EndpointType)
}
func TestGetUserBreakdown_CustomLimit(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=100", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, 100, repo.capturedLimit)
}
func TestGetUserBreakdown_LimitClamped(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
// limit > 200 should fall back to default 50
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=999", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, 50, repo.capturedLimit)
}
func TestGetUserBreakdown_ResponseFormat(t *testing.T) {
repo := &userBreakdownRepoCapture{
result: []usagestats.UserBreakdownItem{
{UserID: 1, Email: "alice@test.com", Requests: 100, TotalTokens: 50000, Cost: 1.5, ActualCost: 1.2},
{UserID: 2, Email: "bob@test.com", Requests: 50, TotalTokens: 25000, Cost: 0.8, ActualCost: 0.6},
},
}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=1", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Users []usagestats.UserBreakdownItem `json:"users"`
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
} `json:"data"`
}
err := json.Unmarshal(w.Body.Bytes(), &resp)
require.NoError(t, err)
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Users, 2)
require.Equal(t, int64(1), resp.Data.Users[0].UserID)
require.Equal(t, "alice@test.com", resp.Data.Users[0].Email)
require.Equal(t, int64(100), resp.Data.Users[0].Requests)
require.InDelta(t, 1.2, resp.Data.Users[0].ActualCost, 0.001)
require.Equal(t, "2026-03-01", resp.Data.StartDate)
require.Equal(t, "2026-03-16", resp.Data.EndDate)
}
func TestGetUserBreakdown_EmptyResult(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=999", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var resp struct {
Data struct {
Users []usagestats.UserBreakdownItem `json:"users"`
} `json:"data"`
}
err := json.Unmarshal(w.Body.Bytes(), &resp)
require.NoError(t, err)
require.Empty(t, resp.Data.Users)
}
func TestGetUserBreakdown_NoFilters(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, int64(0), repo.capturedDim.GroupID)
require.Empty(t, repo.capturedDim.Model)
require.Empty(t, repo.capturedDim.Endpoint)
}

View File

@@ -345,6 +345,9 @@ func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Conte
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil
}

View File

@@ -129,6 +129,24 @@ type UserSpendingRankingResponse struct {
TotalTokens int64 `json:"total_tokens"`
}
// UserBreakdownItem represents per-user usage breakdown within a dimension (group, model, endpoint).
type UserBreakdownItem struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
type UserBreakdownDimension struct {
GroupID int64 // filter by group_id (>0 to enable)
Model string // filter by model name (non-empty to enable)
Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path"
}
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint struct {
Date string `json:"date"`

View File

@@ -3000,6 +3000,85 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
return results, nil
}
// GetUserBreakdownStats returns per-user usage breakdown within a specific dimension.
func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) (results []usagestats.UserBreakdownItem, err error) {
query := `
SELECT
COALESCE(ul.user_id, 0) as user_id,
COALESCE(u.email, '') as email,
COUNT(*) as requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(ul.total_cost), 0) as cost,
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
FROM usage_logs ul
LEFT JOIN users u ON u.id = ul.user_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
`
args := []any{startTime, endTime}
if dim.GroupID > 0 {
query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1)
args = append(args, dim.GroupID)
}
if dim.Model != "" {
query += fmt.Sprintf(" AND ul.model = $%d", len(args)+1)
args = append(args, dim.Model)
}
if dim.Endpoint != "" {
col := resolveEndpointColumn(dim.EndpointType)
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
args = append(args, dim.Endpoint)
}
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
}
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results = make([]usagestats.UserBreakdownItem, 0)
for rows.Next() {
var row usagestats.UserBreakdownItem
if err := rows.Scan(
&row.UserID,
&row.Email,
&row.Requests,
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// resolveEndpointColumn maps endpoint type to the corresponding DB column name.
func resolveEndpointColumn(endpointType string) string {
switch endpointType {
case "upstream":
return "ul.upstream_endpoint"
case "path":
return "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"
default:
return "ul.inbound_endpoint"
}
}
// GetGlobalStats gets usage statistics for all users within a time range
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
query := `

View File

@@ -0,0 +1,29 @@
//go:build unit
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestResolveEndpointColumn(t *testing.T) {
tests := []struct {
endpointType string
want string
}{
{"inbound", "ul.inbound_endpoint"},
{"upstream", "ul.upstream_endpoint"},
{"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"},
{"", "ul.inbound_endpoint"}, // default
{"unknown", "ul.inbound_endpoint"}, // fallback
}
for _, tc := range tests {
t.Run(tc.endpointType, func(t *testing.T) {
got := resolveEndpointColumn(tc.endpointType)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -1637,6 +1637,10 @@ func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}

View File

@@ -198,6 +198,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
dashboard.GET("/user-breakdown", h.Admin.Dashboard.GetUserBreakdown)
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
}
}

View File

@@ -48,6 +48,7 @@ type UsageLogRepository interface {
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)

View File

@@ -335,6 +335,14 @@ func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime
return ranking, nil
}
func (s *DashboardService) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
stats, err := s.usageRepo.GetUserBreakdownStats(ctx, startTime, endTime, dim, limit)
if err != nil {
return nil, fmt.Errorf("get user breakdown stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
if err != nil {