diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index acdd0d18..c4f3af5e 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -716,6 +716,7 @@ var ( {Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "request_id", Type: field.TypeString, Size: 64}, {Name: "model", Type: field.TypeString, Size: 100}, + {Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "input_tokens", Type: field.TypeInt, Default: 0}, {Name: "output_tokens", Type: field.TypeInt, Default: 0}, @@ -756,31 +757,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[34]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -789,38 +790,43 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[34]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_model", Unique: false, Columns: []*schema.Column{UsageLogsColumns[2]}, }, + { + Name: "usagelog_requested_model", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[3]}, + }, { Name: "usagelog_request_id", Unique: false, @@ -829,17 +835,17 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[29]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[29]}, }, { Name: "usagelog_group_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[29]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index ff58fa9e..10f7afe4 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -18239,6 +18239,7 @@ type UsageLogMutation struct { id *int64 request_id *string model *string + requested_model *string upstream_model *string input_tokens *int addinput_tokens *int @@ -18577,6 +18578,55 @@ func (m *UsageLogMutation) ResetModel() { m.model = nil } +// SetRequestedModel sets the "requested_model" field. +func (m *UsageLogMutation) SetRequestedModel(s string) { + m.requested_model = &s +} + +// RequestedModel returns the value of the "requested_model" field in the mutation. +func (m *UsageLogMutation) RequestedModel() (r string, exists bool) { + v := m.requested_model + if v == nil { + return + } + return *v, true +} + +// OldRequestedModel returns the old "requested_model" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldRequestedModel(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequestedModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequestedModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequestedModel: %w", err) + } + return oldValue.RequestedModel, nil +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (m *UsageLogMutation) ClearRequestedModel() { + m.requested_model = nil + m.clearedFields[usagelog.FieldRequestedModel] = struct{}{} +} + +// RequestedModelCleared returns if the "requested_model" field was cleared in this mutation. +func (m *UsageLogMutation) RequestedModelCleared() bool { + _, ok := m.clearedFields[usagelog.FieldRequestedModel] + return ok +} + +// ResetRequestedModel resets all changes to the "requested_model" field. +func (m *UsageLogMutation) ResetRequestedModel() { + m.requested_model = nil + delete(m.clearedFields, usagelog.FieldRequestedModel) +} + // SetUpstreamModel sets the "upstream_model" field. func (m *UsageLogMutation) SetUpstreamModel(s string) { m.upstream_model = &s @@ -20247,7 +20297,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 33) + fields := make([]string, 0, 34) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -20263,6 +20313,9 @@ func (m *UsageLogMutation) Fields() []string { if m.model != nil { fields = append(fields, usagelog.FieldModel) } + if m.requested_model != nil { + fields = append(fields, usagelog.FieldRequestedModel) + } if m.upstream_model != nil { fields = append(fields, usagelog.FieldUpstreamModel) } @@ -20365,6 +20418,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.RequestID() case usagelog.FieldModel: return m.Model() + case usagelog.FieldRequestedModel: + return m.RequestedModel() case usagelog.FieldUpstreamModel: return m.UpstreamModel() case usagelog.FieldGroupID: @@ -20440,6 +20495,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldRequestID(ctx) case usagelog.FieldModel: return m.OldModel(ctx) + case usagelog.FieldRequestedModel: + return m.OldRequestedModel(ctx) case usagelog.FieldUpstreamModel: return m.OldUpstreamModel(ctx) case usagelog.FieldGroupID: @@ -20540,6 +20597,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetModel(v) return nil + case usagelog.FieldRequestedModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequestedModel(v) + return nil case usagelog.FieldUpstreamModel: v, ok := value.(string) if !ok { @@ -20985,6 +21049,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error { // mutation. func (m *UsageLogMutation) ClearedFields() []string { var fields []string + if m.FieldCleared(usagelog.FieldRequestedModel) { + fields = append(fields, usagelog.FieldRequestedModel) + } if m.FieldCleared(usagelog.FieldUpstreamModel) { fields = append(fields, usagelog.FieldUpstreamModel) } @@ -21029,6 +21096,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *UsageLogMutation) ClearField(name string) error { switch name { + case usagelog.FieldRequestedModel: + m.ClearRequestedModel() + return nil case usagelog.FieldUpstreamModel: m.ClearUpstreamModel() return nil @@ -21082,6 +21152,9 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldModel: m.ResetModel() return nil + case usagelog.FieldRequestedModel: + m.ResetRequestedModel() + return nil case usagelog.FieldUpstreamModel: m.ResetUpstreamModel() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 2401e553..19c58d76 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -821,96 +821,100 @@ func init() { return nil } }() + // usagelogDescRequestedModel is the schema descriptor for requested_model field. + usagelogDescRequestedModel := usagelogFields[5].Descriptor() + // usagelog.RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save. + usagelog.RequestedModelValidator = usagelogDescRequestedModel.Validators[0].(func(string) error) // usagelogDescUpstreamModel is the schema descriptor for upstream_model field. - usagelogDescUpstreamModel := usagelogFields[5].Descriptor() + usagelogDescUpstreamModel := usagelogFields[6].Descriptor() // usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error) // usagelogDescInputTokens is the schema descriptor for input_tokens field. - usagelogDescInputTokens := usagelogFields[8].Descriptor() + usagelogDescInputTokens := usagelogFields[9].Descriptor() // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) // usagelogDescOutputTokens is the schema descriptor for output_tokens field. - usagelogDescOutputTokens := usagelogFields[9].Descriptor() + usagelogDescOutputTokens := usagelogFields[10].Descriptor() // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. - usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor() + usagelogDescCacheCreationTokens := usagelogFields[11].Descriptor() // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. - usagelogDescCacheReadTokens := usagelogFields[11].Descriptor() + usagelogDescCacheReadTokens := usagelogFields[12].Descriptor() // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. - usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor() + usagelogDescCacheCreation5mTokens := usagelogFields[13].Descriptor() // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. - usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor() + usagelogDescCacheCreation1hTokens := usagelogFields[14].Descriptor() // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) // usagelogDescInputCost is the schema descriptor for input_cost field. - usagelogDescInputCost := usagelogFields[14].Descriptor() + usagelogDescInputCost := usagelogFields[15].Descriptor() // usagelog.DefaultInputCost holds the default value on creation for the input_cost field. usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) // usagelogDescOutputCost is the schema descriptor for output_cost field. - usagelogDescOutputCost := usagelogFields[15].Descriptor() + usagelogDescOutputCost := usagelogFields[16].Descriptor() // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. - usagelogDescCacheCreationCost := usagelogFields[16].Descriptor() + usagelogDescCacheCreationCost := usagelogFields[17].Descriptor() // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. - usagelogDescCacheReadCost := usagelogFields[17].Descriptor() + usagelogDescCacheReadCost := usagelogFields[18].Descriptor() // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) // usagelogDescTotalCost is the schema descriptor for total_cost field. - usagelogDescTotalCost := usagelogFields[18].Descriptor() + usagelogDescTotalCost := usagelogFields[19].Descriptor() // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) // usagelogDescActualCost is the schema descriptor for actual_cost field. - usagelogDescActualCost := usagelogFields[19].Descriptor() + usagelogDescActualCost := usagelogFields[20].Descriptor() // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. - usagelogDescRateMultiplier := usagelogFields[20].Descriptor() + usagelogDescRateMultiplier := usagelogFields[21].Descriptor() // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) // usagelogDescBillingType is the schema descriptor for billing_type field. - usagelogDescBillingType := usagelogFields[22].Descriptor() + usagelogDescBillingType := usagelogFields[23].Descriptor() // usagelog.DefaultBillingType holds the default value on creation for the billing_type field. usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) // usagelogDescStream is the schema descriptor for stream field. - usagelogDescStream := usagelogFields[23].Descriptor() + usagelogDescStream := usagelogFields[24].Descriptor() // usagelog.DefaultStream holds the default value on creation for the stream field. usagelog.DefaultStream = usagelogDescStream.Default.(bool) // usagelogDescUserAgent is the schema descriptor for user_agent field. - usagelogDescUserAgent := usagelogFields[26].Descriptor() + usagelogDescUserAgent := usagelogFields[27].Descriptor() // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) // usagelogDescIPAddress is the schema descriptor for ip_address field. - usagelogDescIPAddress := usagelogFields[27].Descriptor() + usagelogDescIPAddress := usagelogFields[28].Descriptor() // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) // usagelogDescImageCount is the schema descriptor for image_count field. - usagelogDescImageCount := usagelogFields[28].Descriptor() + usagelogDescImageCount := usagelogFields[29].Descriptor() // usagelog.DefaultImageCount holds the default value on creation for the image_count field. usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) // usagelogDescImageSize is the schema descriptor for image_size field. - usagelogDescImageSize := usagelogFields[29].Descriptor() + usagelogDescImageSize := usagelogFields[30].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) // usagelogDescMediaType is the schema descriptor for media_type field. - usagelogDescMediaType := usagelogFields[30].Descriptor() + usagelogDescMediaType := usagelogFields[31].Descriptor() // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. - usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor() + usagelogDescCacheTTLOverridden := usagelogFields[32].Descriptor() // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[32].Descriptor() + usagelogDescCreatedAt := usagelogFields[33].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index 8f8a5255..32c39e25 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field { field.String("model"). MaxLen(100). NotEmpty(), + // RequestedModel stores the client-requested model name for stable display and analytics. + // NULL means historical rows written before requested_model dual-write was introduced. + field.String("requested_model"). + MaxLen(100). + Optional(). + Nillable(), // UpstreamModel stores the actual upstream model name when model mapping // is applied. NULL means no mapping — the requested model was used as-is. field.String("upstream_model"). @@ -181,6 +187,7 @@ func (UsageLog) Indexes() []ent.Index { index.Fields("subscription_id"), index.Fields("created_at"), index.Fields("model"), + index.Fields("requested_model"), index.Fields("request_id"), // 复合索引用于时间范围查询 index.Fields("user_id", "created_at"), diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index 014851c9..fb4ee1c5 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -32,6 +32,8 @@ type UsageLog struct { RequestID string `json:"request_id,omitempty"` // Model holds the value of the "model" field. Model string `json:"model,omitempty"` + // RequestedModel holds the value of the "requested_model" field. + RequestedModel *string `json:"requested_model,omitempty"` // UpstreamModel holds the value of the "upstream_model" field. UpstreamModel *string `json:"upstream_model,omitempty"` // GroupID holds the value of the "group_id" field. @@ -177,7 +179,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -232,6 +234,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Model = value.String } + case usagelog.FieldRequestedModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field requested_model", values[i]) + } else if value.Valid { + _m.RequestedModel = new(string) + *_m.RequestedModel = value.String + } case usagelog.FieldUpstreamModel: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field upstream_model", values[i]) @@ -486,6 +495,11 @@ func (_m *UsageLog) String() string { builder.WriteString("model=") builder.WriteString(_m.Model) builder.WriteString(", ") + if v := _m.RequestedModel; v != nil { + builder.WriteString("requested_model=") + builder.WriteString(*v) + } + builder.WriteString(", ") if v := _m.UpstreamModel; v != nil { builder.WriteString("upstream_model=") builder.WriteString(*v) diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index 789407e7..b534f193 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -24,6 +24,8 @@ const ( FieldRequestID = "request_id" // FieldModel holds the string denoting the model field in the database. FieldModel = "model" + // FieldRequestedModel holds the string denoting the requested_model field in the database. + FieldRequestedModel = "requested_model" // FieldUpstreamModel holds the string denoting the upstream_model field in the database. FieldUpstreamModel = "upstream_model" // FieldGroupID holds the string denoting the group_id field in the database. @@ -137,6 +139,7 @@ var Columns = []string{ FieldAccountID, FieldRequestID, FieldModel, + FieldRequestedModel, FieldUpstreamModel, FieldGroupID, FieldSubscriptionID, @@ -182,6 +185,8 @@ var ( RequestIDValidator func(string) error // ModelValidator is a validator for the "model" field. It is called by the builders before save. ModelValidator func(string) error + // RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save. + RequestedModelValidator func(string) error // UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. UpstreamModelValidator func(string) error // DefaultInputTokens holds the default value on creation for the "input_tokens" field. @@ -263,6 +268,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldModel, opts...).ToFunc() } +// ByRequestedModel orders the results by the requested_model field. +func ByRequestedModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequestedModel, opts...).ToFunc() +} + // ByUpstreamModel orders the results by the upstream_model field. func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index 5f341976..f95bceb7 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) } +// RequestedModel applies equality check predicate on the "requested_model" field. It's identical to RequestedModelEQ. +func RequestedModel(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v)) +} + // UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ. func UpstreamModel(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) @@ -410,6 +415,81 @@ func ModelContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v)) } +// RequestedModelEQ applies the EQ predicate on the "requested_model" field. +func RequestedModelEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v)) +} + +// RequestedModelNEQ applies the NEQ predicate on the "requested_model" field. +func RequestedModelNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldRequestedModel, v)) +} + +// RequestedModelIn applies the In predicate on the "requested_model" field. +func RequestedModelIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldRequestedModel, vs...)) +} + +// RequestedModelNotIn applies the NotIn predicate on the "requested_model" field. +func RequestedModelNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldRequestedModel, vs...)) +} + +// RequestedModelGT applies the GT predicate on the "requested_model" field. +func RequestedModelGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldRequestedModel, v)) +} + +// RequestedModelGTE applies the GTE predicate on the "requested_model" field. +func RequestedModelGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldRequestedModel, v)) +} + +// RequestedModelLT applies the LT predicate on the "requested_model" field. +func RequestedModelLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldRequestedModel, v)) +} + +// RequestedModelLTE applies the LTE predicate on the "requested_model" field. +func RequestedModelLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldRequestedModel, v)) +} + +// RequestedModelContains applies the Contains predicate on the "requested_model" field. +func RequestedModelContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldRequestedModel, v)) +} + +// RequestedModelHasPrefix applies the HasPrefix predicate on the "requested_model" field. +func RequestedModelHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldRequestedModel, v)) +} + +// RequestedModelHasSuffix applies the HasSuffix predicate on the "requested_model" field. +func RequestedModelHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldRequestedModel, v)) +} + +// RequestedModelIsNil applies the IsNil predicate on the "requested_model" field. +func RequestedModelIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldRequestedModel)) +} + +// RequestedModelNotNil applies the NotNil predicate on the "requested_model" field. +func RequestedModelNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldRequestedModel)) +} + +// RequestedModelEqualFold applies the EqualFold predicate on the "requested_model" field. +func RequestedModelEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldRequestedModel, v)) +} + +// RequestedModelContainsFold applies the ContainsFold predicate on the "requested_model" field. +func RequestedModelContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldRequestedModel, v)) +} + // UpstreamModelEQ applies the EQ predicate on the "upstream_model" field. func UpstreamModelEQ(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index 26be5dcb..6ae0bf59 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate { return _c } +// SetRequestedModel sets the "requested_model" field. +func (_c *UsageLogCreate) SetRequestedModel(v string) *UsageLogCreate { + _c.mutation.SetRequestedModel(v) + return _c +} + +// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableRequestedModel(v *string) *UsageLogCreate { + if v != nil { + _c.SetRequestedModel(*v) + } + return _c +} + // SetUpstreamModel sets the "upstream_model" field. func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate { _c.mutation.SetUpstreamModel(v) @@ -610,6 +624,11 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _c.mutation.RequestedModel(); ok { + if err := usagelog.RequestedModelValidator(v); err != nil { + return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)} + } + } if v, ok := _c.mutation.UpstreamModel(); ok { if err := usagelog.UpstreamModelValidator(v); err != nil { return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} @@ -733,6 +752,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldModel, field.TypeString, value) _node.Model = value } + if value, ok := _c.mutation.RequestedModel(); ok { + _spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value) + _node.RequestedModel = &value + } if value, ok := _c.mutation.UpstreamModel(); ok { _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) _node.UpstreamModel = &value @@ -1034,6 +1057,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert { return u } +// SetRequestedModel sets the "requested_model" field. +func (u *UsageLogUpsert) SetRequestedModel(v string) *UsageLogUpsert { + u.Set(usagelog.FieldRequestedModel, v) + return u +} + +// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateRequestedModel() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldRequestedModel) + return u +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (u *UsageLogUpsert) ClearRequestedModel() *UsageLogUpsert { + u.SetNull(usagelog.FieldRequestedModel) + return u +} + // SetUpstreamModel sets the "upstream_model" field. func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert { u.Set(usagelog.FieldUpstreamModel, v) @@ -1641,6 +1682,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne { }) } +// SetRequestedModel sets the "requested_model" field. +func (u *UsageLogUpsertOne) SetRequestedModel(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetRequestedModel(v) + }) +} + +// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateRequestedModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRequestedModel() + }) +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (u *UsageLogUpsertOne) ClearRequestedModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearRequestedModel() + }) +} + // SetUpstreamModel sets the "upstream_model" field. func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -2496,6 +2558,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk { }) } +// SetRequestedModel sets the "requested_model" field. +func (u *UsageLogUpsertBulk) SetRequestedModel(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetRequestedModel(v) + }) +} + +// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateRequestedModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRequestedModel() + }) +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (u *UsageLogUpsertBulk) ClearRequestedModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearRequestedModel() + }) +} + // SetUpstreamModel sets the "upstream_model" field. func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index b7c4632c..516407b9 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate { return _u } +// SetRequestedModel sets the "requested_model" field. +func (_u *UsageLogUpdate) SetRequestedModel(v string) *UsageLogUpdate { + _u.mutation.SetRequestedModel(v) + return _u +} + +// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableRequestedModel(v *string) *UsageLogUpdate { + if v != nil { + _u.SetRequestedModel(*v) + } + return _u +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (_u *UsageLogUpdate) ClearRequestedModel() *UsageLogUpdate { + _u.mutation.ClearRequestedModel() + return _u +} + // SetUpstreamModel sets the "upstream_model" field. func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate { _u.mutation.SetUpstreamModel(v) @@ -765,6 +785,11 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _u.mutation.RequestedModel(); ok { + if err := usagelog.RequestedModelValidator(v); err != nil { + return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)} + } + } if v, ok := _u.mutation.UpstreamModel(); ok { if err := usagelog.UpstreamModelValidator(v); err != nil { return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} @@ -820,6 +845,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Model(); ok { _spec.SetField(usagelog.FieldModel, field.TypeString, value) } + if value, ok := _u.mutation.RequestedModel(); ok { + _spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value) + } + if _u.mutation.RequestedModelCleared() { + _spec.ClearField(usagelog.FieldRequestedModel, field.TypeString) + } if value, ok := _u.mutation.UpstreamModel(); ok { _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) } @@ -1208,6 +1239,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne { return _u } +// SetRequestedModel sets the "requested_model" field. +func (_u *UsageLogUpdateOne) SetRequestedModel(v string) *UsageLogUpdateOne { + _u.mutation.SetRequestedModel(v) + return _u +} + +// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableRequestedModel(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetRequestedModel(*v) + } + return _u +} + +// ClearRequestedModel clears the value of the "requested_model" field. +func (_u *UsageLogUpdateOne) ClearRequestedModel() *UsageLogUpdateOne { + _u.mutation.ClearRequestedModel() + return _u +} + // SetUpstreamModel sets the "upstream_model" field. func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne { _u.mutation.SetUpstreamModel(v) @@ -1884,6 +1935,11 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _u.mutation.RequestedModel(); ok { + if err := usagelog.RequestedModelValidator(v); err != nil { + return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)} + } + } if v, ok := _u.mutation.UpstreamModel(); ok { if err := usagelog.UpstreamModelValidator(v); err != nil { return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} @@ -1956,6 +2012,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if value, ok := _u.mutation.Model(); ok { _spec.SetField(usagelog.FieldModel, field.TypeString, value) } + if value, ok := _u.mutation.RequestedModel(); ok { + _spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value) + } + if _u.mutation.RequestedModelCleared() { + _spec.ClearField(usagelog.FieldRequestedModel, field.TypeString) + } if value, ok := _u.mutation.UpstreamModel(); ok { _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) } diff --git a/backend/go.sum b/backend/go.sum index 270be5f8..f5b7968f 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -94,6 +94,10 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= +github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= @@ -195,6 +199,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -230,6 +236,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -263,6 +271,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -314,6 +324,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index c91566c8..c209caf9 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -110,6 +110,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, DefaultSubscriptions: defaultSubscriptions, @@ -176,6 +177,7 @@ type UpdateSettingsRequest struct { PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` + CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` // 默认配置 DefaultConcurrency int `json:"default_concurrency"` @@ -417,6 +419,55 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { customMenuJSON = string(menuBytes) } + // 自定义端点验证 + const ( + maxCustomEndpoints = 10 + maxEndpointNameLen = 50 + maxEndpointURLLen = 2048 + maxEndpointDescriptionLen = 200 + ) + + customEndpointsJSON := previousSettings.CustomEndpoints + if req.CustomEndpoints != nil { + endpoints := *req.CustomEndpoints + if len(endpoints) > maxCustomEndpoints { + response.BadRequest(c, "Too many custom endpoints (max 10)") + return + } + for _, ep := range endpoints { + if strings.TrimSpace(ep.Name) == "" { + response.BadRequest(c, "Custom endpoint name is required") + return + } + if len(ep.Name) > maxEndpointNameLen { + response.BadRequest(c, "Custom endpoint name is too long (max 50 characters)") + return + } + if strings.TrimSpace(ep.Endpoint) == "" { + response.BadRequest(c, "Custom endpoint URL is required") + return + } + if len(ep.Endpoint) > maxEndpointURLLen { + response.BadRequest(c, "Custom endpoint URL is too long (max 2048 characters)") + return + } + if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(ep.Endpoint)); err != nil { + response.BadRequest(c, "Custom endpoint URL must be an absolute http(s) URL") + return + } + if len(ep.Description) > maxEndpointDescriptionLen { + response.BadRequest(c, "Custom endpoint description is too long (max 200 characters)") + return + } + } + endpointBytes, err := json.Marshal(endpoints) + if err != nil { + response.BadRequest(c, "Failed to serialize custom endpoints") + return + } + customEndpointsJSON = string(endpointBytes) + } + // Ops metrics collector interval validation (seconds). if req.OpsMetricsIntervalSeconds != nil { v := *req.OpsMetricsIntervalSeconds @@ -495,6 +546,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionURL: purchaseURL, SoraClientEnabled: req.SoraClientEnabled, CustomMenuItems: customMenuJSON, + CustomEndpoints: customEndpointsJSON, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, DefaultSubscriptions: defaultSubscriptions, @@ -592,6 +644,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, SoraClientEnabled: updatedSettings.SoraClientEnabled, CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, DefaultSubscriptions: updatedDefaultSubscriptions, diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index d1d867ee..8150aa8e 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -276,11 +276,17 @@ func AccountFromServiceShallow(a *service.Account) *Account { if limit := a.GetQuotaDailyLimit(); limit > 0 { out.QuotaDailyLimit = &limit used := a.GetQuotaDailyUsed() + if a.IsDailyQuotaPeriodExpired() { + used = 0 + } out.QuotaDailyUsed = &used } if limit := a.GetQuotaWeeklyLimit(); limit > 0 { out.QuotaWeeklyLimit = &limit used := a.GetQuotaWeeklyUsed() + if a.IsWeeklyQuotaPeriodExpired() { + used = 0 + } out.QuotaWeeklyUsed = &used } // 固定时间重置配置 @@ -516,14 +522,17 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { // 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。 requestType := l.EffectiveRequestType() stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode) + requestedModel := l.RequestedModel + if requestedModel == "" { + requestedModel = l.Model + } return UsageLog{ ID: l.ID, UserID: l.UserID, APIKeyID: l.APIKeyID, AccountID: l.AccountID, RequestID: l.RequestID, - Model: l.Model, - UpstreamModel: l.UpstreamModel, + Model: requestedModel, ServiceTier: l.ServiceTier, ReasoningEffort: l.ReasoningEffort, InboundEndpoint: l.InboundEndpoint, @@ -580,6 +589,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { } return &AdminUsageLog{ UsageLog: usageLogFromServiceUser(l), + UpstreamModel: l.UpstreamModel, AccountRateMultiplier: l.AccountRateMultiplier, IPAddress: l.IPAddress, Account: AccountSummaryFromService(l.Account), diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go index e4031970..c2635e33 100644 --- a/backend/internal/handler/dto/mappers_usage_test.go +++ b/backend/internal/handler/dto/mappers_usage_test.go @@ -1,6 +1,7 @@ package dto import ( + "encoding/json" "testing" "github.com/Wei-Shaw/sub2api/internal/service" @@ -106,6 +107,47 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) { require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12) } +func TestUsageLogFromService_UsesRequestedModelAndKeepsUpstreamAdminOnly(t *testing.T) { + t.Parallel() + + upstreamModel := "claude-sonnet-4-20250514" + log := &service.UsageLog{ + RequestID: "req_4", + Model: upstreamModel, + RequestedModel: "claude-sonnet-4", + UpstreamModel: &upstreamModel, + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + require.Equal(t, "claude-sonnet-4", userDTO.Model) + require.Equal(t, "claude-sonnet-4", adminDTO.Model) + + userJSON, err := json.Marshal(userDTO) + require.NoError(t, err) + require.NotContains(t, string(userJSON), "upstream_model") + + adminJSON, err := json.Marshal(adminDTO) + require.NoError(t, err) + require.Contains(t, string(adminJSON), `"upstream_model":"claude-sonnet-4-20250514"`) +} + +func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t *testing.T) { + t.Parallel() + + log := &service.UsageLog{ + RequestID: "req_legacy", + Model: "claude-3", + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + require.Equal(t, "claude-3", userDTO.Model) + require.Equal(t, "claude-3", adminDTO.Model) +} + func f64Ptr(value float64) *float64 { return &value } diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 0f4f8fdc..7ea34aa0 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -15,6 +15,13 @@ type CustomMenuItem struct { SortOrder int `json:"sort_order"` } +// CustomEndpoint represents an admin-configured API endpoint for quick copy. +type CustomEndpoint struct { + Name string `json:"name"` + Endpoint string `json:"endpoint"` + Description string `json:"description"` +} + // SystemSettings represents the admin settings API response payload. type SystemSettings struct { RegistrationEnabled bool `json:"registration_enabled"` @@ -56,6 +63,7 @@ type SystemSettings struct { PurchaseSubscriptionURL string `json:"purchase_subscription_url"` SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` + CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` @@ -114,6 +122,7 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` + CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` @@ -218,3 +227,17 @@ func ParseUserVisibleMenuItems(raw string) []CustomMenuItem { } return filtered } + +// ParseCustomEndpoints parses a JSON string into a slice of CustomEndpoint. +// Returns empty slice on empty/invalid input. +func ParseCustomEndpoints(raw string) []CustomEndpoint { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return []CustomEndpoint{} + } + var items []CustomEndpoint + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []CustomEndpoint{} + } + return items +} diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 7b3443be..d4a24e10 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -334,9 +334,6 @@ type UsageLog struct { AccountID int64 `json:"account_id"` RequestID string `json:"request_id"` Model string `json:"model"` - // UpstreamModel is the actual model sent to the upstream provider after mapping. - // Omitted when no mapping was applied (requested model was used as-is). - UpstreamModel *string `json:"upstream_model,omitempty"` // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". ServiceTier *string `json:"service_tier,omitempty"` // ReasoningEffort is the request's reasoning effort level. @@ -396,6 +393,10 @@ type UsageLog struct { type AdminUsageLog struct { UsageLog + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Omitted when no mapping was applied (requested model was used as-is). + UpstreamModel *string `json:"upstream_model,omitempty"` + // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) AccountRateMultiplier *float64 `json:"account_rate_multiplier"` diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 4db5cadd..dd158d8b 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -181,7 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - defaultMappedModel := c.GetString("openai_chat_completions_fallback_model") + defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model")) result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index ec957feb..b7f18d21 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -37,6 +37,16 @@ type OpenAIGatewayHandler struct { cfg *config.Config } +func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackModel string) string { + if fallbackModel = strings.TrimSpace(fallbackModel); fallbackModel != "" { + return fallbackModel + } + if apiKey == nil || apiKey.Group == nil { + return "" + } + return strings.TrimSpace(apiKey.Group.DefaultMappedModel) +} + // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, @@ -657,9 +667,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - // 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时, - // 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。 - defaultMappedModel := c.GetString("openai_messages_fallback_model") + // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 + // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。 + defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model")) result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index a26b3a0c..7bbf94ec 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -352,6 +352,30 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) { }) } +func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) { + t.Run("prefers_explicit_fallback_model", func(t *testing.T) { + apiKey := &service.APIKey{ + Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, + } + require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 ")) + }) + + t.Run("uses_group_default_on_normal_path", func(t *testing.T) { + apiKey := &service.APIKey{ + Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, + } + require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, "")) + }) + + t.Run("returns_empty_without_group_default", func(t *testing.T) { + require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, "")) + require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, "")) + require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{ + Group: &service.Group{}, + }, "")) + }) +} + func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 92061895..2c999cf1 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -52,6 +52,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, SoraClientEnabled: settings.SoraClientEnabled, BackendModeEnabled: settings.BackendModeEnabled, diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index 34f5b60c..095305c2 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -632,8 +632,8 @@ func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - // thinking.type is ignored for effort; default xhigh applies. - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + // thinking.type is ignored for effort; default high applies. + assert.Equal(t, "high", resp.Reasoning.Effort) assert.Equal(t, "auto", resp.Reasoning.Summary) assert.Contains(t, resp.Include, "reasoning.encrypted_content") assert.NotContains(t, resp.Include, "reasoning.summary") @@ -650,8 +650,8 @@ func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - // thinking.type is ignored for effort; default xhigh applies. - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + // thinking.type is ignored for effort; default high applies. + assert.Equal(t, "high", resp.Reasoning.Effort) assert.Equal(t, "auto", resp.Reasoning.Summary) assert.NotContains(t, resp.Include, "reasoning.summary") } @@ -666,9 +666,9 @@ func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) - // Default effort applies (high → xhigh) even when thinking is disabled. + // Default effort applies (high → high) even when thinking is disabled. require.NotNil(t, resp.Reasoning) - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "high", resp.Reasoning.Effort) } func TestAnthropicToResponses_NoThinking(t *testing.T) { @@ -680,9 +680,9 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) - // Default effort applies (high → xhigh) when no thinking/output_config is set. + // Default effort applies (high → high) when no thinking/output_config is set. require.NotNil(t, resp.Reasoning) - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "high", resp.Reasoning.Effort) } // --------------------------------------------------------------------------- @@ -690,7 +690,7 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) { // --------------------------------------------------------------------------- func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) { - // Default is xhigh, but output_config.effort="low" overrides. low→low after mapping. + // Default is high, but output_config.effort="low" overrides. low→low after mapping. req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -708,7 +708,7 @@ func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) { func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) { // No thinking field, but output_config.effort="medium" → creates reasoning. - // medium→high after mapping. + // medium→medium after 1:1 mapping. req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -719,12 +719,12 @@ func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "medium", resp.Reasoning.Effort) assert.Equal(t, "auto", resp.Reasoning.Summary) } func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) { - // output_config.effort="high" → mapped to "xhigh". + // output_config.effort="high" → mapped to "high" (1:1, both sides' default). req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -732,6 +732,22 @@ func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) { OutputConfig: &AnthropicOutputConfig{Effort: "high"}, } + resp, err := AnthropicToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestAnthropicToResponses_OutputConfigMax(t *testing.T) { + // output_config.effort="max" → mapped to OpenAI's highest supported level "xhigh". + req := &AnthropicRequest{ + Model: "gpt-5.2", + MaxTokens: 1024, + Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}}, + OutputConfig: &AnthropicOutputConfig{Effort: "max"}, + } + resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) @@ -740,7 +756,7 @@ func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) { } func TestAnthropicToResponses_NoOutputConfig(t *testing.T) { - // No output_config → default xhigh regardless of thinking.type. + // No output_config → default high regardless of thinking.type. req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -751,11 +767,11 @@ func TestAnthropicToResponses_NoOutputConfig(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "high", resp.Reasoning.Effort) } func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) { - // output_config present but effort empty (e.g. only format set) → default xhigh. + // output_config present but effort empty (e.g. only format set) → default high. req := &AnthropicRequest{ Model: "gpt-5.2", MaxTokens: 1024, @@ -766,7 +782,7 @@ func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) { resp, err := AnthropicToResponses(req) require.NoError(t, err) require.NotNil(t, resp.Reasoning) - assert.Equal(t, "xhigh", resp.Reasoning.Effort) + assert.Equal(t, "high", resp.Reasoning.Effort) } // --------------------------------------------------------------------------- diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses.go b/backend/internal/pkg/apicompat/anthropic_to_responses.go index fca3cf1f..485262e8 100644 --- a/backend/internal/pkg/apicompat/anthropic_to_responses.go +++ b/backend/internal/pkg/apicompat/anthropic_to_responses.go @@ -46,9 +46,10 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) { } // Determine reasoning effort: only output_config.effort controls the - // level; thinking.type is ignored. Default is xhigh when unset. - // Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh. - effort := "high" // default → maps to xhigh + // level; thinking.type is ignored. Default is high when unset (both + // Anthropic and OpenAI default to high). + // Anthropic levels map 1:1 to OpenAI: low→low, medium→medium, high→high, max→xhigh. + effort := "high" // default → both sides' default if req.OutputConfig != nil && req.OutputConfig.Effort != "" { effort = req.OutputConfig.Effort } @@ -380,18 +381,19 @@ func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string { // mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to // OpenAI Responses API effort levels. // +// Both APIs default to "high". The mapping is 1:1 for shared levels; +// only Anthropic's "max" (Opus 4.6 exclusive) maps to OpenAI's "xhigh" +// (GPT-5.2+ exclusive) as both represent the highest reasoning tier. +// // low → low -// medium → high -// high → xhigh +// medium → medium +// high → high +// max → xhigh func mapAnthropicEffortToResponses(effort string) string { - switch effort { - case "medium": - return "high" - case "high": + if effort == "max" { return "xhigh" - default: - return effort // "low" and any unknown values pass through unchanged } + return effort // low→low, medium→medium, high→high, unknown→passthrough } // convertAnthropicToolsToResponses maps Anthropic tool definitions to diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index 8b819033..f54a4a02 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -181,6 +181,35 @@ func TestChatCompletionsToResponses_ImageURL(t *testing.T) { assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL) } +func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "system", Content: json.RawMessage(`[{"type":"text","text":"You are a careful visual assistant."}]`)}, + {Role: "user", Content: json.RawMessage(`[{"type":"text","text":"Describe this image"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + + var systemParts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &systemParts)) + require.Len(t, systemParts, 1) + assert.Equal(t, "input_text", systemParts[0].Type) + assert.Equal(t, "You are a careful visual assistant.", systemParts[0].Text) + + var userParts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[1].Content, &userParts)) + require.Len(t, userParts, 2) + assert.Equal(t, "input_image", userParts[1].Type) + assert.Equal(t, "data:image/png;base64,abc123", userParts[1].ImageURL) +} + func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) { req := &ChatCompletionsRequest{ Model: "gpt-4o", @@ -398,6 +427,45 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) { assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent) } +func TestChatCompletionsToResponses_ToolArrayContent(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Use the tool"`)}, + { + Role: "assistant", + ToolCalls: []ChatToolCall{ + { + ID: "call_1", + Type: "function", + Function: ChatFunctionCall{ + Name: "inspect_image", + Arguments: `{}`, + }, + }, + }, + }, + { + Role: "tool", + ToolCallID: "call_1", + Content: json.RawMessage( + `[{"type":"text","text":"image width: 100"},{"type":"image_url","image_url":{"url":"data:image/png;base64,ignored"}},{"type":"text","text":"; image height: 200"}]`, + ), + }, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 3) + assert.Equal(t, "function_call_output", items[2].Type) + assert.Equal(t, "call_1", items[2].CallID) + assert.Equal(t, "image width: 100; image height: 200", items[2].Output) +} + func TestResponsesToChatCompletions_Incomplete(t *testing.T) { resp := &ResponsesResponse{ ID: "resp_inc", diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go index c4a9e773..6cdd012a 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -6,6 +6,11 @@ import ( "strings" ) +type chatMessageContent struct { + Text *string + Parts []ChatContentPart +} + // ChatCompletionsToResponses converts a Chat Completions request into a // Responses API request. The upstream always streams, so Stream is forced to // true. store is always false and reasoning.encrypted_content is always @@ -113,11 +118,11 @@ func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) { // chatSystemToResponses converts a system message. func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) { - text, err := parseChatContent(m.Content) + parsed, err := parseChatMessageContent(m.Content) if err != nil { return nil, err } - content, err := json.Marshal(text) + content, err := marshalChatInputContent(parsed) if err != nil { return nil, err } @@ -127,39 +132,11 @@ func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) { // chatUserToResponses converts a user message, handling both plain strings and // multi-modal content arrays. func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) { - // Try plain string first. - var s string - if err := json.Unmarshal(m.Content, &s); err == nil { - content, _ := json.Marshal(s) - return []ResponsesInputItem{{Role: "user", Content: content}}, nil - } - - var parts []ChatContentPart - if err := json.Unmarshal(m.Content, &parts); err != nil { + parsed, err := parseChatMessageContent(m.Content) + if err != nil { return nil, fmt.Errorf("parse user content: %w", err) } - - var responseParts []ResponsesContentPart - for _, p := range parts { - switch p.Type { - case "text": - if p.Text != "" { - responseParts = append(responseParts, ResponsesContentPart{ - Type: "input_text", - Text: p.Text, - }) - } - case "image_url": - if p.ImageURL != nil && p.ImageURL.URL != "" { - responseParts = append(responseParts, ResponsesContentPart{ - Type: "input_image", - ImageURL: p.ImageURL.URL, - }) - } - } - } - - content, err := json.Marshal(responseParts) + content, err := marshalChatInputContent(parsed) if err != nil { return nil, err } @@ -312,16 +289,79 @@ func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) { } // parseChatContent returns the string value of a ChatMessage Content field. -// Content must be a JSON string. Returns "" if content is null or empty. +// Content can be a JSON string or an array of typed parts. Array content is +// flattened to text by concatenating text parts and ignoring non-text parts. func parseChatContent(raw json.RawMessage) (string, error) { + parsed, err := parseChatMessageContent(raw) + if err != nil { + return "", err + } + if parsed.Text != nil { + return *parsed.Text, nil + } + return flattenChatContentParts(parsed.Parts), nil +} + +func parseChatMessageContent(raw json.RawMessage) (chatMessageContent, error) { if len(raw) == 0 { - return "", nil + return chatMessageContent{Text: stringPtr("")}, nil } + var s string - if err := json.Unmarshal(raw, &s); err != nil { - return "", fmt.Errorf("parse content as string: %w", err) + if err := json.Unmarshal(raw, &s); err == nil { + return chatMessageContent{Text: &s}, nil } - return s, nil + + var parts []ChatContentPart + if err := json.Unmarshal(raw, &parts); err == nil { + return chatMessageContent{Parts: parts}, nil + } + + return chatMessageContent{}, fmt.Errorf("parse content as string or parts array") +} + +func marshalChatInputContent(content chatMessageContent) (json.RawMessage, error) { + if content.Text != nil { + return json.Marshal(*content.Text) + } + return json.Marshal(convertChatContentPartsToResponses(content.Parts)) +} + +func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesContentPart { + var responseParts []ResponsesContentPart + for _, p := range parts { + switch p.Type { + case "text": + if p.Text != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_text", + Text: p.Text, + }) + } + case "image_url": + if p.ImageURL != nil && p.ImageURL.URL != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_image", + ImageURL: p.ImageURL.URL, + }) + } + } + } + return responseParts +} + +func flattenChatContentParts(parts []ChatContentPart) string { + var textParts []string + for _, p := range parts { + if p.Type == "text" && p.Text != "" { + textParts = append(textParts, p.Text) + } + } + return strings.Join(textParts, "") +} + +func stringPtr(s string) *string { + return &s } // convertChatToolsToResponses maps Chat Completions tool definitions and legacy diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index ca454606..e4da825b 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,50 +28,64 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +// usageLogInsertArgTypes must stay in the same order as: +// 1. prepareUsageLogInsert().args +// 2. every INSERT/CTE VALUES column list in this file +// 3. execUsageLogInsertNoResult placeholder positions +// 4. scanUsageLog selected column order (via usageLogSelectColumns) +// +// When adding a usage_logs column, update all of those call sites together. var usageLogInsertArgTypes = [...]string{ - "bigint", - "bigint", - "bigint", - "text", - "text", - "text", - "bigint", - "bigint", - "integer", - "integer", - "integer", - "integer", - "integer", - "integer", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "numeric", - "smallint", - "smallint", - "boolean", - "boolean", - "integer", - "integer", - "text", - "text", - "integer", - "text", - "text", - "text", - "text", - "text", - "text", - "boolean", - "timestamptz", + "bigint", // user_id + "bigint", // api_key_id + "bigint", // account_id + "text", // request_id + "text", // model + "text", // requested_model + "text", // upstream_model + "bigint", // group_id + "bigint", // subscription_id + "integer", // input_tokens + "integer", // output_tokens + "integer", // cache_creation_tokens + "integer", // cache_read_tokens + "integer", // cache_creation_5m_tokens + "integer", // cache_creation_1h_tokens + "numeric", // input_cost + "numeric", // output_cost + "numeric", // cache_creation_cost + "numeric", // cache_read_cost + "numeric", // total_cost + "numeric", // actual_cost + "numeric", // rate_multiplier + "numeric", // account_rate_multiplier + "smallint", // billing_type + "smallint", // request_type + "boolean", // stream + "boolean", // openai_ws_mode + "integer", // duration_ms + "integer", // first_token_ms + "text", // user_agent + "text", // ip_address + "integer", // image_count + "text", // image_size + "text", // media_type + "text", // service_tier + "text", // reasoning_effort + "text", // inbound_endpoint + "text", // upstream_endpoint + "boolean", // cache_ttl_overridden + "timestamptz", // created_at } +const rawUsageLogModelColumn = "model" + +// rawUsageLogModelColumn preserves the exact stored usage_logs.model semantics for direct filters. +// Historical rows may contain upstream/billing model values, while newer rows store requested_model. +// Requested/upstream/mapping analytics must use resolveModelDimensionExpression instead. + // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ "hour": "YYYY-MM-DD HH24:00", @@ -88,6 +102,30 @@ func safeDateFormat(granularity string) string { return "YYYY-MM-DD" } +// appendRawUsageLogModelWhereCondition keeps direct model filters on the raw model column for backward +// compatibility with historical rows. Requested/upstream analytics must use +// resolveModelDimensionExpression instead. +func appendRawUsageLogModelWhereCondition(conditions []string, args []any, model string) ([]string, []any) { + if strings.TrimSpace(model) == "" { + return conditions, args + } + conditions = append(conditions, fmt.Sprintf("%s = $%d", rawUsageLogModelColumn, len(args)+1)) + args = append(args, model) + return conditions, args +} + +// appendRawUsageLogModelQueryFilter keeps direct model filters on the raw model column for backward +// compatibility with historical rows. Requested/upstream analytics must use +// resolveModelDimensionExpression instead. +func appendRawUsageLogModelQueryFilter(query string, args []any, model string) (string, []any) { + if strings.TrimSpace(model) == "" { + return query, args + } + query += fmt.Sprintf(" AND %s = $%d", rawUsageLogModelColumn, len(args)+1) + args = append(args, model) + return query, args +} + type usageLogRepository struct { client *dbent.Client sql sqlExecutor @@ -278,6 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -313,12 +352,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, $6, - $7, $8, - $9, $10, $11, $12, - $13, $14, - $15, $16, $17, $18, $19, $20, - $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 + $1, $2, $3, $4, $5, $6, $7, + $8, $9, + $10, $11, $12, $13, + $14, $15, + $16, $17, $18, $19, $20, $21, + $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -709,6 +748,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -779,6 +819,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -820,6 +861,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -901,6 +943,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -937,7 +980,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*39) + args := make([]any, 0, len(preparedList)*40) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -968,6 +1011,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -1009,6 +1053,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -1058,6 +1103,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared account_id, request_id, model, + requested_model, upstream_model, group_id, subscription_id, @@ -1093,12 +1139,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, $6, - $7, $8, - $9, $10, $11, $12, - $13, $14, - $15, $16, $17, $18, $19, $20, - $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 + $1, $2, $3, $4, $5, $6, $7, + $8, $9, + $10, $11, $12, $13, + $14, $15, + $16, $17, $18, $19, $20, $21, + $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1130,6 +1176,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint) + requestedModel := strings.TrimSpace(log.RequestedModel) + if requestedModel == "" { + requestedModel = strings.TrimSpace(log.Model) + } upstreamModel := nullString(log.UpstreamModel) var requestIDArg any @@ -1148,6 +1198,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { log.AccountID, requestIDArg, log.Model, + nullString(&requestedModel), upstreamModel, groupID, subscriptionID, @@ -1702,7 +1753,7 @@ func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, acco // GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据 // 性能优化:数据库层聚合计算,避免应用层循环统计 func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { - query := ` + query := fmt.Sprintf(` SELECT COUNT(*) as total_requests, COALESCE(SUM(input_tokens), 0) as total_input_tokens, @@ -1712,8 +1763,8 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN COALESCE(SUM(actual_cost), 0) as total_actual_cost, COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms FROM usage_logs - WHERE model = $1 AND created_at >= $2 AND created_at < $3 - ` + WHERE %s = $1 AND created_at >= $2 AND created_at < $3 + `, rawUsageLogModelColumn) var stats usagestats.UsageStats if err := scanSingleRow( @@ -1837,7 +1888,7 @@ func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco } func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" + query := fmt.Sprintf("SELECT %s FROM usage_logs WHERE %s = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000", usageLogSelectColumns, rawUsageLogModelColumn) logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) return logs, nil, err } @@ -2532,10 +2583,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) args = append(args, filters.GroupID) } - if filters.Model != "" { - conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) - args = append(args, filters.Model) - } + conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model) conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) @@ -2768,10 +2816,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if model != "" { - query += fmt.Sprintf(" AND model = $%d", len(args)+1) - args = append(args, model) - } + query, args = appendRawUsageLogModelQueryFilter(query, args, model) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) @@ -3126,13 +3171,14 @@ func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayS // resolveModelDimensionExpression maps model source type to a safe SQL expression. func resolveModelDimensionExpression(modelType string) string { + requestedExpr := "COALESCE(NULLIF(TRIM(requested_model), ''), model)" switch usagestats.NormalizeModelSource(modelType) { case usagestats.ModelSourceUpstream: - return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)" + return fmt.Sprintf("COALESCE(NULLIF(TRIM(upstream_model), ''), %s)", requestedExpr) case usagestats.ModelSourceMapping: - return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))" + return fmt.Sprintf("(%s || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), %s))", requestedExpr, requestedExpr) default: - return "model" + return requestedExpr } } @@ -3204,10 +3250,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) args = append(args, filters.GroupID) } - if filters.Model != "" { - conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) - args = append(args, filters.Model) - } + conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model) conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) @@ -3336,10 +3379,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if model != "" { - query += fmt.Sprintf(" AND model = $%d", len(args)+1) - args = append(args, model) - } + query, args = appendRawUsageLogModelQueryFilter(query, args, model) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) @@ -3410,10 +3450,7 @@ func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if model != "" { - query += fmt.Sprintf(" AND model = $%d", len(args)+1) - args = append(args, model) - } + query, args = appendRawUsageLogModelQueryFilter(query, args, model) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) @@ -3888,6 +3925,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e accountID int64 requestID sql.NullString model string + requestedModel sql.NullString upstreamModel sql.NullString groupID sql.NullInt64 subscriptionID sql.NullInt64 @@ -3931,6 +3969,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &accountID, &requestID, &model, + &requestedModel, &upstreamModel, &groupID, &subscriptionID, @@ -3975,6 +4014,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e APIKeyID: apiKeyID, AccountID: accountID, Model: model, + RequestedModel: coalesceTrimmedString(requestedModel, model), InputTokens: inputTokens, OutputTokens: outputTokens, CacheCreationTokens: cacheCreationTokens, @@ -4181,6 +4221,13 @@ func nullString(v *string) sql.NullString { return sql.NullString{String: *v, Valid: true} } +func coalesceTrimmedString(v sql.NullString, fallback string) string { + if v.Valid && strings.TrimSpace(v.String) != "" { + return v.String + } + return fallback +} + func setToSlice(set map[int64]struct{}) []int64 { out := make([]int64, 0, len(set)) for id := range set { diff --git a/backend/internal/repository/usage_log_repo_breakdown_test.go b/backend/internal/repository/usage_log_repo_breakdown_test.go index 5d908bfd..da62e8dd 100644 --- a/backend/internal/repository/usage_log_repo_breakdown_test.go +++ b/backend/internal/repository/usage_log_repo_breakdown_test.go @@ -34,11 +34,11 @@ func TestResolveModelDimensionExpression(t *testing.T) { modelType string want string }{ - {usagestats.ModelSourceRequested, "model"}, - {usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"}, - {usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"}, - {"", "model"}, - {"invalid", "model"}, + {usagestats.ModelSourceRequested, "COALESCE(NULLIF(TRIM(requested_model), ''), model)"}, + {usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model))"}, + {usagestats.ModelSourceMapping, "(COALESCE(NULLIF(TRIM(requested_model), ''), model) || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model)))"}, + {"", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"}, + {"invalid", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"}, } for _, tc := range tests { diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 76827c31..ebc8929a 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -3,6 +3,7 @@ package repository import ( "context" "database/sql" + "database/sql/driver" "fmt" "reflect" "testing" @@ -21,20 +22,21 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) log := &service.UsageLog{ - UserID: 1, - APIKeyID: 2, - AccountID: 3, - RequestID: "req-1", - Model: "gpt-5", - InputTokens: 10, - OutputTokens: 20, - TotalCost: 1, - ActualCost: 1, - BillingType: service.BillingTypeBalance, - RequestType: service.RequestTypeWSV2, - Stream: false, - OpenAIWSMode: false, - CreatedAt: createdAt, + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-1", + Model: "gpt-5", + RequestedModel: "gpt-5", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1, + ActualCost: 1, + BillingType: service.BillingTypeBalance, + RequestType: service.RequestTypeWSV2, + Stream: false, + OpenAIWSMode: false, + CreatedAt: createdAt, } mock.ExpectQuery("INSERT INTO usage_logs"). @@ -44,6 +46,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { log.AccountID, log.RequestID, log.Model, + log.RequestedModel, sqlmock.AnyArg(), // upstream_model sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // subscription_id @@ -99,13 +102,14 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC) serviceTier := "priority" log := &service.UsageLog{ - UserID: 1, - APIKeyID: 2, - AccountID: 3, - RequestID: "req-service-tier", - Model: "gpt-5.4", - ServiceTier: &serviceTier, - CreatedAt: createdAt, + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-service-tier", + Model: "gpt-5.4", + RequestedModel: "gpt-5.4", + ServiceTier: &serviceTier, + CreatedAt: createdAt, } mock.ExpectQuery("INSERT INTO usage_logs"). @@ -115,6 +119,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { log.AccountID, log.RequestID, log.Model, + log.RequestedModel, sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), @@ -158,6 +163,75 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { require.NoError(t, mock.ExpectationsWereMet()) } +func TestBuildUsageLogBestEffortInsertQuery_IncludesRequestedModelColumn(t *testing.T) { + prepared := prepareUsageLogInsert(&service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-best-effort-query", + Model: "gpt-5", + RequestedModel: "gpt-5", + CreatedAt: time.Date(2025, 1, 3, 12, 0, 0, 0, time.UTC), + }) + + query, args := buildUsageLogBestEffortInsertQuery([]usageLogInsertPrepared{prepared}) + + require.Contains(t, query, "INSERT INTO usage_logs (") + require.Contains(t, query, "\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,") + require.Contains(t, query, "\n\t\t\trequest_id,\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,") + require.Len(t, args, len(prepared.args)) + require.Equal(t, prepared.args[5], args[5]) +} + +func TestExecUsageLogInsertNoResult_PersistsRequestedModel(t *testing.T) { + db, mock := newSQLMock(t) + prepared := prepareUsageLogInsert(&service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-best-effort-exec", + Model: "gpt-5", + RequestedModel: "gpt-5", + CreatedAt: time.Date(2025, 1, 4, 12, 0, 0, 0, time.UTC), + }) + + mock.ExpectExec("INSERT INTO usage_logs"). + WithArgs(anySliceToDriverValues(prepared.args)...). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := execUsageLogInsertNoResult(context.Background(), db, prepared) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPrepareUsageLogInsert_ArgCountMatchesTypes(t *testing.T) { + prepared := prepareUsageLogInsert(&service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-arg-count", + Model: "gpt-5", + RequestedModel: "gpt-5", + CreatedAt: time.Date(2025, 1, 5, 12, 0, 0, 0, time.UTC), + }) + + require.Len(t, prepared.args, len(usageLogInsertArgTypes)) +} + +func TestCoalesceTrimmedString(t *testing.T) { + require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{}, "fallback")) + require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{Valid: true, String: " "}, "fallback")) + require.Equal(t, "value", coalesceTrimmedString(sql.NullString{Valid: true, String: "value"}, "fallback")) +} + +func anySliceToDriverValues(values []any) []driver.Value { + out := make([]driver.Value, 0, len(values)) + for _, value := range values { + out = append(out, value) + } + return out +} + func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { db, mock := newSQLMock(t) repo := &usageLogRepository{sql: db} @@ -354,7 +428,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(20), // api_key_id int64(30), // account_id sql.NullString{Valid: true, String: "req-1"}, - "gpt-5", // model + "gpt-5", // model + sql.NullString{Valid: true, String: "gpt-5"}, // requested_model sql.NullString{}, // upstream_model sql.NullInt64{}, // group_id sql.NullInt64{}, // subscription_id @@ -407,6 +482,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(31), sql.NullString{Valid: true, String: "req-2"}, "gpt-5", + sql.NullString{Valid: true, String: "gpt-5"}, sql.NullString{}, sql.NullInt64{}, sql.NullInt64{}, @@ -449,6 +525,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(32), sql.NullString{Valid: true, String: "req-3"}, "gpt-5.4", + sql.NullString{Valid: true, String: "gpt-5.4"}, sql.NullString{}, sql.NullInt64{}, sql.NullInt64{}, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index a6bd50ac..8509c8a9 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -540,7 +540,8 @@ func TestAPIContracts(t *testing.T) { "max_claude_code_version": "", "allow_ungrouped_key_scheduling": false, "backend_mode_enabled": false, - "custom_menu_items": [] + "custom_menu_items": [], + "custom_endpoints": [] } }`, }, diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index b6408f5f..d42c6a11 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -1543,6 +1543,24 @@ func isPeriodExpired(periodStart time.Time, dur time.Duration) bool { return time.Since(periodStart) >= dur } +// IsDailyQuotaPeriodExpired 检查日配额周期是否已过期(用于显示层判断是否需要将 used 归零) +func (a *Account) IsDailyQuotaPeriodExpired() bool { + start := a.getExtraTime("quota_daily_start") + if a.GetQuotaDailyResetMode() == "fixed" { + return a.isFixedDailyPeriodExpired(start) + } + return isPeriodExpired(start, 24*time.Hour) +} + +// IsWeeklyQuotaPeriodExpired 检查周配额周期是否已过期(用于显示层判断是否需要将 used 归零) +func (a *Account) IsWeeklyQuotaPeriodExpired() bool { + start := a.getExtraTime("quota_weekly_start") + if a.GetQuotaWeeklyResetMode() == "fixed" { + return a.isFixedWeeklyPeriodExpired(start) + } + return isPeriodExpired(start, 7*24*time.Hour) +} + // IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true) func (a *Account) IsQuotaExceeded() bool { // 总额度 diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 50fa78f2..6ee8280c 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1742,7 +1742,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: billingModel, // 使用映射模型用于计费和日志 + Model: originalModel, + UpstreamModel: billingModel, Stream: claudeReq.Stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -2435,7 +2436,8 @@ handleSuccess: return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: billingModel, + Model: originalModel, + UpstreamModel: billingModel, Stream: stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 6e0a7305..f5f9434c 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -542,7 +542,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) { result, err := svc.Forward(context.Background(), c, account, body, false) require.NoError(t, err) require.NotNil(t, result) - require.Equal(t, mappedModel, result.Model) + require.Equal(t, "claude-sonnet-4-5", result.Model) + require.Equal(t, mappedModel, result.UpstreamModel) } // TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel @@ -594,7 +595,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false) require.NoError(t, err) require.NotNil(t, result) - require.Equal(t, mappedModel, result.Model) + require.Equal(t, "gemini-2.5-flash", result.Model) + require.Equal(t, mappedModel, result.UpstreamModel) } func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) { @@ -664,7 +666,8 @@ func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignatur result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false) require.NoError(t, err) require.NotNil(t, result) - require.Equal(t, mappedModel, result.Model) + require.Equal(t, originalModel, result.Model) + require.Equal(t, mappedModel, result.UpstreamModel) require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry") firstReq := string(upstream.requestBodies[0]) diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 99fea0b0..004511f5 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -222,10 +222,10 @@ func (s *BillingService) initFallbackPricing() { LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, } s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{ - InputPricePerToken: 7.5e-7, - OutputPricePerToken: 4.5e-6, - CacheReadPricePerToken: 7.5e-8, - SupportsCacheBreakdown: false, + InputPricePerToken: 7.5e-7, + OutputPricePerToken: 4.5e-6, + CacheReadPricePerToken: 7.5e-8, + SupportsCacheBreakdown: false, } s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{ InputPricePerToken: 2e-7, diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 384d5159..4ae5a469 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -119,6 +119,7 @@ const ( SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src) SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组) + SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组) // 默认配置 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 4c1f0317..1b2f5f51 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -162,6 +162,32 @@ func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash) } +func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + mappedModel := "claude-sonnet-4-20250514" + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_models_split", + Usage: ClaudeUsage{InputTokens: 10, OutputTokens: 6}, + Model: "claude-sonnet-4", + UpstreamModel: mappedModel, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.Model) + require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.RequestedModel) + require.NotNil(t, usageRepo.lastLog.UpstreamModel) + require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel) +} + func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} userRepo := &openAIRecordUsageUserRepoStub{} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index e23d24de..72cef2ac 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -482,10 +482,12 @@ type ClaudeUsage struct { // ForwardResult 转发结果 type ForwardResult struct { - RequestID string - Usage ClaudeUsage - Model string - UpstreamModel string // Actual upstream model after mapping (empty = no mapping) + RequestID string + Usage ClaudeUsage + Model string + // UpstreamModel is the actual upstream model after mapping. + // Prefer empty when it is identical to Model; persistence normalizes equal values away as no-op mappings. + UpstreamModel string Stream bool Duration time.Duration FirstTokenMs *int // 首字时间(流式请求) @@ -4197,7 +4199,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) break } - logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) + logger.LegacyPrintf("service.gateway", "[warn] Account %d: thinking blocks have invalid signature, retrying with filtered blocks", account.ID) // Conservative two-stage fallback: // 1) Disable thinking + thinking->text (preserve content) @@ -4212,7 +4214,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { if retryResp.StatusCode < 400 { - logger.LegacyPrintf("service.gateway", "Account %d: signature error retry succeeded (thinking downgraded)", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: thinking block retry succeeded (blocks downgraded)", account.ID) resp = retryResp break } @@ -6102,13 +6104,9 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { return false } - // Log for debugging - logger.LegacyPrintf("service.gateway", "[SignatureCheck] Checking error message: %s", msg) - // 检测signature相关的错误(更宽松的匹配) // 例如: "Invalid `signature` in `thinking` block", "***.signature" 等 if strings.Contains(msg, "signature") { - logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected signature error") return true } @@ -7516,6 +7514,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } var cost *CostBreakdown + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) // 根据请求类型选择计费方式 if result.MediaType == "image" || result.MediaType == "video" { @@ -7531,7 +7530,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu if result.MediaType == "image" { cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) } else { - cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) + cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) } } else if result.MediaType == "prompt" { cost = &CostBreakdown{} @@ -7545,7 +7544,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu Price4K: apiKey.Group.ImagePrice4K, } } - cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) + cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) } else { // Token 计费 tokens := UsageTokens{ @@ -7557,7 +7556,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) + cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} @@ -7589,6 +7588,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu AccountID: account.ID, RequestID: requestID, Model: result.Model, + RequestedModel: result.Model, UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), @@ -7719,6 +7719,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } var cost *CostBreakdown + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) // 根据请求类型选择计费方式 if result.ImageCount > 0 { @@ -7731,7 +7732,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * Price4K: apiKey.Group.ImagePrice4K, } } - cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) + cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) } else { // Token 计费(使用长上下文计费方法) tokens := UsageTokens{ @@ -7743,7 +7744,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) + cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} @@ -7771,6 +7772,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * AccountID: account.ID, RequestID: requestID, Model: result.Model, + RequestedModel: result.Model, UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index e65c838d..5b1abc11 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -1028,14 +1028,15 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: req.Stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: req.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: imageSize, }, nil } @@ -1241,12 +1242,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. estimated := estimateGeminiCountTokens(body) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) return &ForwardResult{ - RequestID: "", - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(startTime), - FirstTokenMs: nil, + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, }, nil } setOpsUpstreamError(c, 0, safeErr, "") @@ -1310,12 +1312,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. estimated := estimateGeminiCountTokens(body) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) return &ForwardResult{ - RequestID: "", - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(startTime), - FirstTokenMs: nil, + RequestID: "", + Usage: ClaudeUsage{}, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, }, nil } // Final attempt: surface the upstream error body (passed through below) instead of a generic retry error. @@ -1350,12 +1353,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. estimated := estimateGeminiCountTokens(body) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) return &ForwardResult{ - RequestID: requestID, - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(startTime), - FirstTokenMs: nil, + RequestID: requestID, + Usage: ClaudeUsage{}, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, }, nil } @@ -1527,14 +1531,15 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: imageSize, }, nil } diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index 7560f480..17f7e74e 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "encoding/json" "fmt" "io" @@ -15,6 +16,30 @@ import ( "github.com/stretchr/testify/require" ) +type geminiCompatHTTPUpstreamStub struct { + response *http.Response + err error + calls int + lastReq *http.Request +} + +func (s *geminiCompatHTTPUpstreamStub) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + s.calls++ + s.lastReq = req + if s.err != nil { + return nil, s.err + } + if s.response == nil { + return nil, fmt.Errorf("missing stub response") + } + resp := *s.response + return &resp, nil +} + +func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return s.Do(req, proxyURL, accountID, accountConcurrency) +} + // TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { tests := []struct { @@ -170,6 +195,42 @@ func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLo require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志") } +func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpstreamModel(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + httpStub := &geminiCompatHTTPUpstreamStub{ + response: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"x-request-id": []string{"gemini-req-1"}}, + Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}`)), + }, + } + svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}} + account := &Account{ + ID: 1, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-key", + "model_mapping": map[string]any{ + "claude-sonnet-4": "claude-sonnet-4-20250514", + }, + }, + } + body := []byte(`{"model":"claude-sonnet-4","max_tokens":16,"messages":[{"role":"user","content":"hello"}]}`) + + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "claude-sonnet-4", result.Model) + require.Equal(t, "claude-sonnet-4-20250514", result.UpstreamModel) + require.Equal(t, 1, httpStub.calls) + require.NotNil(t, httpStub.lastReq) + require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:") +} + func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { claudeReq := map[string]any{ "model": "claude-haiku-4-5-20251001", diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index a35f9127..5aa4db8a 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -879,6 +879,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad require.NoError(t, err) require.NotNil(t, usageRepo.lastLog) require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model) + require.Equal(t, "gpt-5.1", usageRepo.lastLog.RequestedModel) require.NotNil(t, usageRepo.lastLog.UpstreamModel) require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel) require.NotNil(t, usageRepo.lastLog.ServiceTier) @@ -894,6 +895,40 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad require.Equal(t, 1, userRepo.deductCalls) } +func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} + + expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_upstream_model_billing_fallback", + Model: "gpt-5.1", + UpstreamModel: "gpt-5.1-codex", + Usage: usage, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model) + require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost) + require.Equal(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost) + require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount) +} + func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index cf902c20..4e96cf05 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4110,9 +4110,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) } - billingModel := result.Model + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) if result.BillingModel != "" { - billingModel = result.BillingModel + billingModel = strings.TrimSpace(result.BillingModel) } serviceTier := "" if result.ServiceTier != nil { @@ -4140,6 +4140,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec AccountID: account.ID, RequestID: requestID, Model: result.Model, + RequestedModel: result.Model, UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort, diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index 7af3ecae..edbb968b 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -68,3 +68,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) { }) } } + +func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) { + account := &Account{ + Credentials: map[string]any{}, + } + + withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "") + if got := normalizeCodexModel(withoutDefault); got != "gpt-5.1" { + t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withoutDefault, got, "gpt-5.1") + } + + withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4") + if got := normalizeCodexModel(withDefault); got != "gpt-5.4" { + t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withDefault, got, "gpt-5.4") + } +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 1d3d8fdf..814ec0bd 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2328,6 +2328,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( RequestID: responseID, Usage: *usage, Model: originalModel, + UpstreamModel: mappedModel, ServiceTier: extractOpenAIServiceTier(reqBody), ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), Stream: reqStream, @@ -2945,6 +2946,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( RequestID: responseID, Usage: usage, Model: originalModel, + UpstreamModel: mappedModel, ServiceTier: extractOpenAIServiceTierFromBody(payload), ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), Stream: reqStream, diff --git a/backend/internal/service/ops_alert_evaluator_service.go b/backend/internal/service/ops_alert_evaluator_service.go index 88883180..11c5d5ce 100644 --- a/backend/internal/service/ops_alert_evaluator_service.go +++ b/backend/internal/service/ops_alert_evaluator_service.go @@ -88,6 +88,7 @@ func (s *OpsAlertEvaluatorService) Start() { if s.stopCh == nil { s.stopCh = make(chan struct{}) } + s.wg.Add(1) go s.run() }) } @@ -105,7 +106,6 @@ func (s *OpsAlertEvaluatorService) Stop() { } func (s *OpsAlertEvaluatorService) run() { - s.wg.Add(1) defer s.wg.Done() // Start immediately to produce early feedback in ops dashboard. @@ -848,7 +848,9 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc return nil, false } return func() { - _, _ = opsAlertEvaluatorReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result() + releaseCtx, releaseCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer releaseCancel() + _, _ = opsAlertEvaluatorReleaseScript.Run(releaseCtx, s.redisClient, []string{key}, s.instanceID).Result() }, true } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f652839c..44d20491 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -150,6 +150,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyPurchaseSubscriptionURL, SettingKeySoraClientEnabled, SettingKeyCustomMenuItems, + SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, SettingKeyBackendModeEnabled, } @@ -195,6 +196,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], + CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", }, nil @@ -247,6 +249,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems json.RawMessage `json:"custom_menu_items"` + CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` Version string `json:"version,omitempty"` @@ -272,6 +275,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), + CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, BackendModeEnabled: settings.BackendModeEnabled, Version: s.version, @@ -314,6 +318,18 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage { return result } +// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]". +func safeRawJSONArray(raw string) json.RawMessage { + raw = strings.TrimSpace(raw) + if raw == "" { + return json.RawMessage("[]") + } + if json.Valid([]byte(raw)) { + return json.RawMessage(raw) + } + return json.RawMessage("[]") +} + // GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url // and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { @@ -454,6 +470,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems + updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) @@ -740,6 +757,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyPurchaseSubscriptionURL: "", SettingKeySoraClientEnabled: "false", SettingKeyCustomMenuItems: "[]", + SettingKeyCustomEndpoints: "[]", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultSubscriptions: "[]", @@ -805,6 +823,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], + CustomEndpoints: settings[SettingKeyCustomEndpoints], BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", } diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index cd0bed0b..cf1d5eed 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -43,6 +43,7 @@ type SystemSettings struct { PurchaseSubscriptionURL string SoraClientEnabled bool CustomMenuItems string // JSON array of custom menu items + CustomEndpoints string // JSON array of custom endpoints DefaultConcurrency int DefaultBalance float64 @@ -104,6 +105,7 @@ type PublicSettings struct { PurchaseSubscriptionURL string SoraClientEnabled bool CustomMenuItems string // JSON array of custom menu items + CustomEndpoints string // JSON array of custom endpoints LinuxDoOAuthEnabled bool BackendModeEnabled bool diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index ab6871bb..e9d325f4 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -148,10 +148,13 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) return nil, errors.New("model is required") } + originalModel := reqModel mappedModel := account.GetMappedModel(reqModel) + var upstreamModel string if mappedModel != "" && mappedModel != reqModel { reqModel = mappedModel + upstreamModel = mappedModel } modelCfg, ok := GetSoraModelConfig(reqModel) @@ -213,13 +216,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) } return &ForwardResult{ - RequestID: "", - Model: reqModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", + RequestID: "", + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", }, nil } @@ -269,13 +273,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun c.JSON(http.StatusOK, resp) } return &ForwardResult{ - RequestID: "", - Model: reqModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", + RequestID: "", + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", }, nil } if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { @@ -419,16 +424,17 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun } return &ForwardResult{ - RequestID: taskID, - Model: reqModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: mediaType, - MediaURL: firstMediaURL(finalURLs), - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: taskID, + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: mediaType, + MediaURL: firstMediaURL(finalURLs), + ImageCount: imageCount, + ImageSize: imageSize, }, nil } diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index 206636ff..2fef600c 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -144,6 +144,11 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { ID: 1, Platform: PlatformSora, Status: StatusActive, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "prompt-enhance-short-10s": "prompt-enhance-short-15s", + }, + }, } body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) @@ -152,6 +157,7 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { require.NotNil(t, result) require.Equal(t, "prompt", result.MediaType) require.Equal(t, "prompt-enhance-short-10s", result.Model) + require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel) } func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 5a498676..576841fa 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -98,6 +98,9 @@ type UsageLog struct { AccountID int64 RequestID string Model string + // RequestedModel is the client-requested model name recorded for stable user/admin display. + // Empty should be treated as Model for backward compatibility with historical rows. + RequestedModel string // UpstreamModel is the actual model sent to the upstream provider after mapping. // Nil means no mapping was applied (requested model was used as-is). UpstreamModel *string diff --git a/backend/internal/service/usage_log_helpers.go b/backend/internal/service/usage_log_helpers.go index 2ab51849..57c51540 100644 --- a/backend/internal/service/usage_log_helpers.go +++ b/backend/internal/service/usage_log_helpers.go @@ -19,3 +19,10 @@ func optionalNonEqualStringPtr(value, compare string) *string { } return &value } + +func forwardResultBillingModel(requestedModel, upstreamModel string) string { + if trimmedUpstream := strings.TrimSpace(upstreamModel); trimmedUpstream != "" { + return trimmedUpstream + } + return strings.TrimSpace(requestedModel) +} diff --git a/backend/migrations/077_add_usage_log_requested_model.sql b/backend/migrations/077_add_usage_log_requested_model.sql new file mode 100644 index 00000000..4b87df86 --- /dev/null +++ b/backend/migrations/077_add_usage_log_requested_model.sql @@ -0,0 +1,3 @@ +-- Add requested_model field to usage_logs for normalized request/upstream model tracking. +-- NULL means historical rows written before requested_model dual-write was introduced. +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100); diff --git a/backend/migrations/078_add_usage_log_requested_model_index_notx.sql b/backend/migrations/078_add_usage_log_requested_model_index_notx.sql new file mode 100644 index 00000000..c3412562 --- /dev/null +++ b/backend/migrations/078_add_usage_log_requested_model_index_notx.sql @@ -0,0 +1,3 @@ +-- Support requested_model / upstream_model aggregations with time-range filters. +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_created_requested_model_upstream_model +ON usage_logs (created_at, requested_model, upstream_model); diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 0519d2fc..83258bcc 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -4,7 +4,7 @@ */ import { apiClient } from '../client' -import type { CustomMenuItem } from '@/types' +import type { CustomMenuItem, CustomEndpoint } from '@/types' export interface DefaultSubscriptionSetting { group_id: number @@ -43,6 +43,7 @@ export interface SystemSettings { sora_client_enabled: boolean backend_mode_enabled: boolean custom_menu_items: CustomMenuItem[] + custom_endpoints: CustomEndpoint[] // SMTP settings smtp_host: string smtp_port: number @@ -112,6 +113,7 @@ export interface UpdateSettingsRequest { sora_client_enabled?: boolean backend_mode_enabled?: boolean custom_menu_items?: CustomMenuItem[] + custom_endpoints?: CustomEndpoint[] smtp_host?: string smtp_port?: number smtp_username?: string diff --git a/frontend/src/components/admin/usage/__tests__/UsageTable.spec.ts b/frontend/src/components/admin/usage/__tests__/UsageTable.spec.ts index e38bb4f7..9309c88b 100644 --- a/frontend/src/components/admin/usage/__tests__/UsageTable.spec.ts +++ b/frontend/src/components/admin/usage/__tests__/UsageTable.spec.ts @@ -39,6 +39,7 @@ const DataTableStub = { template: `
+
@@ -108,4 +109,42 @@ describe('admin UsageTable tooltip', () => { expect(text).toContain('$30.0000 / 1M tokens') expect(text).toContain('$0.069568') }) + + it('shows requested and upstream models separately for admin rows', () => { + const row = { + request_id: 'req-admin-model-1', + model: 'claude-sonnet-4', + upstream_model: 'claude-sonnet-4-20250514', + actual_cost: 0, + total_cost: 0, + account_rate_multiplier: 1, + rate_multiplier: 1, + input_cost: 0, + output_cost: 0, + cache_creation_cost: 0, + cache_read_cost: 0, + input_tokens: 0, + output_tokens: 0, + } + + const wrapper = mount(UsageTable, { + props: { + data: [row], + loading: false, + columns: [], + }, + global: { + stubs: { + DataTable: DataTableStub, + EmptyState: true, + Icon: true, + Teleport: true, + }, + }, + }) + + const text = wrapper.text() + expect(text).toContain('claude-sonnet-4') + expect(text).toContain('claude-sonnet-4-20250514') + }) }) diff --git a/frontend/src/components/keys/EndpointPopover.vue b/frontend/src/components/keys/EndpointPopover.vue new file mode 100644 index 00000000..49db50b0 --- /dev/null +++ b/frontend/src/components/keys/EndpointPopover.vue @@ -0,0 +1,141 @@ + + + diff --git a/frontend/src/components/keys/__tests__/EndpointPopover.spec.ts b/frontend/src/components/keys/__tests__/EndpointPopover.spec.ts new file mode 100644 index 00000000..4d753da2 --- /dev/null +++ b/frontend/src/components/keys/__tests__/EndpointPopover.spec.ts @@ -0,0 +1,69 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +const copyToClipboard = vi.fn().mockResolvedValue(true) + +const messages: Record = { + 'keys.endpoints.title': 'API 端点', + 'keys.endpoints.default': '默认', + 'keys.endpoints.copied': '已复制', + 'keys.endpoints.copiedHint': '已复制到剪贴板', + 'keys.endpoints.clickToCopy': '点击可复制此端点', + 'keys.endpoints.speedTest': '测速', +} + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => messages[key] ?? key, + }), +})) + +vi.mock('@/composables/useClipboard', () => ({ + useClipboard: () => ({ + copyToClipboard, + }), +})) + +import EndpointPopover from '../EndpointPopover.vue' + +describe('EndpointPopover', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('将说明提示渲染到 URL 上方而不是旧的 title 图标上', () => { + const wrapper = mount(EndpointPopover, { + props: { + apiBaseUrl: 'https://default.example.com/v1', + customEndpoints: [ + { + name: '备用线路', + endpoint: 'https://backup.example.com/v1', + description: '自定义说明', + }, + ], + }, + }) + + expect(wrapper.text()).toContain('自定义说明') + expect(wrapper.text()).toContain('点击可复制此端点') + expect(wrapper.find('[role="button"]').attributes('title')).toBeUndefined() + expect(wrapper.find('[title="自定义说明"]').exists()).toBe(false) + }) + + it('点击 URL 后会复制并切换为已复制提示', async () => { + const wrapper = mount(EndpointPopover, { + props: { + apiBaseUrl: 'https://default.example.com/v1', + customEndpoints: [], + }, + }) + + await wrapper.find('[role="button"]').trigger('click') + await flushPromises() + + expect(copyToClipboard).toHaveBeenCalledWith('https://default.example.com/v1', '已复制') + expect(wrapper.text()).toContain('已复制到剪贴板') + expect(wrapper.find('button[aria-label="已复制到剪贴板"]').exists()).toBe(true) + }) +}) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index e5a370c8..a2f69e2c 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -533,6 +533,14 @@ export default { title: 'API Keys', description: 'Manage your API keys and access tokens', searchPlaceholder: 'Search name or key...', + endpoints: { + title: 'API Endpoints', + default: 'Default', + copied: 'Copied', + copiedHint: 'Copied to clipboard', + clickToCopy: 'Click to copy this endpoint', + speedTest: 'Speed Test', + }, allGroups: 'All Groups', allStatus: 'All Status', createKey: 'Create API Key', @@ -4162,6 +4170,18 @@ export default { apiBaseUrlPlaceholder: 'https://api.example.com', apiBaseUrlHint: 'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.', + customEndpoints: { + title: 'Custom Endpoints', + description: 'Add additional API endpoint URLs for users to quickly copy on the API Keys page', + itemLabel: 'Endpoint #{n}', + name: 'Name', + namePlaceholder: 'e.g., OpenAI Compatible', + endpointUrl: 'Endpoint URL', + endpointUrlPlaceholder: 'https://api2.example.com', + descriptionLabel: 'Description', + descriptionPlaceholder: 'e.g., Supports OpenAI format requests', + add: 'Add Endpoint', + }, contactInfo: 'Contact Info', contactInfoPlaceholder: 'e.g., QQ: 123456789', contactInfoHint: 'Customer support contact info, displayed on redeem page, profile, etc.', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index ac6632be..2eef299c 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -533,6 +533,14 @@ export default { title: 'API 密钥', description: '管理您的 API 密钥和访问令牌', searchPlaceholder: '搜索名称或Key...', + endpoints: { + title: 'API 端点', + default: '默认', + copied: '已复制', + copiedHint: '已复制到剪贴板', + clickToCopy: '点击可复制此端点', + speedTest: '测速', + }, allGroups: '全部分组', allStatus: '全部状态', createKey: '创建密钥', @@ -4324,6 +4332,18 @@ export default { apiBaseUrl: 'API 端点地址', apiBaseUrlHint: '用于"使用密钥"和"导入到 CC Switch"功能,留空则使用当前站点地址', apiBaseUrlPlaceholder: 'https://api.example.com', + customEndpoints: { + title: '自定义端点', + description: '添加额外的 API 端点地址,用户可在「API Keys」页面快速复制', + itemLabel: '端点 #{n}', + name: '名称', + namePlaceholder: '如:OpenAI Compatible', + endpointUrl: '端点地址', + endpointUrlPlaceholder: 'https://api2.example.com', + descriptionLabel: '介绍', + descriptionPlaceholder: '如:支持 OpenAI 格式请求', + add: '添加端点', + }, contactInfo: '客服联系方式', contactInfoPlaceholder: '例如:QQ: 123456789', contactInfoHint: '填写客服联系方式,将展示在兑换页面、个人资料等位置', diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index dea920c0..c080c2af 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -330,6 +330,7 @@ export const useAppStore = defineStore('app', () => { purchase_subscription_enabled: false, purchase_subscription_url: '', custom_menu_items: [], + custom_endpoints: [], linuxdo_oauth_enabled: false, sora_client_enabled: false, backend_mode_enabled: false, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 056efae2..2656a28d 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -84,6 +84,12 @@ export interface CustomMenuItem { sort_order: number } +export interface CustomEndpoint { + name: string + endpoint: string + description: string +} + export interface PublicSettings { registration_enabled: boolean email_verify_enabled: boolean @@ -104,6 +110,7 @@ export interface PublicSettings { purchase_subscription_enabled: boolean purchase_subscription_url: string custom_menu_items: CustomMenuItem[] + custom_endpoints: CustomEndpoint[] linuxdo_oauth_enabled: boolean sora_client_enabled: boolean backend_mode_enabled: boolean @@ -978,7 +985,6 @@ export interface UsageLog { account_id: number | null request_id: string model: string - upstream_model?: string | null service_tier?: string | null reasoning_effort?: string | null inbound_endpoint?: string | null @@ -1033,6 +1039,8 @@ export interface UsageLogAccountSummary { } export interface AdminUsageLog extends UsageLog { + upstream_model?: string | null + // 账号计费倍率(仅管理员可见) account_rate_multiplier?: number | null diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 99cd247e..dbaa9c37 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -7,7 +7,7 @@ -
+
+ +
+ +

+ {{ t('admin.settings.site.customEndpoints.description') }} +

+ +
+
+
+ + {{ t('admin.settings.site.customEndpoints.itemLabel', { n: index + 1 }) }} + + +
+
+
+ + +
+
+ + +
+
+ + +
+
+
+
+ + +
+