From 51547fa216608cbc9661ecd585ada230cfb9377c Mon Sep 17 00:00:00 2001 From: Ethan0x0000 <3352979663@qq.com> Date: Tue, 17 Mar 2026 19:25:17 +0800 Subject: [PATCH 1/7] feat(db): add upstream_model column to usage_logs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add nullable VARCHAR(100) column to record the actual model sent to upstream providers when model mapping is applied. NULL means no mapping — the requested model was used as-is. Includes migration, concurrent index for aggregation queries, Ent schema regeneration, and migration README correction (forward-only runner, not goose). --- backend/ent/migrate/schema.go | 29 +++---- backend/ent/mutation.go | 75 ++++++++++++++++- backend/ent/runtime/runtime.go | 48 ++++++----- backend/ent/schema/usage_log.go | 6 ++ backend/ent/usagelog.go | 16 +++- backend/ent/usagelog/usagelog.go | 10 +++ backend/ent/usagelog/where.go | 80 ++++++++++++++++++ backend/ent/usagelog_create.go | 83 +++++++++++++++++++ backend/ent/usagelog_update.go | 62 ++++++++++++++ backend/go.mod | 2 + backend/go.sum | 4 + .../075_add_usage_log_upstream_model.sql | 4 + ...dd_usage_log_upstream_model_index_notx.sql | 3 + backend/migrations/README.md | 38 +++------ 14 files changed, 396 insertions(+), 64 deletions(-) create mode 100644 backend/migrations/075_add_usage_log_upstream_model.sql create mode 100644 backend/migrations/076_add_usage_log_upstream_model_index_notx.sql diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index ff1c1b88..acdd0d18 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -716,6 +716,7 @@ var ( {Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "request_id", Type: field.TypeString, Size: 64}, {Name: "model", Type: field.TypeString, Size: 100}, + {Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "input_tokens", Type: field.TypeInt, Default: 0}, {Name: "output_tokens", Type: field.TypeInt, Default: 0}, {Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0}, @@ -755,31 +756,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -788,32 +789,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_model", @@ -828,17 +829,17 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]}, }, { Name: "usagelog_group_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 652adcac..ff58fa9e 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -18239,6 +18239,7 @@ type UsageLogMutation struct { id *int64 request_id *string model *string + upstream_model *string input_tokens *int addinput_tokens *int output_tokens *int @@ -18576,6 +18577,55 @@ func (m *UsageLogMutation) ResetModel() { m.model = nil } +// SetUpstreamModel sets the "upstream_model" field. +func (m *UsageLogMutation) SetUpstreamModel(s string) { + m.upstream_model = &s +} + +// UpstreamModel returns the value of the "upstream_model" field in the mutation. +func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) { + v := m.upstream_model + if v == nil { + return + } + return *v, true +} + +// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpstreamModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err) + } + return oldValue.UpstreamModel, nil +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (m *UsageLogMutation) ClearUpstreamModel() { + m.upstream_model = nil + m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{} +} + +// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation. +func (m *UsageLogMutation) UpstreamModelCleared() bool { + _, ok := m.clearedFields[usagelog.FieldUpstreamModel] + return ok +} + +// ResetUpstreamModel resets all changes to the "upstream_model" field. +func (m *UsageLogMutation) ResetUpstreamModel() { + m.upstream_model = nil + delete(m.clearedFields, usagelog.FieldUpstreamModel) +} + // SetGroupID sets the "group_id" field. func (m *UsageLogMutation) SetGroupID(i int64) { m.group = &i @@ -20197,7 +20247,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 32) + fields := make([]string, 0, 33) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -20213,6 +20263,9 @@ func (m *UsageLogMutation) Fields() []string { if m.model != nil { fields = append(fields, usagelog.FieldModel) } + if m.upstream_model != nil { + fields = append(fields, usagelog.FieldUpstreamModel) + } if m.group != nil { fields = append(fields, usagelog.FieldGroupID) } @@ -20312,6 +20365,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.RequestID() case usagelog.FieldModel: return m.Model() + case usagelog.FieldUpstreamModel: + return m.UpstreamModel() case usagelog.FieldGroupID: return m.GroupID() case usagelog.FieldSubscriptionID: @@ -20385,6 +20440,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldRequestID(ctx) case usagelog.FieldModel: return m.OldModel(ctx) + case usagelog.FieldUpstreamModel: + return m.OldUpstreamModel(ctx) case usagelog.FieldGroupID: return m.OldGroupID(ctx) case usagelog.FieldSubscriptionID: @@ -20483,6 +20540,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetModel(v) return nil + case usagelog.FieldUpstreamModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpstreamModel(v) + return nil case usagelog.FieldGroupID: v, ok := value.(int64) if !ok { @@ -20921,6 +20985,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error { // mutation. func (m *UsageLogMutation) ClearedFields() []string { var fields []string + if m.FieldCleared(usagelog.FieldUpstreamModel) { + fields = append(fields, usagelog.FieldUpstreamModel) + } if m.FieldCleared(usagelog.FieldGroupID) { fields = append(fields, usagelog.FieldGroupID) } @@ -20962,6 +21029,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *UsageLogMutation) ClearField(name string) error { switch name { + case usagelog.FieldUpstreamModel: + m.ClearUpstreamModel() + return nil case usagelog.FieldGroupID: m.ClearGroupID() return nil @@ -21012,6 +21082,9 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldModel: m.ResetModel() return nil + case usagelog.FieldUpstreamModel: + m.ResetUpstreamModel() + return nil case usagelog.FieldGroupID: m.ResetGroupID() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index b8facf36..2401e553 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -821,92 +821,96 @@ func init() { return nil } }() + // usagelogDescUpstreamModel is the schema descriptor for upstream_model field. + usagelogDescUpstreamModel := usagelogFields[5].Descriptor() + // usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. + usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error) // usagelogDescInputTokens is the schema descriptor for input_tokens field. - usagelogDescInputTokens := usagelogFields[7].Descriptor() + usagelogDescInputTokens := usagelogFields[8].Descriptor() // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) // usagelogDescOutputTokens is the schema descriptor for output_tokens field. - usagelogDescOutputTokens := usagelogFields[8].Descriptor() + usagelogDescOutputTokens := usagelogFields[9].Descriptor() // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. - usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor() + usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor() // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. - usagelogDescCacheReadTokens := usagelogFields[10].Descriptor() + usagelogDescCacheReadTokens := usagelogFields[11].Descriptor() // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. - usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor() + usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor() // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. - usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor() + usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor() // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) // usagelogDescInputCost is the schema descriptor for input_cost field. - usagelogDescInputCost := usagelogFields[13].Descriptor() + usagelogDescInputCost := usagelogFields[14].Descriptor() // usagelog.DefaultInputCost holds the default value on creation for the input_cost field. usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) // usagelogDescOutputCost is the schema descriptor for output_cost field. - usagelogDescOutputCost := usagelogFields[14].Descriptor() + usagelogDescOutputCost := usagelogFields[15].Descriptor() // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. - usagelogDescCacheCreationCost := usagelogFields[15].Descriptor() + usagelogDescCacheCreationCost := usagelogFields[16].Descriptor() // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. - usagelogDescCacheReadCost := usagelogFields[16].Descriptor() + usagelogDescCacheReadCost := usagelogFields[17].Descriptor() // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) // usagelogDescTotalCost is the schema descriptor for total_cost field. - usagelogDescTotalCost := usagelogFields[17].Descriptor() + usagelogDescTotalCost := usagelogFields[18].Descriptor() // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) // usagelogDescActualCost is the schema descriptor for actual_cost field. - usagelogDescActualCost := usagelogFields[18].Descriptor() + usagelogDescActualCost := usagelogFields[19].Descriptor() // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. - usagelogDescRateMultiplier := usagelogFields[19].Descriptor() + usagelogDescRateMultiplier := usagelogFields[20].Descriptor() // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) // usagelogDescBillingType is the schema descriptor for billing_type field. - usagelogDescBillingType := usagelogFields[21].Descriptor() + usagelogDescBillingType := usagelogFields[22].Descriptor() // usagelog.DefaultBillingType holds the default value on creation for the billing_type field. usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) // usagelogDescStream is the schema descriptor for stream field. - usagelogDescStream := usagelogFields[22].Descriptor() + usagelogDescStream := usagelogFields[23].Descriptor() // usagelog.DefaultStream holds the default value on creation for the stream field. usagelog.DefaultStream = usagelogDescStream.Default.(bool) // usagelogDescUserAgent is the schema descriptor for user_agent field. - usagelogDescUserAgent := usagelogFields[25].Descriptor() + usagelogDescUserAgent := usagelogFields[26].Descriptor() // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) // usagelogDescIPAddress is the schema descriptor for ip_address field. - usagelogDescIPAddress := usagelogFields[26].Descriptor() + usagelogDescIPAddress := usagelogFields[27].Descriptor() // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) // usagelogDescImageCount is the schema descriptor for image_count field. - usagelogDescImageCount := usagelogFields[27].Descriptor() + usagelogDescImageCount := usagelogFields[28].Descriptor() // usagelog.DefaultImageCount holds the default value on creation for the image_count field. usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) // usagelogDescImageSize is the schema descriptor for image_size field. - usagelogDescImageSize := usagelogFields[28].Descriptor() + usagelogDescImageSize := usagelogFields[29].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) // usagelogDescMediaType is the schema descriptor for media_type field. - usagelogDescMediaType := usagelogFields[29].Descriptor() + usagelogDescMediaType := usagelogFields[30].Descriptor() // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. - usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor() + usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor() // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[31].Descriptor() + usagelogDescCreatedAt := usagelogFields[32].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index dcca1a0a..8f8a5255 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field { field.String("model"). MaxLen(100). NotEmpty(), + // UpstreamModel stores the actual upstream model name when model mapping + // is applied. NULL means no mapping — the requested model was used as-is. + field.String("upstream_model"). + MaxLen(100). + Optional(). + Nillable(), field.Int64("group_id"). Optional(). Nillable(), diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index f6968d0d..014851c9 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -32,6 +32,8 @@ type UsageLog struct { RequestID string `json:"request_id,omitempty"` // Model holds the value of the "model" field. Model string `json:"model,omitempty"` + // UpstreamModel holds the value of the "upstream_model" field. + UpstreamModel *string `json:"upstream_model,omitempty"` // GroupID holds the value of the "group_id" field. GroupID *int64 `json:"group_id,omitempty"` // SubscriptionID holds the value of the "subscription_id" field. @@ -175,7 +177,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -230,6 +232,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Model = value.String } + case usagelog.FieldUpstreamModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field upstream_model", values[i]) + } else if value.Valid { + _m.UpstreamModel = new(string) + *_m.UpstreamModel = value.String + } case usagelog.FieldGroupID: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field group_id", values[i]) @@ -477,6 +486,11 @@ func (_m *UsageLog) String() string { builder.WriteString("model=") builder.WriteString(_m.Model) builder.WriteString(", ") + if v := _m.UpstreamModel; v != nil { + builder.WriteString("upstream_model=") + builder.WriteString(*v) + } + builder.WriteString(", ") if v := _m.GroupID; v != nil { builder.WriteString("group_id=") builder.WriteString(fmt.Sprintf("%v", *v)) diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index ba97b843..789407e7 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -24,6 +24,8 @@ const ( FieldRequestID = "request_id" // FieldModel holds the string denoting the model field in the database. FieldModel = "model" + // FieldUpstreamModel holds the string denoting the upstream_model field in the database. + FieldUpstreamModel = "upstream_model" // FieldGroupID holds the string denoting the group_id field in the database. FieldGroupID = "group_id" // FieldSubscriptionID holds the string denoting the subscription_id field in the database. @@ -135,6 +137,7 @@ var Columns = []string{ FieldAccountID, FieldRequestID, FieldModel, + FieldUpstreamModel, FieldGroupID, FieldSubscriptionID, FieldInputTokens, @@ -179,6 +182,8 @@ var ( RequestIDValidator func(string) error // ModelValidator is a validator for the "model" field. It is called by the builders before save. ModelValidator func(string) error + // UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. + UpstreamModelValidator func(string) error // DefaultInputTokens holds the default value on creation for the "input_tokens" field. DefaultInputTokens int // DefaultOutputTokens holds the default value on creation for the "output_tokens" field. @@ -258,6 +263,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldModel, opts...).ToFunc() } +// ByUpstreamModel orders the results by the upstream_model field. +func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc() +} + // ByGroupID orders the results by the group_id field. func ByGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldGroupID, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index af960335..5f341976 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) } +// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ. +func UpstreamModel(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) +} + // GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ. func GroupID(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) @@ -405,6 +410,81 @@ func ModelContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v)) } +// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field. +func UpstreamModelEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) +} + +// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field. +func UpstreamModelNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v)) +} + +// UpstreamModelIn applies the In predicate on the "upstream_model" field. +func UpstreamModelIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...)) +} + +// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field. +func UpstreamModelNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...)) +} + +// UpstreamModelGT applies the GT predicate on the "upstream_model" field. +func UpstreamModelGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v)) +} + +// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field. +func UpstreamModelGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v)) +} + +// UpstreamModelLT applies the LT predicate on the "upstream_model" field. +func UpstreamModelLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v)) +} + +// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field. +func UpstreamModelLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v)) +} + +// UpstreamModelContains applies the Contains predicate on the "upstream_model" field. +func UpstreamModelContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v)) +} + +// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field. +func UpstreamModelHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v)) +} + +// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field. +func UpstreamModelHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v)) +} + +// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field. +func UpstreamModelIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel)) +} + +// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field. +func UpstreamModelNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel)) +} + +// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field. +func UpstreamModelEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v)) +} + +// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field. +func UpstreamModelContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v)) +} + // GroupIDEQ applies the EQ predicate on the "group_id" field. func GroupIDEQ(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index e0285a5e..26be5dcb 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate { return _c } +// SetUpstreamModel sets the "upstream_model" field. +func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate { + _c.mutation.SetUpstreamModel(v) + return _c +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate { + if v != nil { + _c.SetUpstreamModel(*v) + } + return _c +} + // SetGroupID sets the "group_id" field. func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate { _c.mutation.SetGroupID(v) @@ -596,6 +610,11 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _c.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if _, ok := _c.mutation.InputTokens(); !ok { return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)} } @@ -714,6 +733,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldModel, field.TypeString, value) _node.Model = value } + if value, ok := _c.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + _node.UpstreamModel = &value + } if value, ok := _c.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _node.InputTokens = value @@ -1011,6 +1034,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert { return u } +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert { + u.Set(usagelog.FieldUpstreamModel, v) + return u +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldUpstreamModel) + return u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert { + u.SetNull(usagelog.FieldUpstreamModel) + return u +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert { u.Set(usagelog.FieldGroupID, v) @@ -1600,6 +1641,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne { }) } +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetUpstreamModel(v) + }) +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUpstreamModel() + }) +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearUpstreamModel() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -2434,6 +2496,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk { }) } +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetUpstreamModel(v) + }) +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUpstreamModel() + }) +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearUpstreamModel() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index b46e5b56..b7c4632c 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate { return _u } +// SetUpstreamModel sets the "upstream_model" field. +func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate { + _u.mutation.SetUpstreamModel(v) + return _u +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate { + if v != nil { + _u.SetUpstreamModel(*v) + } + return _u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate { + _u.mutation.ClearUpstreamModel() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate { _u.mutation.SetGroupID(v) @@ -745,6 +765,11 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _u.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -795,6 +820,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Model(); ok { _spec.SetField(usagelog.FieldModel, field.TypeString, value) } + if value, ok := _u.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + } + if _u.mutation.UpstreamModelCleared() { + _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } @@ -1177,6 +1208,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne { return _u } +// SetUpstreamModel sets the "upstream_model" field. +func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne { + _u.mutation.SetUpstreamModel(v) + return _u +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetUpstreamModel(*v) + } + return _u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne { + _u.mutation.ClearUpstreamModel() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne { _u.mutation.SetGroupID(v) @@ -1833,6 +1884,11 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _u.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -1900,6 +1956,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if value, ok := _u.mutation.Model(); ok { _spec.SetField(usagelog.FieldModel, field.TypeString, value) } + if value, ok := _u.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + } + if _u.mutation.UpstreamModelCleared() { + _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } diff --git a/backend/go.mod b/backend/go.mod index 135cbd3e..509619b1 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -107,6 +107,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -176,6 +177,7 @@ require ( golang.org/x/mod v0.32.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect + golang.org/x/tools v0.41.0 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 324fe652..847888e3 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -184,6 +184,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -203,6 +205,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= diff --git a/backend/migrations/075_add_usage_log_upstream_model.sql b/backend/migrations/075_add_usage_log_upstream_model.sql new file mode 100644 index 00000000..7f9f8ec6 --- /dev/null +++ b/backend/migrations/075_add_usage_log_upstream_model.sql @@ -0,0 +1,4 @@ +-- Add upstream_model field to usage_logs. +-- Stores the actual upstream model name when it differs from the requested model +-- (i.e., when model mapping is applied). NULL means no mapping was applied. +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100); diff --git a/backend/migrations/076_add_usage_log_upstream_model_index_notx.sql b/backend/migrations/076_add_usage_log_upstream_model_index_notx.sql new file mode 100644 index 00000000..9eee61be --- /dev/null +++ b/backend/migrations/076_add_usage_log_upstream_model_index_notx.sql @@ -0,0 +1,3 @@ +-- Support upstream_model / mapping model distribution aggregations with time-range filters. +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_created_model_upstream_model +ON usage_logs (created_at, model, upstream_model); diff --git a/backend/migrations/README.md b/backend/migrations/README.md index 47f6fa35..40455ad9 100644 --- a/backend/migrations/README.md +++ b/backend/migrations/README.md @@ -34,18 +34,18 @@ Example: `017_add_gemini_tier_id.sql` ## Migration File Structure -```sql --- +goose Up --- +goose StatementBegin --- Your forward migration SQL here --- +goose StatementEnd +This project uses a custom migration runner (`internal/repository/migrations_runner.go`) that executes the full SQL file content as-is. --- +goose Down --- +goose StatementBegin --- Your rollback migration SQL here --- +goose StatementEnd +- Regular migrations (`*.sql`): executed in a transaction. +- Non-transactional migrations (`*_notx.sql`): split by statement and executed without transaction (for `CONCURRENTLY`). + +```sql +-- Forward-only migration (recommended) +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS example_column VARCHAR(100); ``` +> ⚠️ Do **not** place executable "Down" SQL in the same file. The runner does not parse goose Up/Down sections and will execute all SQL statements in the file. + ## Important Rules ### ⚠️ Immutability Principle @@ -66,9 +66,9 @@ Why? touch migrations/018_your_change.sql ``` -2. **Write Up and Down migrations** - - Up: Apply the change - - Down: Revert the change (should be symmetric with Up) +2. **Write forward-only migration SQL** + - Put only the intended schema change in the file + - If rollback is needed, create a new migration file to revert 3. **Test locally** ```bash @@ -144,8 +144,6 @@ touch migrations/018_your_new_change.sql ## Example Migration ```sql --- +goose Up --- +goose StatementBegin -- Add tier_id field to Gemini OAuth accounts for quota tracking UPDATE accounts SET credentials = jsonb_set( @@ -157,17 +155,6 @@ SET credentials = jsonb_set( WHERE platform = 'gemini' AND type = 'oauth' AND credentials->>'tier_id' IS NULL; --- +goose StatementEnd - --- +goose Down --- +goose StatementBegin --- Remove tier_id field -UPDATE accounts -SET credentials = credentials - 'tier_id' -WHERE platform = 'gemini' - AND type = 'oauth' - AND credentials->>'tier_id' = 'LEGACY'; --- +goose StatementEnd ``` ## Troubleshooting @@ -194,5 +181,4 @@ VALUES ('NNN_migration.sql', 'calculated_checksum', NOW()); ## References - Migration runner: `internal/repository/migrations_runner.go` -- Goose syntax: https://github.com/pressly/goose - PostgreSQL docs: https://www.postgresql.org/docs/ From 2e4ac88ad9773436b4156a410510ffe8f173ba7a Mon Sep 17 00:00:00 2001 From: Ethan0x0000 <3352979663@qq.com> Date: Tue, 17 Mar 2026 19:25:35 +0800 Subject: [PATCH 2/7] feat(service): record upstream model across all gateway paths Propagate UpstreamModel through ForwardResult and OpenAIForwardResult in Anthropic direct, API-key passthrough, Bedrock, and OpenAI gateway flows. Extract optionalNonEqualStringPtr and optionalTrimmedStringPtr into usage_log_helpers.go. Store upstream_model only when it differs from the requested model. Also introduces anthropicPassthroughForwardInput struct to reduce parameter count. --- backend/internal/service/gateway_service.go | 56 +++++++++++++++---- .../openai_gateway_chat_completions.go | 28 +++++----- .../service/openai_gateway_messages.go | 28 +++++----- .../service/openai_gateway_service.go | 16 +++--- backend/internal/service/usage_log.go | 3 + backend/internal/service/usage_log_helpers.go | 21 +++++++ 6 files changed, 107 insertions(+), 45 deletions(-) create mode 100644 backend/internal/service/usage_log_helpers.go diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 0b50162a..4544ec82 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -491,6 +491,7 @@ type ForwardResult struct { RequestID string Usage ClaudeUsage Model string + UpstreamModel string // Actual upstream model after mapping (empty = no mapping) Stream bool Duration time.Duration FirstTokenMs *int // 首字时间(流式请求) @@ -3989,7 +3990,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A passthroughModel = mappedModel } } - return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) + return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ + Body: passthroughBody, + RequestModel: passthroughModel, + OriginalModel: parsed.Model, + RequestStream: parsed.Stream, + StartTime: startTime, + }) } if account != nil && account.IsBedrock() { @@ -4513,6 +4520,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, // 使用原始模型用于计费和日志 + UpstreamModel: mappedModel, Stream: reqStream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -4520,14 +4528,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }, nil } +type anthropicPassthroughForwardInput struct { + Body []byte + RequestModel string + OriginalModel string + RequestStream bool + StartTime time.Time +} + func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( ctx context.Context, c *gin.Context, account *Account, body []byte, reqModel string, + originalModel string, reqStream bool, startTime time.Time, +) (*ForwardResult, error) { + return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ + Body: body, + RequestModel: reqModel, + OriginalModel: originalModel, + RequestStream: reqStream, + StartTime: startTime, + }) +} + +func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( + ctx context.Context, + c *gin.Context, + account *Account, + input anthropicPassthroughForwardInput, ) (*ForwardResult, error) { token, tokenType, err := s.GetAccessToken(ctx, account) if err != nil { @@ -4543,19 +4575,19 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( } logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", - account.ID, account.Name, reqModel, reqStream) + account.ID, account.Name, input.RequestModel, input.RequestStream) if c != nil { c.Set("anthropic_passthrough", true) } // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 - setOpsUpstreamRequestBody(c, body) + setOpsUpstreamRequestBody(c, input.Body) var resp *http.Response retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) - upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, input.RequestStream) + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, input.Body, token) releaseUpstreamCtx() if err != nil { return nil, err @@ -4713,8 +4745,8 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( var usage *ClaudeUsage var firstTokenMs *int var clientDisconnect bool - if reqStream { - streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel) + if input.RequestStream { + streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, input.StartTime, input.RequestModel) if err != nil { return nil, err } @@ -4734,9 +4766,10 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( return &ForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, - Model: reqModel, - Stream: reqStream, - Duration: time.Since(startTime), + Model: input.OriginalModel, + UpstreamModel: input.RequestModel, + Stream: input.RequestStream, + Duration: time.Since(input.StartTime), FirstTokenMs: firstTokenMs, ClientDisconnect: clientDisconnect, }, nil @@ -5241,6 +5274,7 @@ func (s *GatewayService) forwardBedrock( RequestID: resp.Header.Get("x-amzn-requestid"), Usage: *usage, Model: reqModel, + UpstreamModel: mappedModel, Stream: reqStream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -7529,6 +7563,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu AccountID: account.ID, RequestID: requestID, Model: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), @@ -7710,6 +7745,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * AccountID: account.ID, RequestID: requestID, Model: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 9529f6be..7202f7cb 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -277,12 +277,13 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( c.JSON(http.StatusOK, chatResp) return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: false, - Duration: time.Since(startTime), + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), }, nil } @@ -324,13 +325,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: true, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + UpstreamModel: mappedModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, } } diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 58714571..6a29823a 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -299,12 +299,13 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( c.JSON(http.StatusOK, anthropicResp) return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: false, - Duration: time.Since(startTime), + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), }, nil } @@ -347,13 +348,14 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( // resultWithUsage builds the final result snapshot. resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: true, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + UpstreamModel: mappedModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, } } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index c8876edb..cf902c20 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -216,6 +216,9 @@ type OpenAIForwardResult struct { // This is set by the Anthropic Messages conversion path where // the mapped upstream model differs from the client-facing model. BillingModel string + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Empty when no mapping was applied (requested model was used as-is). + UpstreamModel string // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". // Nil means the request did not specify a recognized tier. ServiceTier *string @@ -2128,6 +2131,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco firstTokenMs, wsAttempts, ) + wsResult.UpstreamModel = mappedModel return wsResult, nil } s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) @@ -2263,6 +2267,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, + UpstreamModel: mappedModel, ServiceTier: serviceTier, ReasoningEffort: reasoningEffort, Stream: reqStream, @@ -4134,7 +4139,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec APIKeyID: apiKey.ID, AccountID: account.ID, RequestID: requestID, - Model: billingModel, + Model: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), @@ -4700,11 +4706,3 @@ func normalizeOpenAIReasoningEffort(raw string) string { return "" } } - -func optionalTrimmedStringPtr(raw string) *string { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return nil - } - return &trimmed -} diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 7f1bef7f..5a498676 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -98,6 +98,9 @@ type UsageLog struct { AccountID int64 RequestID string Model string + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Nil means no mapping was applied (requested model was used as-is). + UpstreamModel *string // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". ServiceTier *string // ReasoningEffort is the request's reasoning effort level. diff --git a/backend/internal/service/usage_log_helpers.go b/backend/internal/service/usage_log_helpers.go new file mode 100644 index 00000000..2ab51849 --- /dev/null +++ b/backend/internal/service/usage_log_helpers.go @@ -0,0 +1,21 @@ +package service + +import "strings" + +func optionalTrimmedStringPtr(raw string) *string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return nil + } + return &trimmed +} + +// optionalNonEqualStringPtr returns a pointer to value if it is non-empty and +// differs from compare; otherwise nil. Used to store upstream_model only when +// it differs from the requested model. +func optionalNonEqualStringPtr(value, compare string) *string { + if value == "" || value == compare { + return nil + } + return &value +} From 7134266acfae3e3bfce4d983b6258afc6624526c Mon Sep 17 00:00:00 2001 From: Ethan0x0000 <3352979663@qq.com> Date: Tue, 17 Mar 2026 19:25:52 +0800 Subject: [PATCH 3/7] feat(dashboard): add model source dimension to stats queries Support querying model statistics by 'requested', 'upstream', or 'mapping' dimension. Add resolveModelDimensionExpression for safe SQL expression generation, IsValidModelSource whitelist validator, and NormalizeModelSource fallback. Repository persists and scans upstream_model in all insert/select paths. --- .../pkg/usagestats/usage_log_types.go | 23 ++++++ backend/internal/repository/usage_log_repo.go | 77 ++++++++++++++----- backend/internal/service/dashboard_service.go | 21 +++++ 3 files changed, 102 insertions(+), 19 deletions(-) diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index f42a746f..de3ad378 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -3,6 +3,28 @@ package usagestats import "time" +const ( + ModelSourceRequested = "requested" + ModelSourceUpstream = "upstream" + ModelSourceMapping = "mapping" +) + +func IsValidModelSource(source string) bool { + switch source { + case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping: + return true + default: + return false + } +} + +func NormalizeModelSource(source string) string { + if IsValidModelSource(source) { + return source + } + return ModelSourceRequested +} + // DashboardStats 仪表盘统计 type DashboardStats struct { // 用户统计 @@ -143,6 +165,7 @@ type UserBreakdownItem struct { type UserBreakdownDimension struct { GroupID int64 // filter by group_id (>0 to enable) Model string // filter by model name (non-empty to enable) + ModelType string // "requested", "upstream", or "mapping" Endpoint string // filter by endpoint value (non-empty to enable) EndpointType string // "inbound", "upstream", or "path" } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index dcdaeaee..61a54267 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,7 +28,7 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" var usageLogInsertArgTypes = [...]string{ "bigint", @@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{ "bigint", "text", "text", + "text", "bigint", "bigint", "integer", @@ -277,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -311,12 +313,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 + $1, $2, $3, $4, $5, $6, + $7, $8, + $9, $10, $11, $12, + $13, $14, + $15, $16, $17, $18, $19, $20, + $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -707,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -742,7 +745,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage created_at ) AS (VALUES `) - args := make([]any, 0, len(keys)*38) + args := make([]any, 0, len(keys)*39) argPos := 1 for idx, key := range keys { if idx > 0 { @@ -776,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -816,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -896,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -931,7 +937,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*38) + args := make([]any, 0, len(preparedList)*39) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -962,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1002,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1050,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1084,12 +1093,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 + $1, $2, $3, $4, $5, $6, + $7, $8, + $9, $10, $11, $12, + $13, $14, + $15, $16, $17, $18, $19, $20, + $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1121,6 +1130,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint) + upstreamModel := nullString(log.UpstreamModel) var requestIDArg any if requestID != "" { @@ -1138,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { log.AccountID, requestIDArg, log.Model, + upstreamModel, groupID, subscriptionID, log.InputTokens, @@ -2864,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st // GetModelStatsWithFilters returns model statistics with optional filters func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { + return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested) +} + +// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension. +// source: requested | upstream | mapping. +func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) { + return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source) +} + +func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" } + modelExpr := resolveModelDimensionExpression(source) query := fmt.Sprintf(` SELECT - model, + %s as model, COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, @@ -2883,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start %s FROM usage_logs WHERE created_at >= $1 AND created_at < $2 - `, actualCostExpr) + `, modelExpr, actualCostExpr) args := []any{startTime, endTime} if userID > 0 { @@ -2907,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) } - query += " GROUP BY model ORDER BY total_tokens DESC" + query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr) rows, err := r.sql.QueryContext(ctx, query, args...) if err != nil { @@ -3021,7 +3043,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim args = append(args, dim.GroupID) } if dim.Model != "" { - query += fmt.Sprintf(" AND ul.model = $%d", len(args)+1) + query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1) args = append(args, dim.Model) } if dim.Endpoint != "" { @@ -3067,6 +3089,18 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim return results, nil } +// resolveModelDimensionExpression maps model source type to a safe SQL expression. +func resolveModelDimensionExpression(modelType string) string { + switch usagestats.NormalizeModelSource(modelType) { + case usagestats.ModelSourceUpstream: + return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)" + case usagestats.ModelSourceMapping: + return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))" + default: + return "model" + } +} + // resolveEndpointColumn maps endpoint type to the corresponding DB column name. func resolveEndpointColumn(endpointType string) string { switch endpointType { @@ -3819,6 +3853,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e accountID int64 requestID sql.NullString model string + upstreamModel sql.NullString groupID sql.NullInt64 subscriptionID sql.NullInt64 inputTokens int @@ -3861,6 +3896,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &accountID, &requestID, &model, + &upstreamModel, &groupID, &subscriptionID, &inputTokens, @@ -3973,6 +4009,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if upstreamEndpoint.Valid { log.UpstreamEndpoint = &upstreamEndpoint.String } + if upstreamModel.Valid { + log.UpstreamModel = &upstreamModel.String + } return log, nil } diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index ad29990f..1c960fdf 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi return stats, nil } +func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) { + normalizedSource := usagestats.NormalizeModelSource(modelSource) + if normalizedSource == usagestats.ModelSourceRequested { + return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + } + + type modelStatsBySourceRepo interface { + GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error) + } + + if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok { + stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource) + if err != nil { + return nil, fmt.Errorf("get model stats with filters by source: %w", err) + } + return stats, nil + } + + return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) +} + func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { From 56fcb20f94f3d651d87753b84db54c42dada6ffd Mon Sep 17 00:00:00 2001 From: Ethan0x0000 <3352979663@qq.com> Date: Tue, 17 Mar 2026 19:26:11 +0800 Subject: [PATCH 4/7] feat(api): expose model_source filter in dashboard endpoints Add model_source query parameter to GetModelStats and GetUserBreakdown handlers with explicit IsValidModelSource validation. Include model_source in cache key to prevent cross-source cache hits. Expose upstream_model in usage log DTO with omitempty semantics. --- .../internal/handler/admin/dashboard_handler.go | 16 +++++++++++++++- .../handler/admin/dashboard_query_cache.go | 5 ++++- .../admin/dashboard_snapshot_v2_handler.go | 1 + backend/internal/handler/dto/mappers.go | 1 + backend/internal/handler/dto/types.go | 3 +++ 5 files changed, 24 insertions(+), 2 deletions(-) diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index a34bbd39..2a214471 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -273,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 + modelSource := usagestats.ModelSourceRequested var requestType *int16 var stream *bool var billingType *int8 @@ -297,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { groupID = id } } + if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" { + if !usagestats.IsValidModelSource(rawModelSource) { + response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping") + return + } + modelSource = rawModelSource + } if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { parsed, err := service.ParseUsageRequestType(requestTypeStr) if err != nil { @@ -323,7 +331,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { } } - stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return @@ -619,6 +627,12 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) { } } dim.Model = c.Query("model") + rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested)) + if !usagestats.IsValidModelSource(rawModelSource) { + response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping") + return + } + dim.ModelType = rawModelSource dim.Endpoint = c.Query("endpoint") dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound") diff --git a/backend/internal/handler/admin/dashboard_query_cache.go b/backend/internal/handler/admin/dashboard_query_cache.go index 47af5117..815c5161 100644 --- a/backend/internal/handler/admin/dashboard_query_cache.go +++ b/backend/internal/handler/admin/dashboard_query_cache.go @@ -38,6 +38,7 @@ type dashboardModelGroupCacheKey struct { APIKeyID int64 `json:"api_key_id"` AccountID int64 `json:"account_id"` GroupID int64 `json:"group_id"` + ModelSource string `json:"model_source,omitempty"` RequestType *int16 `json:"request_type"` Stream *bool `json:"stream"` BillingType *int8 `json:"billing_type"` @@ -111,6 +112,7 @@ func (h *DashboardHandler) getModelStatsCached( ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, + modelSource string, requestType *int16, stream *bool, billingType *int8, @@ -122,12 +124,13 @@ func (h *DashboardHandler) getModelStatsCached( APIKeyID: apiKeyID, AccountID: accountID, GroupID: groupID, + ModelSource: usagestats.NormalizeModelSource(modelSource), RequestType: requestType, Stream: stream, BillingType: billingType, }) entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) { - return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource) }) if err != nil { return nil, hit, err diff --git a/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go index 16e10339..517ae7bd 100644 --- a/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go +++ b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go @@ -200,6 +200,7 @@ func (h *DashboardHandler) buildSnapshotV2Response( filters.APIKeyID, filters.AccountID, filters.GroupID, + usagestats.ModelSourceRequested, filters.RequestType, filters.Stream, filters.BillingType, diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 8e5f23e7..cc25f7c3 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -521,6 +521,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { AccountID: l.AccountID, RequestID: l.RequestID, Model: l.Model, + UpstreamModel: l.UpstreamModel, ServiceTier: l.ServiceTier, ReasoningEffort: l.ReasoningEffort, InboundEndpoint: l.InboundEndpoint, diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index c52e357e..fa360804 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -332,6 +332,9 @@ type UsageLog struct { AccountID int64 `json:"account_id"` RequestID string `json:"request_id"` Model string `json:"model"` + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Omitted when no mapping was applied (requested model was used as-is). + UpstreamModel *string `json:"upstream_model,omitempty"` // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". ServiceTier *string `json:"service_tier,omitempty"` // ReasoningEffort is the request's reasoning effort level. From eeff451bc58717b994901c48207be19daba39fdc Mon Sep 17 00:00:00 2001 From: Ethan0x0000 <3352979663@qq.com> Date: Tue, 17 Mar 2026 19:26:30 +0800 Subject: [PATCH 5/7] test(backend): add tests for upstream model tracking and model source filtering Cover IsValidModelSource/NormalizeModelSource, resolveModelDimensionExpression SQL expressions, invalid model_source 400 responses on both GetModelStats and GetUserBreakdown, upstream_model in scan/insert SQL mock expectations, and updated passthrough/billing test signatures. --- .../dashboard_handler_request_type_test.go | 22 +++++++++ .../dashboard_handler_user_breakdown_test.go | 26 ++++++++++ .../pkg/usagestats/usage_log_types_test.go | 47 +++++++++++++++++++ .../usage_log_repo_breakdown_test.go | 25 +++++++++- .../usage_log_repo_request_type_test.go | 5 ++ ...teway_anthropic_apikey_passthrough_test.go | 8 ++-- .../openai_gateway_record_usage_test.go | 7 ++- 7 files changed, 132 insertions(+), 8 deletions(-) create mode 100644 backend/internal/pkg/usagestats/usage_log_types_test.go diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go index 9aec61d4..6056f725 100644 --- a/backend/internal/handler/admin/dashboard_handler_request_type_test.go +++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go @@ -149,6 +149,28 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) { require.Equal(t, http.StatusBadRequest, rec.Code) } +func TestDashboardModelStatsInvalidModelSource(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsValidModelSource(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) +} + func TestDashboardUsersRankingLimitAndCache(t *testing.T) { dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) repo := &dashboardUsageRepoCapture{ diff --git a/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go b/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go index 2c1dbd59..b3a05111 100644 --- a/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go +++ b/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go @@ -73,9 +73,35 @@ func TestGetUserBreakdown_ModelFilter(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model) + require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType) require.Equal(t, int64(0), repo.capturedDim.GroupID) } +func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType) +} + +func TestGetUserBreakdown_InvalidModelSource(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code) +} + func TestGetUserBreakdown_EndpointFilter(t *testing.T) { repo := &userBreakdownRepoCapture{} router := newUserBreakdownRouter(repo) diff --git a/backend/internal/pkg/usagestats/usage_log_types_test.go b/backend/internal/pkg/usagestats/usage_log_types_test.go new file mode 100644 index 00000000..95cf6069 --- /dev/null +++ b/backend/internal/pkg/usagestats/usage_log_types_test.go @@ -0,0 +1,47 @@ +package usagestats + +import "testing" + +func TestIsValidModelSource(t *testing.T) { + tests := []struct { + name string + source string + want bool + }{ + {name: "requested", source: ModelSourceRequested, want: true}, + {name: "upstream", source: ModelSourceUpstream, want: true}, + {name: "mapping", source: ModelSourceMapping, want: true}, + {name: "invalid", source: "foobar", want: false}, + {name: "empty", source: "", want: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := IsValidModelSource(tc.source); got != tc.want { + t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want) + } + }) + } +} + +func TestNormalizeModelSource(t *testing.T) { + tests := []struct { + name string + source string + want string + }{ + {name: "requested", source: ModelSourceRequested, want: ModelSourceRequested}, + {name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream}, + {name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping}, + {name: "invalid falls back", source: "foobar", want: ModelSourceRequested}, + {name: "empty falls back", source: "", want: ModelSourceRequested}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := NormalizeModelSource(tc.source); got != tc.want { + t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want) + } + }) + } +} diff --git a/backend/internal/repository/usage_log_repo_breakdown_test.go b/backend/internal/repository/usage_log_repo_breakdown_test.go index ca63e0bc..5d908bfd 100644 --- a/backend/internal/repository/usage_log_repo_breakdown_test.go +++ b/backend/internal/repository/usage_log_repo_breakdown_test.go @@ -5,6 +5,7 @@ package repository import ( "testing" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/stretchr/testify/require" ) @@ -16,8 +17,8 @@ func TestResolveEndpointColumn(t *testing.T) { {"inbound", "ul.inbound_endpoint"}, {"upstream", "ul.upstream_endpoint"}, {"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"}, - {"", "ul.inbound_endpoint"}, // default - {"unknown", "ul.inbound_endpoint"}, // fallback + {"", "ul.inbound_endpoint"}, // default + {"unknown", "ul.inbound_endpoint"}, // fallback } for _, tc := range tests { @@ -27,3 +28,23 @@ func TestResolveEndpointColumn(t *testing.T) { }) } } + +func TestResolveModelDimensionExpression(t *testing.T) { + tests := []struct { + modelType string + want string + }{ + {usagestats.ModelSourceRequested, "model"}, + {usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"}, + {usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"}, + {"", "model"}, + {"invalid", "model"}, + } + + for _, tc := range tests { + t.Run(tc.modelType, func(t *testing.T) { + got := resolveModelDimensionExpression(tc.modelType) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 27ae4571..76827c31 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { log.AccountID, log.RequestID, log.Model, + sqlmock.AnyArg(), // upstream_model sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // subscription_id log.InputTokens, @@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { log.Model, sqlmock.AnyArg(), sqlmock.AnyArg(), + sqlmock.AnyArg(), log.InputTokens, log.OutputTokens, log.CacheCreationTokens, @@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(30), // account_id sql.NullString{Valid: true, String: "req-1"}, "gpt-5", // model + sql.NullString{}, // upstream_model sql.NullInt64{}, // group_id sql.NullInt64{}, // subscription_id 1, // input_tokens @@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(31), sql.NullString{Valid: true, String: "req-2"}, "gpt-5", + sql.NullString{}, sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, @@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(32), sql.NullString{Valid: true, String: "req-3"}, "gpt-5.4", + sql.NullString{}, sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 789cbab8..c534a9b7 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -788,7 +788,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc rateLimitService: &RateLimitService{}, } - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now()) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 12, result.Usage.InputTokens) @@ -815,7 +815,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp } svc := &GatewayService{} - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "requires apikey token") @@ -840,7 +840,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest } account := newAnthropicAPIKeyAccountForTest() - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "upstream request failed") @@ -873,7 +873,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo httpUpstream: upstream, } - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "empty response") diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index ada7d805..a35f9127 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -846,7 +846,7 @@ func TestExtractOpenAIServiceTierFromBody(t *testing.T) { require.Nil(t, extractOpenAIServiceTierFromBody(nil)) } -func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) { +func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetadataFields(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{} @@ -859,6 +859,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te RequestID: "resp_billing_model_override", BillingModel: "gpt-5.1-codex", Model: "gpt-5.1", + UpstreamModel: "gpt-5.1-codex", ServiceTier: &serviceTier, ReasoningEffort: &reasoning, Usage: OpenAIUsage{ @@ -877,7 +878,9 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te require.NoError(t, err) require.NotNil(t, usageRepo.lastLog) - require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model) + require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model) + require.NotNil(t, usageRepo.lastLog.UpstreamModel) + require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel) require.NotNil(t, usageRepo.lastLog.ServiceTier) require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) require.NotNil(t, usageRepo.lastLog.ReasoningEffort) From 62b40636e09f58755c297d6d30feb7aac97c6502 Mon Sep 17 00:00:00 2001 From: Ethan0x0000 <3352979663@qq.com> Date: Tue, 17 Mar 2026 19:26:48 +0800 Subject: [PATCH 6/7] feat(frontend): display upstream model in usage table and distribution charts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Show upstream model mapping (requested -> upstream) in UsageTable with arrow notation. Add requested/upstream/mapping source toggle to ModelDistributionChart with lazy loading — only fetches data when user switches tab, with per-source cache invalidation on filter changes. Include upstream_model column in Excel export and i18n for en/zh. --- frontend/src/api/admin/dashboard.ts | 2 + .../src/components/admin/usage/UsageTable.vue | 12 +- .../charts/EndpointDistributionChart.vue | 4 +- .../charts/ModelDistributionChart.vue | 59 +++++++++- frontend/src/i18n/locales/en.ts | 3 + frontend/src/i18n/locales/zh.ts | 3 + frontend/src/types/index.ts | 1 + frontend/src/views/admin/UsageView.vue | 111 ++++++++++++++++-- 8 files changed, 177 insertions(+), 18 deletions(-) diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index 0bf0a2c5..15d1540f 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -81,6 +81,7 @@ export interface ModelStatsParams { user_id?: number api_key_id?: number model?: string + model_source?: 'requested' | 'upstream' | 'mapping' account_id?: number group_id?: number request_type?: UsageRequestType @@ -162,6 +163,7 @@ export interface UserBreakdownParams { end_date?: string group_id?: number model?: string + model_source?: 'requested' | 'upstream' | 'mapping' endpoint?: string endpoint_type?: 'inbound' | 'upstream' | 'path' limit?: number diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index aa6c2bbd..4a42ab05 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -25,8 +25,16 @@ {{ row.account?.name || '-' }} -