From eeff451bc58717b994901c48207be19daba39fdc Mon Sep 17 00:00:00 2001 From: Ethan0x0000 <3352979663@qq.com> Date: Tue, 17 Mar 2026 19:26:30 +0800 Subject: [PATCH] test(backend): add tests for upstream model tracking and model source filtering Cover IsValidModelSource/NormalizeModelSource, resolveModelDimensionExpression SQL expressions, invalid model_source 400 responses on both GetModelStats and GetUserBreakdown, upstream_model in scan/insert SQL mock expectations, and updated passthrough/billing test signatures. --- .../dashboard_handler_request_type_test.go | 22 +++++++++ .../dashboard_handler_user_breakdown_test.go | 26 ++++++++++ .../pkg/usagestats/usage_log_types_test.go | 47 +++++++++++++++++++ .../usage_log_repo_breakdown_test.go | 25 +++++++++- .../usage_log_repo_request_type_test.go | 5 ++ ...teway_anthropic_apikey_passthrough_test.go | 8 ++-- .../openai_gateway_record_usage_test.go | 7 ++- 7 files changed, 132 insertions(+), 8 deletions(-) create mode 100644 backend/internal/pkg/usagestats/usage_log_types_test.go diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go index 9aec61d4..6056f725 100644 --- a/backend/internal/handler/admin/dashboard_handler_request_type_test.go +++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go @@ -149,6 +149,28 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) { require.Equal(t, http.StatusBadRequest, rec.Code) } +func TestDashboardModelStatsInvalidModelSource(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsValidModelSource(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) +} + func TestDashboardUsersRankingLimitAndCache(t *testing.T) { dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) repo := &dashboardUsageRepoCapture{ diff --git a/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go b/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go index 2c1dbd59..b3a05111 100644 --- a/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go +++ b/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go @@ -73,9 +73,35 @@ func TestGetUserBreakdown_ModelFilter(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model) + require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType) require.Equal(t, int64(0), repo.capturedDim.GroupID) } +func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType) +} + +func TestGetUserBreakdown_InvalidModelSource(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code) +} + func TestGetUserBreakdown_EndpointFilter(t *testing.T) { repo := &userBreakdownRepoCapture{} router := newUserBreakdownRouter(repo) diff --git a/backend/internal/pkg/usagestats/usage_log_types_test.go b/backend/internal/pkg/usagestats/usage_log_types_test.go new file mode 100644 index 00000000..95cf6069 --- /dev/null +++ b/backend/internal/pkg/usagestats/usage_log_types_test.go @@ -0,0 +1,47 @@ +package usagestats + +import "testing" + +func TestIsValidModelSource(t *testing.T) { + tests := []struct { + name string + source string + want bool + }{ + {name: "requested", source: ModelSourceRequested, want: true}, + {name: "upstream", source: ModelSourceUpstream, want: true}, + {name: "mapping", source: ModelSourceMapping, want: true}, + {name: "invalid", source: "foobar", want: false}, + {name: "empty", source: "", want: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := IsValidModelSource(tc.source); got != tc.want { + t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want) + } + }) + } +} + +func TestNormalizeModelSource(t *testing.T) { + tests := []struct { + name string + source string + want string + }{ + {name: "requested", source: ModelSourceRequested, want: ModelSourceRequested}, + {name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream}, + {name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping}, + {name: "invalid falls back", source: "foobar", want: ModelSourceRequested}, + {name: "empty falls back", source: "", want: ModelSourceRequested}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := NormalizeModelSource(tc.source); got != tc.want { + t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want) + } + }) + } +} diff --git a/backend/internal/repository/usage_log_repo_breakdown_test.go b/backend/internal/repository/usage_log_repo_breakdown_test.go index ca63e0bc..5d908bfd 100644 --- a/backend/internal/repository/usage_log_repo_breakdown_test.go +++ b/backend/internal/repository/usage_log_repo_breakdown_test.go @@ -5,6 +5,7 @@ package repository import ( "testing" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/stretchr/testify/require" ) @@ -16,8 +17,8 @@ func TestResolveEndpointColumn(t *testing.T) { {"inbound", "ul.inbound_endpoint"}, {"upstream", "ul.upstream_endpoint"}, {"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"}, - {"", "ul.inbound_endpoint"}, // default - {"unknown", "ul.inbound_endpoint"}, // fallback + {"", "ul.inbound_endpoint"}, // default + {"unknown", "ul.inbound_endpoint"}, // fallback } for _, tc := range tests { @@ -27,3 +28,23 @@ func TestResolveEndpointColumn(t *testing.T) { }) } } + +func TestResolveModelDimensionExpression(t *testing.T) { + tests := []struct { + modelType string + want string + }{ + {usagestats.ModelSourceRequested, "model"}, + {usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"}, + {usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"}, + {"", "model"}, + {"invalid", "model"}, + } + + for _, tc := range tests { + t.Run(tc.modelType, func(t *testing.T) { + got := resolveModelDimensionExpression(tc.modelType) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 27ae4571..76827c31 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { log.AccountID, log.RequestID, log.Model, + sqlmock.AnyArg(), // upstream_model sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // subscription_id log.InputTokens, @@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { log.Model, sqlmock.AnyArg(), sqlmock.AnyArg(), + sqlmock.AnyArg(), log.InputTokens, log.OutputTokens, log.CacheCreationTokens, @@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(30), // account_id sql.NullString{Valid: true, String: "req-1"}, "gpt-5", // model + sql.NullString{}, // upstream_model sql.NullInt64{}, // group_id sql.NullInt64{}, // subscription_id 1, // input_tokens @@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(31), sql.NullString{Valid: true, String: "req-2"}, "gpt-5", + sql.NullString{}, sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, @@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(32), sql.NullString{Valid: true, String: "req-3"}, "gpt-5.4", + sql.NullString{}, sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 789cbab8..c534a9b7 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -788,7 +788,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc rateLimitService: &RateLimitService{}, } - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now()) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 12, result.Usage.InputTokens) @@ -815,7 +815,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp } svc := &GatewayService{} - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "requires apikey token") @@ -840,7 +840,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest } account := newAnthropicAPIKeyAccountForTest() - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "upstream request failed") @@ -873,7 +873,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo httpUpstream: upstream, } - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "empty response") diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index ada7d805..a35f9127 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -846,7 +846,7 @@ func TestExtractOpenAIServiceTierFromBody(t *testing.T) { require.Nil(t, extractOpenAIServiceTierFromBody(nil)) } -func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) { +func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetadataFields(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{} @@ -859,6 +859,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te RequestID: "resp_billing_model_override", BillingModel: "gpt-5.1-codex", Model: "gpt-5.1", + UpstreamModel: "gpt-5.1-codex", ServiceTier: &serviceTier, ReasoningEffort: &reasoning, Usage: OpenAIUsage{ @@ -877,7 +878,9 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te require.NoError(t, err) require.NotNil(t, usageRepo.lastLog) - require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model) + require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model) + require.NotNil(t, usageRepo.lastLog.UpstreamModel) + require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel) require.NotNil(t, usageRepo.lastLog.ServiceTier) require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) require.NotNil(t, usageRepo.lastLog.ReasoningEffort)