test(backend): add tests for upstream model tracking and model source filtering

Cover IsValidModelSource/NormalizeModelSource, resolveModelDimensionExpression SQL expressions, invalid model_source 400 responses on both GetModelStats and GetUserBreakdown, upstream_model in scan/insert SQL mock expectations, and updated passthrough/billing test signatures.
This commit is contained in:
Ethan0x0000
2026-03-17 19:26:30 +08:00
parent 56fcb20f94
commit eeff451bc5
7 changed files with 132 additions and 8 deletions

View File

@@ -149,6 +149,28 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code) require.Equal(t, http.StatusBadRequest, rec.Code)
} }
func TestDashboardModelStatsInvalidModelSource(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsValidModelSource(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestDashboardUsersRankingLimitAndCache(t *testing.T) { func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
repo := &dashboardUsageRepoCapture{ repo := &dashboardUsageRepoCapture{

View File

@@ -73,9 +73,35 @@ func TestGetUserBreakdown_ModelFilter(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code) require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model) require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType)
require.Equal(t, int64(0), repo.capturedDim.GroupID) require.Equal(t, int64(0), repo.capturedDim.GroupID)
} }
func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType)
}
func TestGetUserBreakdown_InvalidModelSource(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code)
}
func TestGetUserBreakdown_EndpointFilter(t *testing.T) { func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{} repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo) router := newUserBreakdownRouter(repo)

View File

@@ -0,0 +1,47 @@
package usagestats
import "testing"
func TestIsValidModelSource(t *testing.T) {
tests := []struct {
name string
source string
want bool
}{
{name: "requested", source: ModelSourceRequested, want: true},
{name: "upstream", source: ModelSourceUpstream, want: true},
{name: "mapping", source: ModelSourceMapping, want: true},
{name: "invalid", source: "foobar", want: false},
{name: "empty", source: "", want: false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := IsValidModelSource(tc.source); got != tc.want {
t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want)
}
})
}
}
func TestNormalizeModelSource(t *testing.T) {
tests := []struct {
name string
source string
want string
}{
{name: "requested", source: ModelSourceRequested, want: ModelSourceRequested},
{name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream},
{name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping},
{name: "invalid falls back", source: "foobar", want: ModelSourceRequested},
{name: "empty falls back", source: "", want: ModelSourceRequested},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := NormalizeModelSource(tc.source); got != tc.want {
t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want)
}
})
}
}

View File

@@ -5,6 +5,7 @@ package repository
import ( import (
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -16,8 +17,8 @@ func TestResolveEndpointColumn(t *testing.T) {
{"inbound", "ul.inbound_endpoint"}, {"inbound", "ul.inbound_endpoint"},
{"upstream", "ul.upstream_endpoint"}, {"upstream", "ul.upstream_endpoint"},
{"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"}, {"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"},
{"", "ul.inbound_endpoint"}, // default {"", "ul.inbound_endpoint"}, // default
{"unknown", "ul.inbound_endpoint"}, // fallback {"unknown", "ul.inbound_endpoint"}, // fallback
} }
for _, tc := range tests { for _, tc := range tests {
@@ -27,3 +28,23 @@ func TestResolveEndpointColumn(t *testing.T) {
}) })
} }
} }
func TestResolveModelDimensionExpression(t *testing.T) {
tests := []struct {
modelType string
want string
}{
{usagestats.ModelSourceRequested, "model"},
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"},
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"},
{"", "model"},
{"invalid", "model"},
}
for _, tc := range tests {
t.Run(tc.modelType, func(t *testing.T) {
got := resolveModelDimensionExpression(tc.modelType)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.AccountID, log.AccountID,
log.RequestID, log.RequestID,
log.Model, log.Model,
sqlmock.AnyArg(), // upstream_model
sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id sqlmock.AnyArg(), // subscription_id
log.InputTokens, log.InputTokens,
@@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.Model, log.Model,
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.InputTokens, log.InputTokens,
log.OutputTokens, log.OutputTokens,
log.CacheCreationTokens, log.CacheCreationTokens,
@@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(30), // account_id int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"}, sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model "gpt-5", // model
sql.NullString{}, // upstream_model
sql.NullInt64{}, // group_id sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id sql.NullInt64{}, // subscription_id
1, // input_tokens 1, // input_tokens
@@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(31), int64(31),
sql.NullString{Valid: true, String: "req-2"}, sql.NullString{Valid: true, String: "req-2"},
"gpt-5", "gpt-5",
sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
@@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(32), int64(32),
sql.NullString{Valid: true, String: "req-3"}, sql.NullString{Valid: true, String: "req-3"},
"gpt-5.4", "gpt-5.4",
sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,

View File

@@ -788,7 +788,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc
rateLimitService: &RateLimitService{}, rateLimitService: &RateLimitService{},
} }
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens) require.Equal(t, 12, result.Usage.InputTokens)
@@ -815,7 +815,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp
} }
svc := &GatewayService{} svc := &GatewayService{}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "requires apikey token") require.Contains(t, err.Error(), "requires apikey token")
@@ -840,7 +840,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest
} }
account := newAnthropicAPIKeyAccountForTest() account := newAnthropicAPIKeyAccountForTest()
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "upstream request failed") require.Contains(t, err.Error(), "upstream request failed")
@@ -873,7 +873,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo
httpUpstream: upstream, httpUpstream: upstream,
} }
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "empty response") require.Contains(t, err.Error(), "empty response")

View File

@@ -846,7 +846,7 @@ func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Nil(t, extractOpenAIServiceTierFromBody(nil)) require.Nil(t, extractOpenAIServiceTierFromBody(nil))
} }
func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetadataFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{} userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{}
@@ -859,6 +859,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te
RequestID: "resp_billing_model_override", RequestID: "resp_billing_model_override",
BillingModel: "gpt-5.1-codex", BillingModel: "gpt-5.1-codex",
Model: "gpt-5.1", Model: "gpt-5.1",
UpstreamModel: "gpt-5.1-codex",
ServiceTier: &serviceTier, ServiceTier: &serviceTier,
ReasoningEffort: &reasoning, ReasoningEffort: &reasoning,
Usage: OpenAIUsage{ Usage: OpenAIUsage{
@@ -877,7 +878,9 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog) require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model) require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
require.NotNil(t, usageRepo.lastLog.ServiceTier) require.NotNil(t, usageRepo.lastLog.ServiceTier)
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
require.NotNil(t, usageRepo.lastLog.ReasoningEffort) require.NotNil(t, usageRepo.lastLog.ReasoningEffort)