Merge branch 'main' of github.com:InCerryGit/sub2api

# Conflicts:
#	backend/internal/service/billing_service.go
This commit is contained in:
InCerry
2026-03-24 15:08:55 +08:00
59 changed files with 1740 additions and 328 deletions

View File

@@ -716,6 +716,7 @@ var (
{Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "request_id", Type: field.TypeString, Size: 64}, {Name: "request_id", Type: field.TypeString, Size: 64},
{Name: "model", Type: field.TypeString, Size: 100}, {Name: "model", Type: field.TypeString, Size: 100},
{Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "input_tokens", Type: field.TypeInt, Default: 0}, {Name: "input_tokens", Type: field.TypeInt, Default: 0},
{Name: "output_tokens", Type: field.TypeInt, Default: 0}, {Name: "output_tokens", Type: field.TypeInt, Default: 0},
@@ -756,31 +757,31 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "usage_logs_api_keys_usage_logs", Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{APIKeysColumns[0]}, RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_accounts_usage_logs", Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[31]},
RefColumns: []*schema.Column{AccountsColumns[0]}, RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_groups_usage_logs", Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[31]}, Columns: []*schema.Column{UsageLogsColumns[32]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
{ {
Symbol: "usage_logs_users_usage_logs", Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[32]}, Columns: []*schema.Column{UsageLogsColumns[33]},
RefColumns: []*schema.Column{UsersColumns[0]}, RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_user_subscriptions_usage_logs", Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[33]}, Columns: []*schema.Column{UsageLogsColumns[34]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
@@ -789,38 +790,43 @@ var (
{ {
Name: "usagelog_user_id", Name: "usagelog_user_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32]}, Columns: []*schema.Column{UsageLogsColumns[33]},
}, },
{ {
Name: "usagelog_api_key_id", Name: "usagelog_api_key_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[30]},
}, },
{ {
Name: "usagelog_account_id", Name: "usagelog_account_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[31]},
}, },
{ {
Name: "usagelog_group_id", Name: "usagelog_group_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31]}, Columns: []*schema.Column{UsageLogsColumns[32]},
}, },
{ {
Name: "usagelog_subscription_id", Name: "usagelog_subscription_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[33]}, Columns: []*schema.Column{UsageLogsColumns[34]},
}, },
{ {
Name: "usagelog_created_at", Name: "usagelog_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
}, },
{ {
Name: "usagelog_model", Name: "usagelog_model",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[2]}, Columns: []*schema.Column{UsageLogsColumns[2]},
}, },
{
Name: "usagelog_requested_model",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[3]},
},
{ {
Name: "usagelog_request_id", Name: "usagelog_request_id",
Unique: false, Unique: false,
@@ -829,17 +835,17 @@ var (
{ {
Name: "usagelog_user_id_created_at", Name: "usagelog_user_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[29]},
}, },
{ {
Name: "usagelog_api_key_id_created_at", Name: "usagelog_api_key_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[29]},
}, },
{ {
Name: "usagelog_group_id_created_at", Name: "usagelog_group_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[29]},
}, },
}, },
} }

View File

@@ -18239,6 +18239,7 @@ type UsageLogMutation struct {
id *int64 id *int64
request_id *string request_id *string
model *string model *string
requested_model *string
upstream_model *string upstream_model *string
input_tokens *int input_tokens *int
addinput_tokens *int addinput_tokens *int
@@ -18577,6 +18578,55 @@ func (m *UsageLogMutation) ResetModel() {
m.model = nil m.model = nil
} }
// SetRequestedModel sets the "requested_model" field.
func (m *UsageLogMutation) SetRequestedModel(s string) {
m.requested_model = &s
}
// RequestedModel returns the value of the "requested_model" field in the mutation.
func (m *UsageLogMutation) RequestedModel() (r string, exists bool) {
v := m.requested_model
if v == nil {
return
}
return *v, true
}
// OldRequestedModel returns the old "requested_model" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldRequestedModel(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldRequestedModel is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldRequestedModel requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldRequestedModel: %w", err)
}
return oldValue.RequestedModel, nil
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (m *UsageLogMutation) ClearRequestedModel() {
m.requested_model = nil
m.clearedFields[usagelog.FieldRequestedModel] = struct{}{}
}
// RequestedModelCleared returns if the "requested_model" field was cleared in this mutation.
func (m *UsageLogMutation) RequestedModelCleared() bool {
_, ok := m.clearedFields[usagelog.FieldRequestedModel]
return ok
}
// ResetRequestedModel resets all changes to the "requested_model" field.
func (m *UsageLogMutation) ResetRequestedModel() {
m.requested_model = nil
delete(m.clearedFields, usagelog.FieldRequestedModel)
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (m *UsageLogMutation) SetUpstreamModel(s string) { func (m *UsageLogMutation) SetUpstreamModel(s string) {
m.upstream_model = &s m.upstream_model = &s
@@ -20247,7 +20297,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UsageLogMutation) Fields() []string { func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 33) fields := make([]string, 0, 34)
if m.user != nil { if m.user != nil {
fields = append(fields, usagelog.FieldUserID) fields = append(fields, usagelog.FieldUserID)
} }
@@ -20263,6 +20313,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.model != nil { if m.model != nil {
fields = append(fields, usagelog.FieldModel) fields = append(fields, usagelog.FieldModel)
} }
if m.requested_model != nil {
fields = append(fields, usagelog.FieldRequestedModel)
}
if m.upstream_model != nil { if m.upstream_model != nil {
fields = append(fields, usagelog.FieldUpstreamModel) fields = append(fields, usagelog.FieldUpstreamModel)
} }
@@ -20365,6 +20418,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.RequestID() return m.RequestID()
case usagelog.FieldModel: case usagelog.FieldModel:
return m.Model() return m.Model()
case usagelog.FieldRequestedModel:
return m.RequestedModel()
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
return m.UpstreamModel() return m.UpstreamModel()
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
@@ -20440,6 +20495,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldRequestID(ctx) return m.OldRequestID(ctx)
case usagelog.FieldModel: case usagelog.FieldModel:
return m.OldModel(ctx) return m.OldModel(ctx)
case usagelog.FieldRequestedModel:
return m.OldRequestedModel(ctx)
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
return m.OldUpstreamModel(ctx) return m.OldUpstreamModel(ctx)
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
@@ -20540,6 +20597,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
} }
m.SetModel(v) m.SetModel(v)
return nil return nil
case usagelog.FieldRequestedModel:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetRequestedModel(v)
return nil
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
v, ok := value.(string) v, ok := value.(string)
if !ok { if !ok {
@@ -20985,6 +21049,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
// mutation. // mutation.
func (m *UsageLogMutation) ClearedFields() []string { func (m *UsageLogMutation) ClearedFields() []string {
var fields []string var fields []string
if m.FieldCleared(usagelog.FieldRequestedModel) {
fields = append(fields, usagelog.FieldRequestedModel)
}
if m.FieldCleared(usagelog.FieldUpstreamModel) { if m.FieldCleared(usagelog.FieldUpstreamModel) {
fields = append(fields, usagelog.FieldUpstreamModel) fields = append(fields, usagelog.FieldUpstreamModel)
} }
@@ -21029,6 +21096,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool {
// error if the field is not defined in the schema. // error if the field is not defined in the schema.
func (m *UsageLogMutation) ClearField(name string) error { func (m *UsageLogMutation) ClearField(name string) error {
switch name { switch name {
case usagelog.FieldRequestedModel:
m.ClearRequestedModel()
return nil
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
m.ClearUpstreamModel() m.ClearUpstreamModel()
return nil return nil
@@ -21082,6 +21152,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldModel: case usagelog.FieldModel:
m.ResetModel() m.ResetModel()
return nil return nil
case usagelog.FieldRequestedModel:
m.ResetRequestedModel()
return nil
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
m.ResetUpstreamModel() m.ResetUpstreamModel()
return nil return nil

View File

@@ -821,96 +821,100 @@ func init() {
return nil return nil
} }
}() }()
// usagelogDescRequestedModel is the schema descriptor for requested_model field.
usagelogDescRequestedModel := usagelogFields[5].Descriptor()
// usagelog.RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save.
usagelog.RequestedModelValidator = usagelogDescRequestedModel.Validators[0].(func(string) error)
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field. // usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
usagelogDescUpstreamModel := usagelogFields[5].Descriptor() usagelogDescUpstreamModel := usagelogFields[6].Descriptor()
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. // usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error) usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
// usagelogDescInputTokens is the schema descriptor for input_tokens field. // usagelogDescInputTokens is the schema descriptor for input_tokens field.
usagelogDescInputTokens := usagelogFields[8].Descriptor() usagelogDescInputTokens := usagelogFields[9].Descriptor()
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
// usagelogDescOutputTokens is the schema descriptor for output_tokens field. // usagelogDescOutputTokens is the schema descriptor for output_tokens field.
usagelogDescOutputTokens := usagelogFields[9].Descriptor() usagelogDescOutputTokens := usagelogFields[10].Descriptor()
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor() usagelogDescCacheCreationTokens := usagelogFields[11].Descriptor()
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
usagelogDescCacheReadTokens := usagelogFields[11].Descriptor() usagelogDescCacheReadTokens := usagelogFields[12].Descriptor()
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor() usagelogDescCacheCreation5mTokens := usagelogFields[13].Descriptor()
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor() usagelogDescCacheCreation1hTokens := usagelogFields[14].Descriptor()
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
// usagelogDescInputCost is the schema descriptor for input_cost field. // usagelogDescInputCost is the schema descriptor for input_cost field.
usagelogDescInputCost := usagelogFields[14].Descriptor() usagelogDescInputCost := usagelogFields[15].Descriptor()
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field. // usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
// usagelogDescOutputCost is the schema descriptor for output_cost field. // usagelogDescOutputCost is the schema descriptor for output_cost field.
usagelogDescOutputCost := usagelogFields[15].Descriptor() usagelogDescOutputCost := usagelogFields[16].Descriptor()
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
usagelogDescCacheCreationCost := usagelogFields[16].Descriptor() usagelogDescCacheCreationCost := usagelogFields[17].Descriptor()
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
usagelogDescCacheReadCost := usagelogFields[17].Descriptor() usagelogDescCacheReadCost := usagelogFields[18].Descriptor()
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
// usagelogDescTotalCost is the schema descriptor for total_cost field. // usagelogDescTotalCost is the schema descriptor for total_cost field.
usagelogDescTotalCost := usagelogFields[18].Descriptor() usagelogDescTotalCost := usagelogFields[19].Descriptor()
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
// usagelogDescActualCost is the schema descriptor for actual_cost field. // usagelogDescActualCost is the schema descriptor for actual_cost field.
usagelogDescActualCost := usagelogFields[19].Descriptor() usagelogDescActualCost := usagelogFields[20].Descriptor()
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
usagelogDescRateMultiplier := usagelogFields[20].Descriptor() usagelogDescRateMultiplier := usagelogFields[21].Descriptor()
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field. // usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[22].Descriptor() usagelogDescBillingType := usagelogFields[23].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field. // usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field. // usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[23].Descriptor() usagelogDescStream := usagelogFields[24].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field. // usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool) usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field. // usagelogDescUserAgent is the schema descriptor for user_agent field.
usagelogDescUserAgent := usagelogFields[26].Descriptor() usagelogDescUserAgent := usagelogFields[27].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field. // usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[27].Descriptor() usagelogDescIPAddress := usagelogFields[28].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field. // usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[28].Descriptor() usagelogDescImageCount := usagelogFields[29].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field. // usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field. // usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[29].Descriptor() usagelogDescImageSize := usagelogFields[30].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescMediaType is the schema descriptor for media_type field. // usagelogDescMediaType is the schema descriptor for media_type field.
usagelogDescMediaType := usagelogFields[30].Descriptor() usagelogDescMediaType := usagelogFields[31].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor() usagelogDescCacheTTLOverridden := usagelogFields[32].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field. // usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[32].Descriptor() usagelogDescCreatedAt := usagelogFields[33].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin() userMixin := schema.User{}.Mixin()

View File

@@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field {
field.String("model"). field.String("model").
MaxLen(100). MaxLen(100).
NotEmpty(), NotEmpty(),
// RequestedModel stores the client-requested model name for stable display and analytics.
// NULL means historical rows written before requested_model dual-write was introduced.
field.String("requested_model").
MaxLen(100).
Optional().
Nillable(),
// UpstreamModel stores the actual upstream model name when model mapping // UpstreamModel stores the actual upstream model name when model mapping
// is applied. NULL means no mapping — the requested model was used as-is. // is applied. NULL means no mapping — the requested model was used as-is.
field.String("upstream_model"). field.String("upstream_model").
@@ -181,6 +187,7 @@ func (UsageLog) Indexes() []ent.Index {
index.Fields("subscription_id"), index.Fields("subscription_id"),
index.Fields("created_at"), index.Fields("created_at"),
index.Fields("model"), index.Fields("model"),
index.Fields("requested_model"),
index.Fields("request_id"), index.Fields("request_id"),
// 复合索引用于时间范围查询 // 复合索引用于时间范围查询
index.Fields("user_id", "created_at"), index.Fields("user_id", "created_at"),

View File

@@ -32,6 +32,8 @@ type UsageLog struct {
RequestID string `json:"request_id,omitempty"` RequestID string `json:"request_id,omitempty"`
// Model holds the value of the "model" field. // Model holds the value of the "model" field.
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
// RequestedModel holds the value of the "requested_model" field.
RequestedModel *string `json:"requested_model,omitempty"`
// UpstreamModel holds the value of the "upstream_model" field. // UpstreamModel holds the value of the "upstream_model" field.
UpstreamModel *string `json:"upstream_model,omitempty"` UpstreamModel *string `json:"upstream_model,omitempty"`
// GroupID holds the value of the "group_id" field. // GroupID holds the value of the "group_id" field.
@@ -177,7 +179,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
@@ -232,6 +234,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.Model = value.String _m.Model = value.String
} }
case usagelog.FieldRequestedModel:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field requested_model", values[i])
} else if value.Valid {
_m.RequestedModel = new(string)
*_m.RequestedModel = value.String
}
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
if value, ok := values[i].(*sql.NullString); !ok { if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field upstream_model", values[i]) return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
@@ -486,6 +495,11 @@ func (_m *UsageLog) String() string {
builder.WriteString("model=") builder.WriteString("model=")
builder.WriteString(_m.Model) builder.WriteString(_m.Model)
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.RequestedModel; v != nil {
builder.WriteString("requested_model=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.UpstreamModel; v != nil { if v := _m.UpstreamModel; v != nil {
builder.WriteString("upstream_model=") builder.WriteString("upstream_model=")
builder.WriteString(*v) builder.WriteString(*v)

View File

@@ -24,6 +24,8 @@ const (
FieldRequestID = "request_id" FieldRequestID = "request_id"
// FieldModel holds the string denoting the model field in the database. // FieldModel holds the string denoting the model field in the database.
FieldModel = "model" FieldModel = "model"
// FieldRequestedModel holds the string denoting the requested_model field in the database.
FieldRequestedModel = "requested_model"
// FieldUpstreamModel holds the string denoting the upstream_model field in the database. // FieldUpstreamModel holds the string denoting the upstream_model field in the database.
FieldUpstreamModel = "upstream_model" FieldUpstreamModel = "upstream_model"
// FieldGroupID holds the string denoting the group_id field in the database. // FieldGroupID holds the string denoting the group_id field in the database.
@@ -137,6 +139,7 @@ var Columns = []string{
FieldAccountID, FieldAccountID,
FieldRequestID, FieldRequestID,
FieldModel, FieldModel,
FieldRequestedModel,
FieldUpstreamModel, FieldUpstreamModel,
FieldGroupID, FieldGroupID,
FieldSubscriptionID, FieldSubscriptionID,
@@ -182,6 +185,8 @@ var (
RequestIDValidator func(string) error RequestIDValidator func(string) error
// ModelValidator is a validator for the "model" field. It is called by the builders before save. // ModelValidator is a validator for the "model" field. It is called by the builders before save.
ModelValidator func(string) error ModelValidator func(string) error
// RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save.
RequestedModelValidator func(string) error
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. // UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
UpstreamModelValidator func(string) error UpstreamModelValidator func(string) error
// DefaultInputTokens holds the default value on creation for the "input_tokens" field. // DefaultInputTokens holds the default value on creation for the "input_tokens" field.
@@ -263,6 +268,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModel, opts...).ToFunc() return sql.OrderByField(FieldModel, opts...).ToFunc()
} }
// ByRequestedModel orders the results by the requested_model field.
func ByRequestedModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRequestedModel, opts...).ToFunc()
}
// ByUpstreamModel orders the results by the upstream_model field. // ByUpstreamModel orders the results by the upstream_model field.
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption { func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc() return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()

View File

@@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
} }
// RequestedModel applies equality check predicate on the "requested_model" field. It's identical to RequestedModelEQ.
func RequestedModel(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v))
}
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ. // UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
func UpstreamModel(v string) predicate.UsageLog { func UpstreamModel(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
@@ -410,6 +415,81 @@ func ModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v)) return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
} }
// RequestedModelEQ applies the EQ predicate on the "requested_model" field.
func RequestedModelEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v))
}
// RequestedModelNEQ applies the NEQ predicate on the "requested_model" field.
func RequestedModelNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldRequestedModel, v))
}
// RequestedModelIn applies the In predicate on the "requested_model" field.
func RequestedModelIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldRequestedModel, vs...))
}
// RequestedModelNotIn applies the NotIn predicate on the "requested_model" field.
func RequestedModelNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldRequestedModel, vs...))
}
// RequestedModelGT applies the GT predicate on the "requested_model" field.
func RequestedModelGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldRequestedModel, v))
}
// RequestedModelGTE applies the GTE predicate on the "requested_model" field.
func RequestedModelGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldRequestedModel, v))
}
// RequestedModelLT applies the LT predicate on the "requested_model" field.
func RequestedModelLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldRequestedModel, v))
}
// RequestedModelLTE applies the LTE predicate on the "requested_model" field.
func RequestedModelLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldRequestedModel, v))
}
// RequestedModelContains applies the Contains predicate on the "requested_model" field.
func RequestedModelContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldRequestedModel, v))
}
// RequestedModelHasPrefix applies the HasPrefix predicate on the "requested_model" field.
func RequestedModelHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldRequestedModel, v))
}
// RequestedModelHasSuffix applies the HasSuffix predicate on the "requested_model" field.
func RequestedModelHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldRequestedModel, v))
}
// RequestedModelIsNil applies the IsNil predicate on the "requested_model" field.
func RequestedModelIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldRequestedModel))
}
// RequestedModelNotNil applies the NotNil predicate on the "requested_model" field.
func RequestedModelNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldRequestedModel))
}
// RequestedModelEqualFold applies the EqualFold predicate on the "requested_model" field.
func RequestedModelEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldRequestedModel, v))
}
// RequestedModelContainsFold applies the ContainsFold predicate on the "requested_model" field.
func RequestedModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldRequestedModel, v))
}
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field. // UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
func UpstreamModelEQ(v string) predicate.UsageLog { func UpstreamModelEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))

View File

@@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
return _c return _c
} }
// SetRequestedModel sets the "requested_model" field.
func (_c *UsageLogCreate) SetRequestedModel(v string) *UsageLogCreate {
_c.mutation.SetRequestedModel(v)
return _c
}
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableRequestedModel(v *string) *UsageLogCreate {
if v != nil {
_c.SetRequestedModel(*v)
}
return _c
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate { func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
_c.mutation.SetUpstreamModel(v) _c.mutation.SetUpstreamModel(v)
@@ -610,6 +624,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
} }
} }
if v, ok := _c.mutation.RequestedModel(); ok {
if err := usagelog.RequestedModelValidator(v); err != nil {
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
}
}
if v, ok := _c.mutation.UpstreamModel(); ok { if v, ok := _c.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil { if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
@@ -733,6 +752,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldModel, field.TypeString, value) _spec.SetField(usagelog.FieldModel, field.TypeString, value)
_node.Model = value _node.Model = value
} }
if value, ok := _c.mutation.RequestedModel(); ok {
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
_node.RequestedModel = &value
}
if value, ok := _c.mutation.UpstreamModel(); ok { if value, ok := _c.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
_node.UpstreamModel = &value _node.UpstreamModel = &value
@@ -1034,6 +1057,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
return u return u
} }
// SetRequestedModel sets the "requested_model" field.
func (u *UsageLogUpsert) SetRequestedModel(v string) *UsageLogUpsert {
u.Set(usagelog.FieldRequestedModel, v)
return u
}
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateRequestedModel() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldRequestedModel)
return u
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (u *UsageLogUpsert) ClearRequestedModel() *UsageLogUpsert {
u.SetNull(usagelog.FieldRequestedModel)
return u
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert { func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
u.Set(usagelog.FieldUpstreamModel, v) u.Set(usagelog.FieldUpstreamModel, v)
@@ -1641,6 +1682,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
}) })
} }
// SetRequestedModel sets the "requested_model" field.
func (u *UsageLogUpsertOne) SetRequestedModel(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetRequestedModel(v)
})
}
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateRequestedModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateRequestedModel()
})
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (u *UsageLogUpsertOne) ClearRequestedModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearRequestedModel()
})
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne { func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {
@@ -2496,6 +2558,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
}) })
} }
// SetRequestedModel sets the "requested_model" field.
func (u *UsageLogUpsertBulk) SetRequestedModel(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetRequestedModel(v)
})
}
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateRequestedModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateRequestedModel()
})
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (u *UsageLogUpsertBulk) ClearRequestedModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearRequestedModel()
})
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk { func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {

View File

@@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
return _u return _u
} }
// SetRequestedModel sets the "requested_model" field.
func (_u *UsageLogUpdate) SetRequestedModel(v string) *UsageLogUpdate {
_u.mutation.SetRequestedModel(v)
return _u
}
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableRequestedModel(v *string) *UsageLogUpdate {
if v != nil {
_u.SetRequestedModel(*v)
}
return _u
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (_u *UsageLogUpdate) ClearRequestedModel() *UsageLogUpdate {
_u.mutation.ClearRequestedModel()
return _u
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate { func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
_u.mutation.SetUpstreamModel(v) _u.mutation.SetUpstreamModel(v)
@@ -765,6 +785,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
} }
} }
if v, ok := _u.mutation.RequestedModel(); ok {
if err := usagelog.RequestedModelValidator(v); err != nil {
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
}
}
if v, ok := _u.mutation.UpstreamModel(); ok { if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil { if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
@@ -820,6 +845,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Model(); ok { if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value) _spec.SetField(usagelog.FieldModel, field.TypeString, value)
} }
if value, ok := _u.mutation.RequestedModel(); ok {
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
}
if _u.mutation.RequestedModelCleared() {
_spec.ClearField(usagelog.FieldRequestedModel, field.TypeString)
}
if value, ok := _u.mutation.UpstreamModel(); ok { if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
} }
@@ -1208,6 +1239,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
return _u return _u
} }
// SetRequestedModel sets the "requested_model" field.
func (_u *UsageLogUpdateOne) SetRequestedModel(v string) *UsageLogUpdateOne {
_u.mutation.SetRequestedModel(v)
return _u
}
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableRequestedModel(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetRequestedModel(*v)
}
return _u
}
// ClearRequestedModel clears the value of the "requested_model" field.
func (_u *UsageLogUpdateOne) ClearRequestedModel() *UsageLogUpdateOne {
_u.mutation.ClearRequestedModel()
return _u
}
// SetUpstreamModel sets the "upstream_model" field. // SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne { func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
_u.mutation.SetUpstreamModel(v) _u.mutation.SetUpstreamModel(v)
@@ -1884,6 +1935,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
} }
} }
if v, ok := _u.mutation.RequestedModel(); ok {
if err := usagelog.RequestedModelValidator(v); err != nil {
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
}
}
if v, ok := _u.mutation.UpstreamModel(); ok { if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil { if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
@@ -1956,6 +2012,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if value, ok := _u.mutation.Model(); ok { if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value) _spec.SetField(usagelog.FieldModel, field.TypeString, value)
} }
if value, ok := _u.mutation.RequestedModel(); ok {
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
}
if _u.mutation.RequestedModelCleared() {
_spec.ClearField(usagelog.FieldRequestedModel, field.TypeString)
}
if value, ok := _u.mutation.UpstreamModel(); ok { if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
} }

View File

@@ -94,6 +94,10 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
@@ -195,6 +199,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -230,6 +236,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -263,6 +271,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -314,6 +324,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=

View File

@@ -110,6 +110,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency, DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance, DefaultBalance: settings.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions, DefaultSubscriptions: defaultSubscriptions,
@@ -176,6 +177,7 @@ type UpdateSettingsRequest struct {
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置 // 默认配置
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
@@ -417,6 +419,55 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
customMenuJSON = string(menuBytes) customMenuJSON = string(menuBytes)
} }
// 自定义端点验证
const (
maxCustomEndpoints = 10
maxEndpointNameLen = 50
maxEndpointURLLen = 2048
maxEndpointDescriptionLen = 200
)
customEndpointsJSON := previousSettings.CustomEndpoints
if req.CustomEndpoints != nil {
endpoints := *req.CustomEndpoints
if len(endpoints) > maxCustomEndpoints {
response.BadRequest(c, "Too many custom endpoints (max 10)")
return
}
for _, ep := range endpoints {
if strings.TrimSpace(ep.Name) == "" {
response.BadRequest(c, "Custom endpoint name is required")
return
}
if len(ep.Name) > maxEndpointNameLen {
response.BadRequest(c, "Custom endpoint name is too long (max 50 characters)")
return
}
if strings.TrimSpace(ep.Endpoint) == "" {
response.BadRequest(c, "Custom endpoint URL is required")
return
}
if len(ep.Endpoint) > maxEndpointURLLen {
response.BadRequest(c, "Custom endpoint URL is too long (max 2048 characters)")
return
}
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(ep.Endpoint)); err != nil {
response.BadRequest(c, "Custom endpoint URL must be an absolute http(s) URL")
return
}
if len(ep.Description) > maxEndpointDescriptionLen {
response.BadRequest(c, "Custom endpoint description is too long (max 200 characters)")
return
}
}
endpointBytes, err := json.Marshal(endpoints)
if err != nil {
response.BadRequest(c, "Failed to serialize custom endpoints")
return
}
customEndpointsJSON = string(endpointBytes)
}
// Ops metrics collector interval validation (seconds). // Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil { if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds v := *req.OpsMetricsIntervalSeconds
@@ -495,6 +546,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PurchaseSubscriptionURL: purchaseURL, PurchaseSubscriptionURL: purchaseURL,
SoraClientEnabled: req.SoraClientEnabled, SoraClientEnabled: req.SoraClientEnabled,
CustomMenuItems: customMenuJSON, CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency, DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance, DefaultBalance: req.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions, DefaultSubscriptions: defaultSubscriptions,
@@ -592,6 +644,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
SoraClientEnabled: updatedSettings.SoraClientEnabled, SoraClientEnabled: updatedSettings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance, DefaultBalance: updatedSettings.DefaultBalance,
DefaultSubscriptions: updatedDefaultSubscriptions, DefaultSubscriptions: updatedDefaultSubscriptions,

View File

@@ -276,11 +276,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if limit := a.GetQuotaDailyLimit(); limit > 0 { if limit := a.GetQuotaDailyLimit(); limit > 0 {
out.QuotaDailyLimit = &limit out.QuotaDailyLimit = &limit
used := a.GetQuotaDailyUsed() used := a.GetQuotaDailyUsed()
if a.IsDailyQuotaPeriodExpired() {
used = 0
}
out.QuotaDailyUsed = &used out.QuotaDailyUsed = &used
} }
if limit := a.GetQuotaWeeklyLimit(); limit > 0 { if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
out.QuotaWeeklyLimit = &limit out.QuotaWeeklyLimit = &limit
used := a.GetQuotaWeeklyUsed() used := a.GetQuotaWeeklyUsed()
if a.IsWeeklyQuotaPeriodExpired() {
used = 0
}
out.QuotaWeeklyUsed = &used out.QuotaWeeklyUsed = &used
} }
// 固定时间重置配置 // 固定时间重置配置
@@ -516,14 +522,17 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
// 普通用户 DTO严禁包含管理员字段例如 account_rate_multiplier、ip_address、account // 普通用户 DTO严禁包含管理员字段例如 account_rate_multiplier、ip_address、account
requestType := l.EffectiveRequestType() requestType := l.EffectiveRequestType()
stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode) stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode)
requestedModel := l.RequestedModel
if requestedModel == "" {
requestedModel = l.Model
}
return UsageLog{ return UsageLog{
ID: l.ID, ID: l.ID,
UserID: l.UserID, UserID: l.UserID,
APIKeyID: l.APIKeyID, APIKeyID: l.APIKeyID,
AccountID: l.AccountID, AccountID: l.AccountID,
RequestID: l.RequestID, RequestID: l.RequestID,
Model: l.Model, Model: requestedModel,
UpstreamModel: l.UpstreamModel,
ServiceTier: l.ServiceTier, ServiceTier: l.ServiceTier,
ReasoningEffort: l.ReasoningEffort, ReasoningEffort: l.ReasoningEffort,
InboundEndpoint: l.InboundEndpoint, InboundEndpoint: l.InboundEndpoint,
@@ -580,6 +589,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
} }
return &AdminUsageLog{ return &AdminUsageLog{
UsageLog: usageLogFromServiceUser(l), UsageLog: usageLogFromServiceUser(l),
UpstreamModel: l.UpstreamModel,
AccountRateMultiplier: l.AccountRateMultiplier, AccountRateMultiplier: l.AccountRateMultiplier,
IPAddress: l.IPAddress, IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account), Account: AccountSummaryFromService(l.Account),

View File

@@ -1,6 +1,7 @@
package dto package dto
import ( import (
"encoding/json"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -106,6 +107,47 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12) require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
} }
func TestUsageLogFromService_UsesRequestedModelAndKeepsUpstreamAdminOnly(t *testing.T) {
t.Parallel()
upstreamModel := "claude-sonnet-4-20250514"
log := &service.UsageLog{
RequestID: "req_4",
Model: upstreamModel,
RequestedModel: "claude-sonnet-4",
UpstreamModel: &upstreamModel,
}
userDTO := UsageLogFromService(log)
adminDTO := UsageLogFromServiceAdmin(log)
require.Equal(t, "claude-sonnet-4", userDTO.Model)
require.Equal(t, "claude-sonnet-4", adminDTO.Model)
userJSON, err := json.Marshal(userDTO)
require.NoError(t, err)
require.NotContains(t, string(userJSON), "upstream_model")
adminJSON, err := json.Marshal(adminDTO)
require.NoError(t, err)
require.Contains(t, string(adminJSON), `"upstream_model":"claude-sonnet-4-20250514"`)
}
func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t *testing.T) {
t.Parallel()
log := &service.UsageLog{
RequestID: "req_legacy",
Model: "claude-3",
}
userDTO := UsageLogFromService(log)
adminDTO := UsageLogFromServiceAdmin(log)
require.Equal(t, "claude-3", userDTO.Model)
require.Equal(t, "claude-3", adminDTO.Model)
}
func f64Ptr(value float64) *float64 { func f64Ptr(value float64) *float64 {
return &value return &value
} }

View File

@@ -15,6 +15,13 @@ type CustomMenuItem struct {
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
} }
// CustomEndpoint represents an admin-configured API endpoint for quick copy.
type CustomEndpoint struct {
Name string `json:"name"`
Endpoint string `json:"endpoint"`
Description string `json:"description"`
}
// SystemSettings represents the admin settings API response payload. // SystemSettings represents the admin settings API response payload.
type SystemSettings struct { type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
@@ -56,6 +63,7 @@ type SystemSettings struct {
PurchaseSubscriptionURL string `json:"purchase_subscription_url"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
@@ -114,6 +122,7 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
@@ -218,3 +227,17 @@ func ParseUserVisibleMenuItems(raw string) []CustomMenuItem {
} }
return filtered return filtered
} }
// ParseCustomEndpoints parses a JSON string into a slice of CustomEndpoint.
// Returns empty slice on empty/invalid input.
func ParseCustomEndpoints(raw string) []CustomEndpoint {
raw = strings.TrimSpace(raw)
if raw == "" || raw == "[]" {
return []CustomEndpoint{}
}
var items []CustomEndpoint
if err := json.Unmarshal([]byte(raw), &items); err != nil {
return []CustomEndpoint{}
}
return items
}

View File

@@ -334,9 +334,6 @@ type UsageLog struct {
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"` RequestID string `json:"request_id"`
Model string `json:"model"` Model string `json:"model"`
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"`
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string `json:"service_tier,omitempty"` ServiceTier *string `json:"service_tier,omitempty"`
// ReasoningEffort is the request's reasoning effort level. // ReasoningEffort is the request's reasoning effort level.
@@ -396,6 +393,10 @@ type UsageLog struct {
type AdminUsageLog struct { type AdminUsageLog struct {
UsageLog UsageLog
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"`
// AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理) // AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"` AccountRateMultiplier *float64 `json:"account_rate_multiplier"`

View File

@@ -181,7 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now() forwardStart := time.Now()
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model") defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()

View File

@@ -37,6 +37,16 @@ type OpenAIGatewayHandler struct {
cfg *config.Config cfg *config.Config
} }
func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackModel string) string {
if fallbackModel = strings.TrimSpace(fallbackModel); fallbackModel != "" {
return fallbackModel
}
if apiKey == nil || apiKey.Group == nil {
return ""
}
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler( func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService, gatewayService *service.OpenAIGatewayService,
@@ -657,9 +667,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now() forwardStart := time.Now()
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时, // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型 // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1
defaultMappedModel := c.GetString("openai_messages_fallback_model") defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()

View File

@@ -352,6 +352,30 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
}) })
} }
func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
t.Run("prefers_explicit_fallback_model", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
}
require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 "))
})
t.Run("uses_group_default_on_normal_path", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
}
require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, ""))
})
t.Run("returns_empty_without_group_default", func(t *testing.T) {
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, ""))
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, ""))
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{
Group: &service.Group{},
}, ""))
})
}
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@@ -52,6 +52,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,

View File

@@ -632,8 +632,8 @@ func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
// thinking.type is ignored for effort; default xhigh applies. // thinking.type is ignored for effort; default high applies.
assert.Equal(t, "xhigh", resp.Reasoning.Effort) assert.Equal(t, "high", resp.Reasoning.Effort)
assert.Equal(t, "auto", resp.Reasoning.Summary) assert.Equal(t, "auto", resp.Reasoning.Summary)
assert.Contains(t, resp.Include, "reasoning.encrypted_content") assert.Contains(t, resp.Include, "reasoning.encrypted_content")
assert.NotContains(t, resp.Include, "reasoning.summary") assert.NotContains(t, resp.Include, "reasoning.summary")
@@ -650,8 +650,8 @@ func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
// thinking.type is ignored for effort; default xhigh applies. // thinking.type is ignored for effort; default high applies.
assert.Equal(t, "xhigh", resp.Reasoning.Effort) assert.Equal(t, "high", resp.Reasoning.Effort)
assert.Equal(t, "auto", resp.Reasoning.Summary) assert.Equal(t, "auto", resp.Reasoning.Summary)
assert.NotContains(t, resp.Include, "reasoning.summary") assert.NotContains(t, resp.Include, "reasoning.summary")
} }
@@ -666,9 +666,9 @@ func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
// Default effort applies (high → xhigh) even when thinking is disabled. // Default effort applies (high → high) even when thinking is disabled.
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
assert.Equal(t, "xhigh", resp.Reasoning.Effort) assert.Equal(t, "high", resp.Reasoning.Effort)
} }
func TestAnthropicToResponses_NoThinking(t *testing.T) { func TestAnthropicToResponses_NoThinking(t *testing.T) {
@@ -680,9 +680,9 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
// Default effort applies (high → xhigh) when no thinking/output_config is set. // Default effort applies (high → high) when no thinking/output_config is set.
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
assert.Equal(t, "xhigh", resp.Reasoning.Effort) assert.Equal(t, "high", resp.Reasoning.Effort)
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -690,7 +690,7 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) { func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) {
// Default is xhigh, but output_config.effort="low" overrides. low→low after mapping. // Default is high, but output_config.effort="low" overrides. low→low after mapping.
req := &AnthropicRequest{ req := &AnthropicRequest{
Model: "gpt-5.2", Model: "gpt-5.2",
MaxTokens: 1024, MaxTokens: 1024,
@@ -708,7 +708,7 @@ func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) {
func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) { func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) {
// No thinking field, but output_config.effort="medium" → creates reasoning. // No thinking field, but output_config.effort="medium" → creates reasoning.
// medium→high after mapping. // medium→medium after 1:1 mapping.
req := &AnthropicRequest{ req := &AnthropicRequest{
Model: "gpt-5.2", Model: "gpt-5.2",
MaxTokens: 1024, MaxTokens: 1024,
@@ -719,12 +719,12 @@ func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
assert.Equal(t, "high", resp.Reasoning.Effort) assert.Equal(t, "medium", resp.Reasoning.Effort)
assert.Equal(t, "auto", resp.Reasoning.Summary) assert.Equal(t, "auto", resp.Reasoning.Summary)
} }
func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) { func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
// output_config.effort="high" → mapped to "xhigh". // output_config.effort="high" → mapped to "high" (1:1, both sides' default).
req := &AnthropicRequest{ req := &AnthropicRequest{
Model: "gpt-5.2", Model: "gpt-5.2",
MaxTokens: 1024, MaxTokens: 1024,
@@ -732,6 +732,22 @@ func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
OutputConfig: &AnthropicOutputConfig{Effort: "high"}, OutputConfig: &AnthropicOutputConfig{Effort: "high"},
} }
resp, err := AnthropicToResponses(req)
require.NoError(t, err)
require.NotNil(t, resp.Reasoning)
assert.Equal(t, "high", resp.Reasoning.Effort)
assert.Equal(t, "auto", resp.Reasoning.Summary)
}
func TestAnthropicToResponses_OutputConfigMax(t *testing.T) {
// output_config.effort="max" → mapped to OpenAI's highest supported level "xhigh".
req := &AnthropicRequest{
Model: "gpt-5.2",
MaxTokens: 1024,
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
OutputConfig: &AnthropicOutputConfig{Effort: "max"},
}
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
@@ -740,7 +756,7 @@ func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
} }
func TestAnthropicToResponses_NoOutputConfig(t *testing.T) { func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
// No output_config → default xhigh regardless of thinking.type. // No output_config → default high regardless of thinking.type.
req := &AnthropicRequest{ req := &AnthropicRequest{
Model: "gpt-5.2", Model: "gpt-5.2",
MaxTokens: 1024, MaxTokens: 1024,
@@ -751,11 +767,11 @@ func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
assert.Equal(t, "xhigh", resp.Reasoning.Effort) assert.Equal(t, "high", resp.Reasoning.Effort)
} }
func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) { func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
// output_config present but effort empty (e.g. only format set) → default xhigh. // output_config present but effort empty (e.g. only format set) → default high.
req := &AnthropicRequest{ req := &AnthropicRequest{
Model: "gpt-5.2", Model: "gpt-5.2",
MaxTokens: 1024, MaxTokens: 1024,
@@ -766,7 +782,7 @@ func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
resp, err := AnthropicToResponses(req) resp, err := AnthropicToResponses(req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resp.Reasoning) require.NotNil(t, resp.Reasoning)
assert.Equal(t, "xhigh", resp.Reasoning.Effort) assert.Equal(t, "high", resp.Reasoning.Effort)
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------

View File

@@ -46,9 +46,10 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
} }
// Determine reasoning effort: only output_config.effort controls the // Determine reasoning effort: only output_config.effort controls the
// level; thinking.type is ignored. Default is xhigh when unset. // level; thinking.type is ignored. Default is high when unset (both
// Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh. // Anthropic and OpenAI default to high).
effort := "high" // default → maps to xhigh // Anthropic levels map 1:1 to OpenAI: low→low, medium→medium, high→high, max→xhigh.
effort := "high" // default → both sides' default
if req.OutputConfig != nil && req.OutputConfig.Effort != "" { if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
effort = req.OutputConfig.Effort effort = req.OutputConfig.Effort
} }
@@ -380,18 +381,19 @@ func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string {
// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to // mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to
// OpenAI Responses API effort levels. // OpenAI Responses API effort levels.
// //
// Both APIs default to "high". The mapping is 1:1 for shared levels;
// only Anthropic's "max" (Opus 4.6 exclusive) maps to OpenAI's "xhigh"
// (GPT-5.2+ exclusive) as both represent the highest reasoning tier.
//
// low → low // low → low
// medium → high // medium → medium
// high → xhigh // high → high
// max → xhigh
func mapAnthropicEffortToResponses(effort string) string { func mapAnthropicEffortToResponses(effort string) string {
switch effort { if effort == "max" {
case "medium":
return "high"
case "high":
return "xhigh" return "xhigh"
default:
return effort // "low" and any unknown values pass through unchanged
} }
return effort // low→low, medium→medium, high→high, unknown→passthrough
} }
// convertAnthropicToolsToResponses maps Anthropic tool definitions to // convertAnthropicToolsToResponses maps Anthropic tool definitions to

View File

@@ -181,6 +181,35 @@ func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL) assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
} }
func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "system", Content: json.RawMessage(`[{"type":"text","text":"You are a careful visual assistant."}]`)},
{Role: "user", Content: json.RawMessage(`[{"type":"text","text":"Describe this image"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
var systemParts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[0].Content, &systemParts))
require.Len(t, systemParts, 1)
assert.Equal(t, "input_text", systemParts[0].Type)
assert.Equal(t, "You are a careful visual assistant.", systemParts[0].Text)
var userParts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[1].Content, &userParts))
require.Len(t, userParts, 2)
assert.Equal(t, "input_image", userParts[1].Type)
assert.Equal(t, "data:image/png;base64,abc123", userParts[1].ImageURL)
}
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) { func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
req := &ChatCompletionsRequest{ req := &ChatCompletionsRequest{
Model: "gpt-4o", Model: "gpt-4o",
@@ -398,6 +427,45 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent) assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
} }
func TestChatCompletionsToResponses_ToolArrayContent(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Use the tool"`)},
{
Role: "assistant",
ToolCalls: []ChatToolCall{
{
ID: "call_1",
Type: "function",
Function: ChatFunctionCall{
Name: "inspect_image",
Arguments: `{}`,
},
},
},
},
{
Role: "tool",
ToolCallID: "call_1",
Content: json.RawMessage(
`[{"type":"text","text":"image width: 100"},{"type":"image_url","image_url":{"url":"data:image/png;base64,ignored"}},{"type":"text","text":"; image height: 200"}]`,
),
},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 3)
assert.Equal(t, "function_call_output", items[2].Type)
assert.Equal(t, "call_1", items[2].CallID)
assert.Equal(t, "image width: 100; image height: 200", items[2].Output)
}
func TestResponsesToChatCompletions_Incomplete(t *testing.T) { func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
resp := &ResponsesResponse{ resp := &ResponsesResponse{
ID: "resp_inc", ID: "resp_inc",

View File

@@ -6,6 +6,11 @@ import (
"strings" "strings"
) )
type chatMessageContent struct {
Text *string
Parts []ChatContentPart
}
// ChatCompletionsToResponses converts a Chat Completions request into a // ChatCompletionsToResponses converts a Chat Completions request into a
// Responses API request. The upstream always streams, so Stream is forced to // Responses API request. The upstream always streams, so Stream is forced to
// true. store is always false and reasoning.encrypted_content is always // true. store is always false and reasoning.encrypted_content is always
@@ -113,11 +118,11 @@ func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
// chatSystemToResponses converts a system message. // chatSystemToResponses converts a system message.
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) { func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
text, err := parseChatContent(m.Content) parsed, err := parseChatMessageContent(m.Content)
if err != nil { if err != nil {
return nil, err return nil, err
} }
content, err := json.Marshal(text) content, err := marshalChatInputContent(parsed)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -127,39 +132,11 @@ func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// chatUserToResponses converts a user message, handling both plain strings and // chatUserToResponses converts a user message, handling both plain strings and
// multi-modal content arrays. // multi-modal content arrays.
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) { func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// Try plain string first. parsed, err := parseChatMessageContent(m.Content)
var s string if err != nil {
if err := json.Unmarshal(m.Content, &s); err == nil {
content, _ := json.Marshal(s)
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
}
var parts []ChatContentPart
if err := json.Unmarshal(m.Content, &parts); err != nil {
return nil, fmt.Errorf("parse user content: %w", err) return nil, fmt.Errorf("parse user content: %w", err)
} }
content, err := marshalChatInputContent(parsed)
var responseParts []ResponsesContentPart
for _, p := range parts {
switch p.Type {
case "text":
if p.Text != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_text",
Text: p.Text,
})
}
case "image_url":
if p.ImageURL != nil && p.ImageURL.URL != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_image",
ImageURL: p.ImageURL.URL,
})
}
}
}
content, err := json.Marshal(responseParts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -312,16 +289,79 @@ func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
} }
// parseChatContent returns the string value of a ChatMessage Content field. // parseChatContent returns the string value of a ChatMessage Content field.
// Content must be a JSON string. Returns "" if content is null or empty. // Content can be a JSON string or an array of typed parts. Array content is
// flattened to text by concatenating text parts and ignoring non-text parts.
func parseChatContent(raw json.RawMessage) (string, error) { func parseChatContent(raw json.RawMessage) (string, error) {
parsed, err := parseChatMessageContent(raw)
if err != nil {
return "", err
}
if parsed.Text != nil {
return *parsed.Text, nil
}
return flattenChatContentParts(parsed.Parts), nil
}
func parseChatMessageContent(raw json.RawMessage) (chatMessageContent, error) {
if len(raw) == 0 { if len(raw) == 0 {
return "", nil return chatMessageContent{Text: stringPtr("")}, nil
} }
var s string var s string
if err := json.Unmarshal(raw, &s); err != nil { if err := json.Unmarshal(raw, &s); err == nil {
return "", fmt.Errorf("parse content as string: %w", err) return chatMessageContent{Text: &s}, nil
} }
return s, nil
var parts []ChatContentPart
if err := json.Unmarshal(raw, &parts); err == nil {
return chatMessageContent{Parts: parts}, nil
}
return chatMessageContent{}, fmt.Errorf("parse content as string or parts array")
}
func marshalChatInputContent(content chatMessageContent) (json.RawMessage, error) {
if content.Text != nil {
return json.Marshal(*content.Text)
}
return json.Marshal(convertChatContentPartsToResponses(content.Parts))
}
func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesContentPart {
var responseParts []ResponsesContentPart
for _, p := range parts {
switch p.Type {
case "text":
if p.Text != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_text",
Text: p.Text,
})
}
case "image_url":
if p.ImageURL != nil && p.ImageURL.URL != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_image",
ImageURL: p.ImageURL.URL,
})
}
}
}
return responseParts
}
func flattenChatContentParts(parts []ChatContentPart) string {
var textParts []string
for _, p := range parts {
if p.Type == "text" && p.Text != "" {
textParts = append(textParts, p.Text)
}
}
return strings.Join(textParts, "")
}
func stringPtr(s string) *string {
return &s
} }
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy // convertChatToolsToResponses maps Chat Completions tool definitions and legacy

View File

@@ -28,50 +28,64 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
// 2. every INSERT/CTE VALUES column list in this file
// 3. execUsageLogInsertNoResult placeholder positions
// 4. scanUsageLog selected column order (via usageLogSelectColumns)
//
// When adding a usage_logs column, update all of those call sites together.
var usageLogInsertArgTypes = [...]string{ var usageLogInsertArgTypes = [...]string{
"bigint", "bigint", // user_id
"bigint", "bigint", // api_key_id
"bigint", "bigint", // account_id
"text", "text", // request_id
"text", "text", // model
"text", "text", // requested_model
"bigint", "text", // upstream_model
"bigint", "bigint", // group_id
"integer", "bigint", // subscription_id
"integer", "integer", // input_tokens
"integer", "integer", // output_tokens
"integer", "integer", // cache_creation_tokens
"integer", "integer", // cache_read_tokens
"integer", "integer", // cache_creation_5m_tokens
"numeric", "integer", // cache_creation_1h_tokens
"numeric", "numeric", // input_cost
"numeric", "numeric", // output_cost
"numeric", "numeric", // cache_creation_cost
"numeric", "numeric", // cache_read_cost
"numeric", "numeric", // total_cost
"numeric", "numeric", // actual_cost
"numeric", "numeric", // rate_multiplier
"smallint", "numeric", // account_rate_multiplier
"smallint", "smallint", // billing_type
"boolean", "smallint", // request_type
"boolean", "boolean", // stream
"integer", "boolean", // openai_ws_mode
"integer", "integer", // duration_ms
"text", "integer", // first_token_ms
"text", "text", // user_agent
"integer", "text", // ip_address
"text", "integer", // image_count
"text", "text", // image_size
"text", "text", // media_type
"text", "text", // service_tier
"text", "text", // reasoning_effort
"text", "text", // inbound_endpoint
"boolean", "text", // upstream_endpoint
"timestamptz", "boolean", // cache_ttl_overridden
"timestamptz", // created_at
} }
const rawUsageLogModelColumn = "model"
// rawUsageLogModelColumn preserves the exact stored usage_logs.model semantics for direct filters.
// Historical rows may contain upstream/billing model values, while newer rows store requested_model.
// Requested/upstream/mapping analytics must use resolveModelDimensionExpression instead.
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{ var dateFormatWhitelist = map[string]string{
"hour": "YYYY-MM-DD HH24:00", "hour": "YYYY-MM-DD HH24:00",
@@ -88,6 +102,30 @@ func safeDateFormat(granularity string) string {
return "YYYY-MM-DD" return "YYYY-MM-DD"
} }
// appendRawUsageLogModelWhereCondition keeps direct model filters on the raw model column for backward
// compatibility with historical rows. Requested/upstream analytics must use
// resolveModelDimensionExpression instead.
func appendRawUsageLogModelWhereCondition(conditions []string, args []any, model string) ([]string, []any) {
if strings.TrimSpace(model) == "" {
return conditions, args
}
conditions = append(conditions, fmt.Sprintf("%s = $%d", rawUsageLogModelColumn, len(args)+1))
args = append(args, model)
return conditions, args
}
// appendRawUsageLogModelQueryFilter keeps direct model filters on the raw model column for backward
// compatibility with historical rows. Requested/upstream analytics must use
// resolveModelDimensionExpression instead.
func appendRawUsageLogModelQueryFilter(query string, args []any, model string) (string, []any) {
if strings.TrimSpace(model) == "" {
return query, args
}
query += fmt.Sprintf(" AND %s = $%d", rawUsageLogModelColumn, len(args)+1)
args = append(args, model)
return query, args
}
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
sql sqlExecutor sql sqlExecutor
@@ -278,6 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
@@ -313,12 +352,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $1, $2, $3, $4, $5, $6, $7,
$7, $8, $8, $9,
$9, $10, $11, $12, $10, $11, $12, $13,
$13, $14, $14, $15,
$15, $16, $17, $18, $19, $20, $16, $17, $18, $19, $20, $21,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
@@ -709,6 +748,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
@@ -779,6 +819,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
@@ -820,6 +861,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
@@ -901,6 +943,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
@@ -937,7 +980,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*39) args := make([]any, 0, len(preparedList)*40)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
@@ -968,6 +1011,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
@@ -1009,6 +1053,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
@@ -1058,6 +1103,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
account_id, account_id,
request_id, request_id,
model, model,
requested_model,
upstream_model, upstream_model,
group_id, group_id,
subscription_id, subscription_id,
@@ -1093,12 +1139,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $1, $2, $3, $4, $5, $6, $7,
$7, $8, $8, $9,
$9, $10, $11, $12, $10, $11, $12, $13,
$13, $14, $14, $15,
$15, $16, $17, $18, $19, $20, $16, $17, $18, $19, $20, $21,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
@@ -1130,6 +1176,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort) reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint) inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint)
requestedModel := strings.TrimSpace(log.RequestedModel)
if requestedModel == "" {
requestedModel = strings.TrimSpace(log.Model)
}
upstreamModel := nullString(log.UpstreamModel) upstreamModel := nullString(log.UpstreamModel)
var requestIDArg any var requestIDArg any
@@ -1148,6 +1198,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.AccountID, log.AccountID,
requestIDArg, requestIDArg,
log.Model, log.Model,
nullString(&requestedModel),
upstreamModel, upstreamModel,
groupID, groupID,
subscriptionID, subscriptionID,
@@ -1702,7 +1753,7 @@ func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, acco
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据 // GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
// 性能优化:数据库层聚合计算,避免应用层循环统计 // 性能优化:数据库层聚合计算,避免应用层循环统计
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := ` query := fmt.Sprintf(`
SELECT SELECT
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
@@ -1712,8 +1763,8 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN
COALESCE(SUM(actual_cost), 0) as total_actual_cost, COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs FROM usage_logs
WHERE model = $1 AND created_at >= $2 AND created_at < $3 WHERE %s = $1 AND created_at >= $2 AND created_at < $3
` `, rawUsageLogModelColumn)
var stats usagestats.UsageStats var stats usagestats.UsageStats
if err := scanSingleRow( if err := scanSingleRow(
@@ -1837,7 +1888,7 @@ func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
} }
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" query := fmt.Sprintf("SELECT %s FROM usage_logs WHERE %s = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000", usageLogSelectColumns, rawUsageLogModelColumn)
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
return logs, nil, err return logs, nil, err
} }
@@ -2532,10 +2583,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID) args = append(args, filters.GroupID)
} }
if filters.Model != "" { conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model)
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil { if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
@@ -2768,10 +2816,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID) args = append(args, groupID)
} }
if model != "" { query, args = appendRawUsageLogModelQueryFilter(query, args, model)
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil { if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
@@ -3126,13 +3171,14 @@ func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayS
// resolveModelDimensionExpression maps model source type to a safe SQL expression. // resolveModelDimensionExpression maps model source type to a safe SQL expression.
func resolveModelDimensionExpression(modelType string) string { func resolveModelDimensionExpression(modelType string) string {
requestedExpr := "COALESCE(NULLIF(TRIM(requested_model), ''), model)"
switch usagestats.NormalizeModelSource(modelType) { switch usagestats.NormalizeModelSource(modelType) {
case usagestats.ModelSourceUpstream: case usagestats.ModelSourceUpstream:
return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)" return fmt.Sprintf("COALESCE(NULLIF(TRIM(upstream_model), ''), %s)", requestedExpr)
case usagestats.ModelSourceMapping: case usagestats.ModelSourceMapping:
return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))" return fmt.Sprintf("(%s || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), %s))", requestedExpr, requestedExpr)
default: default:
return "model" return requestedExpr
} }
} }
@@ -3204,10 +3250,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID) args = append(args, filters.GroupID)
} }
if filters.Model != "" { conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model)
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil { if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
@@ -3336,10 +3379,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID) args = append(args, groupID)
} }
if model != "" { query, args = appendRawUsageLogModelQueryFilter(query, args, model)
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil { if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
@@ -3410,10 +3450,7 @@ func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID) args = append(args, groupID)
} }
if model != "" { query, args = appendRawUsageLogModelQueryFilter(query, args, model)
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil { if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
@@ -3888,6 +3925,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
accountID int64 accountID int64
requestID sql.NullString requestID sql.NullString
model string model string
requestedModel sql.NullString
upstreamModel sql.NullString upstreamModel sql.NullString
groupID sql.NullInt64 groupID sql.NullInt64
subscriptionID sql.NullInt64 subscriptionID sql.NullInt64
@@ -3931,6 +3969,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&accountID, &accountID,
&requestID, &requestID,
&model, &model,
&requestedModel,
&upstreamModel, &upstreamModel,
&groupID, &groupID,
&subscriptionID, &subscriptionID,
@@ -3975,6 +4014,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
APIKeyID: apiKeyID, APIKeyID: apiKeyID,
AccountID: accountID, AccountID: accountID,
Model: model, Model: model,
RequestedModel: coalesceTrimmedString(requestedModel, model),
InputTokens: inputTokens, InputTokens: inputTokens,
OutputTokens: outputTokens, OutputTokens: outputTokens,
CacheCreationTokens: cacheCreationTokens, CacheCreationTokens: cacheCreationTokens,
@@ -4181,6 +4221,13 @@ func nullString(v *string) sql.NullString {
return sql.NullString{String: *v, Valid: true} return sql.NullString{String: *v, Valid: true}
} }
func coalesceTrimmedString(v sql.NullString, fallback string) string {
if v.Valid && strings.TrimSpace(v.String) != "" {
return v.String
}
return fallback
}
func setToSlice(set map[int64]struct{}) []int64 { func setToSlice(set map[int64]struct{}) []int64 {
out := make([]int64, 0, len(set)) out := make([]int64, 0, len(set))
for id := range set { for id := range set {

View File

@@ -34,11 +34,11 @@ func TestResolveModelDimensionExpression(t *testing.T) {
modelType string modelType string
want string want string
}{ }{
{usagestats.ModelSourceRequested, "model"}, {usagestats.ModelSourceRequested, "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"}, {usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model))"},
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"}, {usagestats.ModelSourceMapping, "(COALESCE(NULLIF(TRIM(requested_model), ''), model) || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model)))"},
{"", "model"}, {"", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
{"invalid", "model"}, {"invalid", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
} }
for _, tc := range tests { for _, tc := range tests {

View File

@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
@@ -21,20 +22,21 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
log := &service.UsageLog{ log := &service.UsageLog{
UserID: 1, UserID: 1,
APIKeyID: 2, APIKeyID: 2,
AccountID: 3, AccountID: 3,
RequestID: "req-1", RequestID: "req-1",
Model: "gpt-5", Model: "gpt-5",
InputTokens: 10, RequestedModel: "gpt-5",
OutputTokens: 20, InputTokens: 10,
TotalCost: 1, OutputTokens: 20,
ActualCost: 1, TotalCost: 1,
BillingType: service.BillingTypeBalance, ActualCost: 1,
RequestType: service.RequestTypeWSV2, BillingType: service.BillingTypeBalance,
Stream: false, RequestType: service.RequestTypeWSV2,
OpenAIWSMode: false, Stream: false,
CreatedAt: createdAt, OpenAIWSMode: false,
CreatedAt: createdAt,
} }
mock.ExpectQuery("INSERT INTO usage_logs"). mock.ExpectQuery("INSERT INTO usage_logs").
@@ -44,6 +46,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.AccountID, log.AccountID,
log.RequestID, log.RequestID,
log.Model, log.Model,
log.RequestedModel,
sqlmock.AnyArg(), // upstream_model sqlmock.AnyArg(), // upstream_model
sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id sqlmock.AnyArg(), // subscription_id
@@ -99,13 +102,14 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC) createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC)
serviceTier := "priority" serviceTier := "priority"
log := &service.UsageLog{ log := &service.UsageLog{
UserID: 1, UserID: 1,
APIKeyID: 2, APIKeyID: 2,
AccountID: 3, AccountID: 3,
RequestID: "req-service-tier", RequestID: "req-service-tier",
Model: "gpt-5.4", Model: "gpt-5.4",
ServiceTier: &serviceTier, RequestedModel: "gpt-5.4",
CreatedAt: createdAt, ServiceTier: &serviceTier,
CreatedAt: createdAt,
} }
mock.ExpectQuery("INSERT INTO usage_logs"). mock.ExpectQuery("INSERT INTO usage_logs").
@@ -115,6 +119,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.AccountID, log.AccountID,
log.RequestID, log.RequestID,
log.Model, log.Model,
log.RequestedModel,
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
@@ -158,6 +163,75 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
require.NoError(t, mock.ExpectationsWereMet()) require.NoError(t, mock.ExpectationsWereMet())
} }
func TestBuildUsageLogBestEffortInsertQuery_IncludesRequestedModelColumn(t *testing.T) {
prepared := prepareUsageLogInsert(&service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-best-effort-query",
Model: "gpt-5",
RequestedModel: "gpt-5",
CreatedAt: time.Date(2025, 1, 3, 12, 0, 0, 0, time.UTC),
})
query, args := buildUsageLogBestEffortInsertQuery([]usageLogInsertPrepared{prepared})
require.Contains(t, query, "INSERT INTO usage_logs (")
require.Contains(t, query, "\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,")
require.Contains(t, query, "\n\t\t\trequest_id,\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,")
require.Len(t, args, len(prepared.args))
require.Equal(t, prepared.args[5], args[5])
}
func TestExecUsageLogInsertNoResult_PersistsRequestedModel(t *testing.T) {
db, mock := newSQLMock(t)
prepared := prepareUsageLogInsert(&service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-best-effort-exec",
Model: "gpt-5",
RequestedModel: "gpt-5",
CreatedAt: time.Date(2025, 1, 4, 12, 0, 0, 0, time.UTC),
})
mock.ExpectExec("INSERT INTO usage_logs").
WithArgs(anySliceToDriverValues(prepared.args)...).
WillReturnResult(sqlmock.NewResult(0, 1))
err := execUsageLogInsertNoResult(context.Background(), db, prepared)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestPrepareUsageLogInsert_ArgCountMatchesTypes(t *testing.T) {
prepared := prepareUsageLogInsert(&service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-arg-count",
Model: "gpt-5",
RequestedModel: "gpt-5",
CreatedAt: time.Date(2025, 1, 5, 12, 0, 0, 0, time.UTC),
})
require.Len(t, prepared.args, len(usageLogInsertArgTypes))
}
func TestCoalesceTrimmedString(t *testing.T) {
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{}, "fallback"))
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{Valid: true, String: " "}, "fallback"))
require.Equal(t, "value", coalesceTrimmedString(sql.NullString{Valid: true, String: "value"}, "fallback"))
}
func anySliceToDriverValues(values []any) []driver.Value {
out := make([]driver.Value, 0, len(values))
for _, value := range values {
out = append(out, value)
}
return out
}
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t) db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db} repo := &usageLogRepository{sql: db}
@@ -354,7 +428,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(20), // api_key_id int64(20), // api_key_id
int64(30), // account_id int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"}, sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model "gpt-5", // model
sql.NullString{Valid: true, String: "gpt-5"}, // requested_model
sql.NullString{}, // upstream_model sql.NullString{}, // upstream_model
sql.NullInt64{}, // group_id sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id sql.NullInt64{}, // subscription_id
@@ -407,6 +482,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(31), int64(31),
sql.NullString{Valid: true, String: "req-2"}, sql.NullString{Valid: true, String: "req-2"},
"gpt-5", "gpt-5",
sql.NullString{Valid: true, String: "gpt-5"},
sql.NullString{}, sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
@@ -449,6 +525,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(32), int64(32),
sql.NullString{Valid: true, String: "req-3"}, sql.NullString{Valid: true, String: "req-3"},
"gpt-5.4", "gpt-5.4",
sql.NullString{Valid: true, String: "gpt-5.4"},
sql.NullString{}, sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},

View File

@@ -540,7 +540,8 @@ func TestAPIContracts(t *testing.T) {
"max_claude_code_version": "", "max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false, "allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false, "backend_mode_enabled": false,
"custom_menu_items": [] "custom_menu_items": [],
"custom_endpoints": []
} }
}`, }`,
}, },

View File

@@ -1543,6 +1543,24 @@ func isPeriodExpired(periodStart time.Time, dur time.Duration) bool {
return time.Since(periodStart) >= dur return time.Since(periodStart) >= dur
} }
// IsDailyQuotaPeriodExpired 检查日配额周期是否已过期(用于显示层判断是否需要将 used 归零)
func (a *Account) IsDailyQuotaPeriodExpired() bool {
start := a.getExtraTime("quota_daily_start")
if a.GetQuotaDailyResetMode() == "fixed" {
return a.isFixedDailyPeriodExpired(start)
}
return isPeriodExpired(start, 24*time.Hour)
}
// IsWeeklyQuotaPeriodExpired 检查周配额周期是否已过期(用于显示层判断是否需要将 used 归零)
func (a *Account) IsWeeklyQuotaPeriodExpired() bool {
start := a.getExtraTime("quota_weekly_start")
if a.GetQuotaWeeklyResetMode() == "fixed" {
return a.isFixedWeeklyPeriodExpired(start)
}
return isPeriodExpired(start, 7*24*time.Hour)
}
// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true // IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true
func (a *Account) IsQuotaExceeded() bool { func (a *Account) IsQuotaExceeded() bool {
// 总额度 // 总额度

View File

@@ -1742,7 +1742,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: billingModel, // 使用映射模型用于计费和日志 Model: originalModel,
UpstreamModel: billingModel,
Stream: claudeReq.Stream, Stream: claudeReq.Stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
@@ -2435,7 +2436,8 @@ handleSuccess:
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: billingModel, Model: originalModel,
UpstreamModel: billingModel,
Stream: stream, Stream: stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,

View File

@@ -542,7 +542,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
result, err := svc.Forward(context.Background(), c, account, body, false) result, err := svc.Forward(context.Background(), c, account, body, false)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model) require.Equal(t, "claude-sonnet-4-5", result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
} }
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel // TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
@@ -594,7 +595,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false) result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model) require.Equal(t, "gemini-2.5-flash", result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
} }
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) { func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
@@ -664,7 +666,8 @@ func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignatur
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false) result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model) require.Equal(t, originalModel, result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry") require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
firstReq := string(upstream.requestBodies[0]) firstReq := string(upstream.requestBodies[0])

View File

@@ -222,10 +222,10 @@ func (s *BillingService) initFallbackPricing() {
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
} }
s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{ s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{
InputPricePerToken: 7.5e-7, InputPricePerToken: 7.5e-7,
OutputPricePerToken: 4.5e-6, OutputPricePerToken: 4.5e-6,
CacheReadPricePerToken: 7.5e-8, CacheReadPricePerToken: 7.5e-8,
SupportsCacheBreakdown: false, SupportsCacheBreakdown: false,
} }
s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{ s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{
InputPricePerToken: 2e-7, InputPricePerToken: 2e-7,

View File

@@ -119,6 +119,7 @@ const (
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL作为 iframe src SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL作为 iframe src
SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项JSON 数组) SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项JSON 数组)
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表JSON 数组)
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量

View File

@@ -162,6 +162,32 @@ func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash) require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
} }
func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
mappedModel := "claude-sonnet-4-20250514"
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_models_split",
Usage: ClaudeUsage{InputTokens: 10, OutputTokens: 6},
Model: "claude-sonnet-4",
UpstreamModel: mappedModel,
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.Model)
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel)
}
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{} userRepo := &openAIRecordUsageUserRepoStub{}

View File

@@ -482,10 +482,12 @@ type ClaudeUsage struct {
// ForwardResult 转发结果 // ForwardResult 转发结果
type ForwardResult struct { type ForwardResult struct {
RequestID string RequestID string
Usage ClaudeUsage Usage ClaudeUsage
Model string Model string
UpstreamModel string // Actual upstream model after mapping (empty = no mapping) // UpstreamModel is the actual upstream model after mapping.
// Prefer empty when it is identical to Model; persistence normalizes equal values away as no-op mappings.
UpstreamModel string
Stream bool Stream bool
Duration time.Duration Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求) FirstTokenMs *int // 首字时间(流式请求)
@@ -4197,7 +4199,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
resp.Body = io.NopCloser(bytes.NewReader(respBody)) resp.Body = io.NopCloser(bytes.NewReader(respBody))
break break
} }
logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) logger.LegacyPrintf("service.gateway", "[warn] Account %d: thinking blocks have invalid signature, retrying with filtered blocks", account.ID)
// Conservative two-stage fallback: // Conservative two-stage fallback:
// 1) Disable thinking + thinking->text (preserve content) // 1) Disable thinking + thinking->text (preserve content)
@@ -4212,7 +4214,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil { if retryErr == nil {
if retryResp.StatusCode < 400 { if retryResp.StatusCode < 400 {
logger.LegacyPrintf("service.gateway", "Account %d: signature error retry succeeded (thinking downgraded)", account.ID) logger.LegacyPrintf("service.gateway", "Account %d: thinking block retry succeeded (blocks downgraded)", account.ID)
resp = retryResp resp = retryResp
break break
} }
@@ -6102,13 +6104,9 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
return false return false
} }
// Log for debugging
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Checking error message: %s", msg)
// 检测signature相关的错误更宽松的匹配 // 检测signature相关的错误更宽松的匹配
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等 // 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
if strings.Contains(msg, "signature") { if strings.Contains(msg, "signature") {
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected signature error")
return true return true
} }
@@ -7516,6 +7514,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
var cost *CostBreakdown var cost *CostBreakdown
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
// 根据请求类型选择计费方式 // 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" { if result.MediaType == "image" || result.MediaType == "video" {
@@ -7531,7 +7530,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.MediaType == "image" { if result.MediaType == "image" {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else { } else {
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
} }
} else if result.MediaType == "prompt" { } else if result.MediaType == "prompt" {
cost = &CostBreakdown{} cost = &CostBreakdown{}
@@ -7545,7 +7544,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
Price4K: apiKey.Group.ImagePrice4K, Price4K: apiKey.Group.ImagePrice4K,
} }
} }
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else { } else {
// Token 计费 // Token 计费
tokens := UsageTokens{ tokens := UsageTokens{
@@ -7557,7 +7556,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
} }
var err error var err error
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
if err != nil { if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0} cost = &CostBreakdown{ActualCost: 0}
@@ -7589,6 +7588,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
@@ -7719,6 +7719,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
} }
var cost *CostBreakdown var cost *CostBreakdown
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
// 根据请求类型选择计费方式 // 根据请求类型选择计费方式
if result.ImageCount > 0 { if result.ImageCount > 0 {
@@ -7731,7 +7732,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
Price4K: apiKey.Group.ImagePrice4K, Price4K: apiKey.Group.ImagePrice4K,
} }
} }
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else { } else {
// Token 计费(使用长上下文计费方法) // Token 计费(使用长上下文计费方法)
tokens := UsageTokens{ tokens := UsageTokens{
@@ -7743,7 +7744,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
} }
var err error var err error
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
if err != nil { if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0} cost = &CostBreakdown{ActualCost: 0}
@@ -7771,6 +7772,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),

View File

@@ -1028,14 +1028,15 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
} }
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
Stream: req.Stream, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: req.Stream,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
ImageCount: imageCount, FirstTokenMs: firstTokenMs,
ImageSize: imageSize, ImageCount: imageCount,
ImageSize: imageSize,
}, nil }, nil
} }
@@ -1241,12 +1242,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
estimated := estimateGeminiCountTokens(body) estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{ return &ForwardResult{
RequestID: "", RequestID: "",
Usage: ClaudeUsage{}, Usage: ClaudeUsage{},
Model: originalModel, Model: originalModel,
Stream: false, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: false,
FirstTokenMs: nil, Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil }, nil
} }
setOpsUpstreamError(c, 0, safeErr, "") setOpsUpstreamError(c, 0, safeErr, "")
@@ -1310,12 +1312,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
estimated := estimateGeminiCountTokens(body) estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{ return &ForwardResult{
RequestID: "", RequestID: "",
Usage: ClaudeUsage{}, Usage: ClaudeUsage{},
Model: originalModel, Model: originalModel,
Stream: false, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: false,
FirstTokenMs: nil, Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil }, nil
} }
// Final attempt: surface the upstream error body (passed through below) instead of a generic retry error. // Final attempt: surface the upstream error body (passed through below) instead of a generic retry error.
@@ -1350,12 +1353,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
estimated := estimateGeminiCountTokens(body) estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: ClaudeUsage{}, Usage: ClaudeUsage{},
Model: originalModel, Model: originalModel,
Stream: false, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: false,
FirstTokenMs: nil, Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil }, nil
} }
@@ -1527,14 +1531,15 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
} }
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
Stream: stream, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: stream,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
ImageCount: imageCount, FirstTokenMs: firstTokenMs,
ImageSize: imageSize, ImageCount: imageCount,
ImageSize: imageSize,
}, nil }, nil
} }

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -15,6 +16,30 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type geminiCompatHTTPUpstreamStub struct {
response *http.Response
err error
calls int
lastReq *http.Request
}
func (s *geminiCompatHTTPUpstreamStub) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
s.calls++
s.lastReq = req
if s.err != nil {
return nil, s.err
}
if s.response == nil {
return nil, fmt.Errorf("missing stub response")
}
resp := *s.response
return &resp, nil
}
func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 // TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
tests := []struct { tests := []struct {
@@ -170,6 +195,42 @@ func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLo
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志") require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
} }
func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpstreamModel(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
httpStub := &geminiCompatHTTPUpstreamStub{
response: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"x-request-id": []string{"gemini-req-1"}},
Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}`)),
},
}
svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}}
account := &Account{
ID: 1,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-key",
"model_mapping": map[string]any{
"claude-sonnet-4": "claude-sonnet-4-20250514",
},
},
}
body := []byte(`{"model":"claude-sonnet-4","max_tokens":16,"messages":[{"role":"user","content":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "claude-sonnet-4", result.Model)
require.Equal(t, "claude-sonnet-4-20250514", result.UpstreamModel)
require.Equal(t, 1, httpStub.calls)
require.NotNil(t, httpStub.lastReq)
require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:")
}
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
claudeReq := map[string]any{ claudeReq := map[string]any{
"model": "claude-haiku-4-5-20251001", "model": "claude-haiku-4-5-20251001",

View File

@@ -879,6 +879,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog) require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model) require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
require.Equal(t, "gpt-5.1", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel) require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel) require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
require.NotNil(t, usageRepo.lastLog.ServiceTier) require.NotNil(t, usageRepo.lastLog.ServiceTier)
@@ -894,6 +895,40 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
require.Equal(t, 1, userRepo.deductCalls) require.Equal(t, 1, userRepo.deductCalls)
} }
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_upstream_model_billing_fallback",
Model: "gpt-5.1",
UpstreamModel: "gpt-5.1-codex",
Usage: usage,
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost)
require.Equal(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost)
require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount)
}
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{} userRepo := &openAIRecordUsageUserRepoStub{}

View File

@@ -4110,9 +4110,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
} }
billingModel := result.Model billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if result.BillingModel != "" { if result.BillingModel != "" {
billingModel = result.BillingModel billingModel = strings.TrimSpace(result.BillingModel)
} }
serviceTier := "" serviceTier := ""
if result.ServiceTier != nil { if result.ServiceTier != nil {
@@ -4140,6 +4140,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier, ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,

View File

@@ -68,3 +68,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
}) })
} }
} }
func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) {
account := &Account{
Credentials: map[string]any{},
}
withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "")
if got := normalizeCodexModel(withoutDefault); got != "gpt-5.1" {
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withoutDefault, got, "gpt-5.1")
}
withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")
if got := normalizeCodexModel(withDefault); got != "gpt-5.4" {
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withDefault, got, "gpt-5.4")
}
}

View File

@@ -2328,6 +2328,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
RequestID: responseID, RequestID: responseID,
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
UpstreamModel: mappedModel,
ServiceTier: extractOpenAIServiceTier(reqBody), ServiceTier: extractOpenAIServiceTier(reqBody),
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
Stream: reqStream, Stream: reqStream,
@@ -2945,6 +2946,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
RequestID: responseID, RequestID: responseID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
UpstreamModel: mappedModel,
ServiceTier: extractOpenAIServiceTierFromBody(payload), ServiceTier: extractOpenAIServiceTierFromBody(payload),
ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
Stream: reqStream, Stream: reqStream,

View File

@@ -88,6 +88,7 @@ func (s *OpsAlertEvaluatorService) Start() {
if s.stopCh == nil { if s.stopCh == nil {
s.stopCh = make(chan struct{}) s.stopCh = make(chan struct{})
} }
s.wg.Add(1)
go s.run() go s.run()
}) })
} }
@@ -105,7 +106,6 @@ func (s *OpsAlertEvaluatorService) Stop() {
} }
func (s *OpsAlertEvaluatorService) run() { func (s *OpsAlertEvaluatorService) run() {
s.wg.Add(1)
defer s.wg.Done() defer s.wg.Done()
// Start immediately to produce early feedback in ops dashboard. // Start immediately to produce early feedback in ops dashboard.
@@ -848,7 +848,9 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc
return nil, false return nil, false
} }
return func() { return func() {
_, _ = opsAlertEvaluatorReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result() releaseCtx, releaseCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer releaseCancel()
_, _ = opsAlertEvaluatorReleaseScript.Run(releaseCtx, s.redisClient, []string{key}, s.instanceID).Result()
}, true }, true
} }

View File

@@ -150,6 +150,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyPurchaseSubscriptionURL, SettingKeyPurchaseSubscriptionURL,
SettingKeySoraClientEnabled, SettingKeySoraClientEnabled,
SettingKeyCustomMenuItems, SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled, SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled, SettingKeyBackendModeEnabled,
} }
@@ -195,6 +196,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled, LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
}, nil }, nil
@@ -247,6 +249,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
@@ -272,6 +275,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
Version: s.version, Version: s.version,
@@ -314,6 +318,18 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
return result return result
} }
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
func safeRawJSONArray(raw string) json.RawMessage {
raw = strings.TrimSpace(raw)
if raw == "" {
return json.RawMessage("[]")
}
if json.Valid([]byte(raw)) {
return json.RawMessage(raw)
}
return json.RawMessage("[]")
}
// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url // GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url
// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. // and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection.
func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) {
@@ -454,6 +470,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
// 默认配置 // 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
@@ -740,6 +757,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyPurchaseSubscriptionURL: "", SettingKeyPurchaseSubscriptionURL: "",
SettingKeySoraClientEnabled: "false", SettingKeySoraClientEnabled: "false",
SettingKeyCustomMenuItems: "[]", SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]", SettingKeyDefaultSubscriptions: "[]",
@@ -805,6 +823,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
} }

View File

@@ -43,6 +43,7 @@ type SystemSettings struct {
PurchaseSubscriptionURL string PurchaseSubscriptionURL string
SoraClientEnabled bool SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
DefaultConcurrency int DefaultConcurrency int
DefaultBalance float64 DefaultBalance float64
@@ -104,6 +105,7 @@ type PublicSettings struct {
PurchaseSubscriptionURL string PurchaseSubscriptionURL string
SoraClientEnabled bool SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool LinuxDoOAuthEnabled bool
BackendModeEnabled bool BackendModeEnabled bool

View File

@@ -148,10 +148,13 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream)
return nil, errors.New("model is required") return nil, errors.New("model is required")
} }
originalModel := reqModel
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
var upstreamModel string
if mappedModel != "" && mappedModel != reqModel { if mappedModel != "" && mappedModel != reqModel {
reqModel = mappedModel reqModel = mappedModel
upstreamModel = mappedModel
} }
modelCfg, ok := GetSoraModelConfig(reqModel) modelCfg, ok := GetSoraModelConfig(reqModel)
@@ -213,13 +216,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
} }
return &ForwardResult{ return &ForwardResult{
RequestID: "", RequestID: "",
Model: reqModel, Model: originalModel,
Stream: clientStream, UpstreamModel: upstreamModel,
Duration: time.Since(startTime), Stream: clientStream,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
Usage: ClaudeUsage{}, FirstTokenMs: firstTokenMs,
MediaType: "prompt", Usage: ClaudeUsage{},
MediaType: "prompt",
}, nil }, nil
} }
@@ -269,13 +273,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
return &ForwardResult{ return &ForwardResult{
RequestID: "", RequestID: "",
Model: reqModel, Model: originalModel,
Stream: clientStream, UpstreamModel: upstreamModel,
Duration: time.Since(startTime), Stream: clientStream,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
Usage: ClaudeUsage{}, FirstTokenMs: firstTokenMs,
MediaType: "prompt", Usage: ClaudeUsage{},
MediaType: "prompt",
}, nil }, nil
} }
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
@@ -419,16 +424,17 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
} }
return &ForwardResult{ return &ForwardResult{
RequestID: taskID, RequestID: taskID,
Model: reqModel, Model: originalModel,
Stream: clientStream, UpstreamModel: upstreamModel,
Duration: time.Since(startTime), Stream: clientStream,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
Usage: ClaudeUsage{}, FirstTokenMs: firstTokenMs,
MediaType: mediaType, Usage: ClaudeUsage{},
MediaURL: firstMediaURL(finalURLs), MediaType: mediaType,
ImageCount: imageCount, MediaURL: firstMediaURL(finalURLs),
ImageSize: imageSize, ImageCount: imageCount,
ImageSize: imageSize,
}, nil }, nil
} }

View File

@@ -144,6 +144,11 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
ID: 1, ID: 1,
Platform: PlatformSora, Platform: PlatformSora,
Status: StatusActive, Status: StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"prompt-enhance-short-10s": "prompt-enhance-short-15s",
},
},
} }
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
@@ -152,6 +157,7 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, "prompt", result.MediaType) require.Equal(t, "prompt", result.MediaType)
require.Equal(t, "prompt-enhance-short-10s", result.Model) require.Equal(t, "prompt-enhance-short-10s", result.Model)
require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel)
} }
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {

View File

@@ -98,6 +98,9 @@ type UsageLog struct {
AccountID int64 AccountID int64
RequestID string RequestID string
Model string Model string
// RequestedModel is the client-requested model name recorded for stable user/admin display.
// Empty should be treated as Model for backward compatibility with historical rows.
RequestedModel string
// UpstreamModel is the actual model sent to the upstream provider after mapping. // UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is). // Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string UpstreamModel *string

View File

@@ -19,3 +19,10 @@ func optionalNonEqualStringPtr(value, compare string) *string {
} }
return &value return &value
} }
func forwardResultBillingModel(requestedModel, upstreamModel string) string {
if trimmedUpstream := strings.TrimSpace(upstreamModel); trimmedUpstream != "" {
return trimmedUpstream
}
return strings.TrimSpace(requestedModel)
}

View File

@@ -0,0 +1,3 @@
-- Add requested_model field to usage_logs for normalized request/upstream model tracking.
-- NULL means historical rows written before requested_model dual-write was introduced.
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100);

View File

@@ -0,0 +1,3 @@
-- Support requested_model / upstream_model aggregations with time-range filters.
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_created_requested_model_upstream_model
ON usage_logs (created_at, requested_model, upstream_model);

View File

@@ -4,7 +4,7 @@
*/ */
import { apiClient } from '../client' import { apiClient } from '../client'
import type { CustomMenuItem } from '@/types' import type { CustomMenuItem, CustomEndpoint } from '@/types'
export interface DefaultSubscriptionSetting { export interface DefaultSubscriptionSetting {
group_id: number group_id: number
@@ -43,6 +43,7 @@ export interface SystemSettings {
sora_client_enabled: boolean sora_client_enabled: boolean
backend_mode_enabled: boolean backend_mode_enabled: boolean
custom_menu_items: CustomMenuItem[] custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[]
// SMTP settings // SMTP settings
smtp_host: string smtp_host: string
smtp_port: number smtp_port: number
@@ -112,6 +113,7 @@ export interface UpdateSettingsRequest {
sora_client_enabled?: boolean sora_client_enabled?: boolean
backend_mode_enabled?: boolean backend_mode_enabled?: boolean
custom_menu_items?: CustomMenuItem[] custom_menu_items?: CustomMenuItem[]
custom_endpoints?: CustomEndpoint[]
smtp_host?: string smtp_host?: string
smtp_port?: number smtp_port?: number
smtp_username?: string smtp_username?: string

View File

@@ -39,6 +39,7 @@ const DataTableStub = {
template: ` template: `
<div> <div>
<div v-for="row in data" :key="row.request_id"> <div v-for="row in data" :key="row.request_id">
<slot name="cell-model" :row="row" :value="row.model" />
<slot name="cell-cost" :row="row" /> <slot name="cell-cost" :row="row" />
</div> </div>
</div> </div>
@@ -108,4 +109,42 @@ describe('admin UsageTable tooltip', () => {
expect(text).toContain('$30.0000 / 1M tokens') expect(text).toContain('$30.0000 / 1M tokens')
expect(text).toContain('$0.069568') expect(text).toContain('$0.069568')
}) })
it('shows requested and upstream models separately for admin rows', () => {
const row = {
request_id: 'req-admin-model-1',
model: 'claude-sonnet-4',
upstream_model: 'claude-sonnet-4-20250514',
actual_cost: 0,
total_cost: 0,
account_rate_multiplier: 1,
rate_multiplier: 1,
input_cost: 0,
output_cost: 0,
cache_creation_cost: 0,
cache_read_cost: 0,
input_tokens: 0,
output_tokens: 0,
}
const wrapper = mount(UsageTable, {
props: {
data: [row],
loading: false,
columns: [],
},
global: {
stubs: {
DataTable: DataTableStub,
EmptyState: true,
Icon: true,
Teleport: true,
},
},
})
const text = wrapper.text()
expect(text).toContain('claude-sonnet-4')
expect(text).toContain('claude-sonnet-4-20250514')
})
}) })

View File

@@ -0,0 +1,141 @@
<script setup lang="ts">
import { computed, onBeforeUnmount, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import { useClipboard } from '@/composables/useClipboard'
import type { CustomEndpoint } from '@/types'
const props = defineProps<{
apiBaseUrl: string
customEndpoints: CustomEndpoint[]
}>()
const { t } = useI18n()
const { copyToClipboard } = useClipboard()
const copiedEndpoint = ref<string | null>(null)
let copiedResetTimer: number | undefined
const allEndpoints = computed(() => {
const items: Array<{ name: string; endpoint: string; description: string; isDefault: boolean }> = []
if (props.apiBaseUrl) {
items.push({
name: t('keys.endpoints.title'),
endpoint: props.apiBaseUrl,
description: '',
isDefault: true,
})
}
for (const ep of props.customEndpoints) {
items.push({ ...ep, isDefault: false })
}
return items
})
async function copy(url: string) {
const success = await copyToClipboard(url, t('keys.endpoints.copied'))
if (!success) return
copiedEndpoint.value = url
if (copiedResetTimer !== undefined) {
window.clearTimeout(copiedResetTimer)
}
copiedResetTimer = window.setTimeout(() => {
if (copiedEndpoint.value === url) {
copiedEndpoint.value = null
}
}, 1800)
}
function tooltipHint(endpoint: string): string {
return copiedEndpoint.value === endpoint
? t('keys.endpoints.copiedHint')
: t('keys.endpoints.clickToCopy')
}
function speedTestUrl(endpoint: string): string {
return `https://www.tcptest.cn/http/${encodeURIComponent(endpoint)}`
}
onBeforeUnmount(() => {
if (copiedResetTimer !== undefined) {
window.clearTimeout(copiedResetTimer)
}
})
</script>
<template>
<div v-if="allEndpoints.length > 0" class="flex flex-wrap gap-2">
<div
v-for="(item, index) in allEndpoints"
:key="index"
class="flex items-center gap-1.5 rounded-lg border border-gray-200 bg-white px-2.5 py-1.5 text-xs transition-colors hover:border-primary-200 dark:border-dark-600 dark:bg-dark-800 dark:hover:border-primary-700"
>
<span class="font-medium text-gray-600 dark:text-gray-300">{{ item.name }}</span>
<span
v-if="item.isDefault"
class="rounded bg-primary-50 px-1 py-px text-[10px] font-medium leading-tight text-primary-600 dark:bg-primary-900/30 dark:text-primary-400"
>{{ t('keys.endpoints.default') }}</span>
<span class="text-gray-300 dark:text-dark-500">|</span>
<div class="group/endpoint relative flex items-center gap-1.5">
<div
class="pointer-events-none absolute bottom-full left-1/2 z-20 mb-2 w-max max-w-[24rem] -translate-x-1/2 translate-y-1 rounded-xl border border-slate-200 bg-white px-3 py-2.5 text-left opacity-0 shadow-[0_14px_36px_-20px_rgba(15,23,42,0.35)] ring-1 ring-slate-200/80 transition-all duration-150 group-hover/endpoint:translate-y-0 group-hover/endpoint:opacity-100 group-focus-within/endpoint:translate-y-0 group-focus-within/endpoint:opacity-100 dark:border-slate-700 dark:bg-slate-900 dark:ring-slate-700/70"
>
<p
v-if="item.description"
class="max-w-[24rem] break-words text-xs leading-5 text-slate-600 dark:text-slate-200"
>
{{ item.description }}
</p>
<p
class="flex items-center gap-1.5 text-[11px] leading-4 text-primary-600 dark:text-primary-300"
:class="item.description ? 'mt-1.5' : ''"
>
<span class="h-1.5 w-1.5 rounded-full bg-primary-500 dark:bg-primary-300"></span>
{{ tooltipHint(item.endpoint) }}
</p>
<div class="absolute left-1/2 top-full h-3 w-3 -translate-x-1/2 -translate-y-1/2 rotate-45 border-b border-r border-slate-200 bg-white dark:border-slate-700 dark:bg-slate-900"></div>
</div>
<code
class="cursor-pointer font-mono text-gray-500 decoration-gray-400 decoration-dashed underline-offset-2 hover:text-primary-600 hover:underline focus:text-primary-600 focus:underline focus:outline-none dark:text-gray-400 dark:decoration-gray-500 dark:hover:text-primary-400 dark:focus:text-primary-400"
role="button"
tabindex="0"
@click="copy(item.endpoint)"
@keydown.enter.prevent="copy(item.endpoint)"
@keydown.space.prevent="copy(item.endpoint)"
>{{ item.endpoint }}</code>
<button
type="button"
class="rounded p-0.5 transition-colors"
:class="copiedEndpoint === item.endpoint
? 'text-emerald-500 dark:text-emerald-400'
: 'text-gray-400 hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400'"
:aria-label="tooltipHint(item.endpoint)"
@click="copy(item.endpoint)"
>
<svg v-if="copiedEndpoint === item.endpoint" class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2.2">
<path stroke-linecap="round" stroke-linejoin="round" d="M5 13l4 4L19 7" />
</svg>
<svg v-else class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z" />
</svg>
</button>
<a
:href="speedTestUrl(item.endpoint)"
target="_blank"
rel="noopener noreferrer"
class="rounded p-0.5 text-gray-400 transition-colors hover:text-amber-500 dark:text-gray-500 dark:hover:text-amber-400"
:title="t('keys.endpoints.speedTest')"
>
<svg class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M13 10V3L4 14h7v7l9-11h-7z" />
</svg>
</a>
</div>
</div>
</div>
</template>

View File

@@ -0,0 +1,69 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { flushPromises, mount } from '@vue/test-utils'
const copyToClipboard = vi.fn().mockResolvedValue(true)
const messages: Record<string, string> = {
'keys.endpoints.title': 'API 端点',
'keys.endpoints.default': '默认',
'keys.endpoints.copied': '已复制',
'keys.endpoints.copiedHint': '已复制到剪贴板',
'keys.endpoints.clickToCopy': '点击可复制此端点',
'keys.endpoints.speedTest': '测速',
}
vi.mock('vue-i18n', () => ({
useI18n: () => ({
t: (key: string) => messages[key] ?? key,
}),
}))
vi.mock('@/composables/useClipboard', () => ({
useClipboard: () => ({
copyToClipboard,
}),
}))
import EndpointPopover from '../EndpointPopover.vue'
describe('EndpointPopover', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('将说明提示渲染到 URL 上方而不是旧的 title 图标上', () => {
const wrapper = mount(EndpointPopover, {
props: {
apiBaseUrl: 'https://default.example.com/v1',
customEndpoints: [
{
name: '备用线路',
endpoint: 'https://backup.example.com/v1',
description: '自定义说明',
},
],
},
})
expect(wrapper.text()).toContain('自定义说明')
expect(wrapper.text()).toContain('点击可复制此端点')
expect(wrapper.find('[role="button"]').attributes('title')).toBeUndefined()
expect(wrapper.find('[title="自定义说明"]').exists()).toBe(false)
})
it('点击 URL 后会复制并切换为已复制提示', async () => {
const wrapper = mount(EndpointPopover, {
props: {
apiBaseUrl: 'https://default.example.com/v1',
customEndpoints: [],
},
})
await wrapper.find('[role="button"]').trigger('click')
await flushPromises()
expect(copyToClipboard).toHaveBeenCalledWith('https://default.example.com/v1', '已复制')
expect(wrapper.text()).toContain('已复制到剪贴板')
expect(wrapper.find('button[aria-label="已复制到剪贴板"]').exists()).toBe(true)
})
})

View File

@@ -533,6 +533,14 @@ export default {
title: 'API Keys', title: 'API Keys',
description: 'Manage your API keys and access tokens', description: 'Manage your API keys and access tokens',
searchPlaceholder: 'Search name or key...', searchPlaceholder: 'Search name or key...',
endpoints: {
title: 'API Endpoints',
default: 'Default',
copied: 'Copied',
copiedHint: 'Copied to clipboard',
clickToCopy: 'Click to copy this endpoint',
speedTest: 'Speed Test',
},
allGroups: 'All Groups', allGroups: 'All Groups',
allStatus: 'All Status', allStatus: 'All Status',
createKey: 'Create API Key', createKey: 'Create API Key',
@@ -4162,6 +4170,18 @@ export default {
apiBaseUrlPlaceholder: 'https://api.example.com', apiBaseUrlPlaceholder: 'https://api.example.com',
apiBaseUrlHint: apiBaseUrlHint:
'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.', 'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.',
customEndpoints: {
title: 'Custom Endpoints',
description: 'Add additional API endpoint URLs for users to quickly copy on the API Keys page',
itemLabel: 'Endpoint #{n}',
name: 'Name',
namePlaceholder: 'e.g., OpenAI Compatible',
endpointUrl: 'Endpoint URL',
endpointUrlPlaceholder: 'https://api2.example.com',
descriptionLabel: 'Description',
descriptionPlaceholder: 'e.g., Supports OpenAI format requests',
add: 'Add Endpoint',
},
contactInfo: 'Contact Info', contactInfo: 'Contact Info',
contactInfoPlaceholder: 'e.g., QQ: 123456789', contactInfoPlaceholder: 'e.g., QQ: 123456789',
contactInfoHint: 'Customer support contact info, displayed on redeem page, profile, etc.', contactInfoHint: 'Customer support contact info, displayed on redeem page, profile, etc.',

View File

@@ -533,6 +533,14 @@ export default {
title: 'API 密钥', title: 'API 密钥',
description: '管理您的 API 密钥和访问令牌', description: '管理您的 API 密钥和访问令牌',
searchPlaceholder: '搜索名称或Key...', searchPlaceholder: '搜索名称或Key...',
endpoints: {
title: 'API 端点',
default: '默认',
copied: '已复制',
copiedHint: '已复制到剪贴板',
clickToCopy: '点击可复制此端点',
speedTest: '测速',
},
allGroups: '全部分组', allGroups: '全部分组',
allStatus: '全部状态', allStatus: '全部状态',
createKey: '创建密钥', createKey: '创建密钥',
@@ -4324,6 +4332,18 @@ export default {
apiBaseUrl: 'API 端点地址', apiBaseUrl: 'API 端点地址',
apiBaseUrlHint: '用于"使用密钥"和"导入到 CC Switch"功能,留空则使用当前站点地址', apiBaseUrlHint: '用于"使用密钥"和"导入到 CC Switch"功能,留空则使用当前站点地址',
apiBaseUrlPlaceholder: 'https://api.example.com', apiBaseUrlPlaceholder: 'https://api.example.com',
customEndpoints: {
title: '自定义端点',
description: '添加额外的 API 端点地址用户可在「API Keys」页面快速复制',
itemLabel: '端点 #{n}',
name: '名称',
namePlaceholder: '如OpenAI Compatible',
endpointUrl: '端点地址',
endpointUrlPlaceholder: 'https://api2.example.com',
descriptionLabel: '介绍',
descriptionPlaceholder: '如:支持 OpenAI 格式请求',
add: '添加端点',
},
contactInfo: '客服联系方式', contactInfo: '客服联系方式',
contactInfoPlaceholder: '例如QQ: 123456789', contactInfoPlaceholder: '例如QQ: 123456789',
contactInfoHint: '填写客服联系方式,将展示在兑换页面、个人资料等位置', contactInfoHint: '填写客服联系方式,将展示在兑换页面、个人资料等位置',

View File

@@ -330,6 +330,7 @@ export const useAppStore = defineStore('app', () => {
purchase_subscription_enabled: false, purchase_subscription_enabled: false,
purchase_subscription_url: '', purchase_subscription_url: '',
custom_menu_items: [], custom_menu_items: [],
custom_endpoints: [],
linuxdo_oauth_enabled: false, linuxdo_oauth_enabled: false,
sora_client_enabled: false, sora_client_enabled: false,
backend_mode_enabled: false, backend_mode_enabled: false,

View File

@@ -84,6 +84,12 @@ export interface CustomMenuItem {
sort_order: number sort_order: number
} }
export interface CustomEndpoint {
name: string
endpoint: string
description: string
}
export interface PublicSettings { export interface PublicSettings {
registration_enabled: boolean registration_enabled: boolean
email_verify_enabled: boolean email_verify_enabled: boolean
@@ -104,6 +110,7 @@ export interface PublicSettings {
purchase_subscription_enabled: boolean purchase_subscription_enabled: boolean
purchase_subscription_url: string purchase_subscription_url: string
custom_menu_items: CustomMenuItem[] custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[]
linuxdo_oauth_enabled: boolean linuxdo_oauth_enabled: boolean
sora_client_enabled: boolean sora_client_enabled: boolean
backend_mode_enabled: boolean backend_mode_enabled: boolean
@@ -978,7 +985,6 @@ export interface UsageLog {
account_id: number | null account_id: number | null
request_id: string request_id: string
model: string model: string
upstream_model?: string | null
service_tier?: string | null service_tier?: string | null
reasoning_effort?: string | null reasoning_effort?: string | null
inbound_endpoint?: string | null inbound_endpoint?: string | null
@@ -1033,6 +1039,8 @@ export interface UsageLogAccountSummary {
} }
export interface AdminUsageLog extends UsageLog { export interface AdminUsageLog extends UsageLog {
upstream_model?: string | null
// 账号计费倍率(仅管理员可见) // 账号计费倍率(仅管理员可见)
account_rate_multiplier?: number | null account_rate_multiplier?: number | null

View File

@@ -7,7 +7,7 @@
</div> </div>
<!-- Settings Form --> <!-- Settings Form -->
<form v-else @submit.prevent="saveSettings" class="space-y-6"> <form v-else @submit.prevent="saveSettings" class="space-y-6" novalidate>
<!-- Tab Navigation --> <!-- Tab Navigation -->
<div class="sticky top-0 z-10 overflow-x-auto settings-tabs-scroll"> <div class="sticky top-0 z-10 overflow-x-auto settings-tabs-scroll">
<nav class="settings-tabs"> <nav class="settings-tabs">
@@ -1248,6 +1248,81 @@
</p> </p>
</div> </div>
<!-- Custom Endpoints -->
<div>
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.settings.site.customEndpoints.title') }}
</label>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.description') }}
</p>
<div class="space-y-3">
<div
v-for="(ep, index) in form.custom_endpoints"
:key="index"
class="rounded-lg border border-gray-200 p-4 dark:border-dark-600"
>
<div class="mb-3 flex items-center justify-between">
<span class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.settings.site.customEndpoints.itemLabel', { n: index + 1 }) }}
</span>
<button
type="button"
class="rounded p-1 text-red-400 hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
@click="removeEndpoint(index)"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2"><path stroke-linecap="round" stroke-linejoin="round" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" /></svg>
</button>
</div>
<div class="grid grid-cols-1 gap-3 sm:grid-cols-2">
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.name') }}
</label>
<input
v-model="ep.name"
type="text"
class="input text-sm"
:placeholder="t('admin.settings.site.customEndpoints.namePlaceholder')"
/>
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.endpointUrl') }}
</label>
<input
v-model="ep.endpoint"
type="url"
class="input font-mono text-sm"
:placeholder="t('admin.settings.site.customEndpoints.endpointUrlPlaceholder')"
/>
</div>
<div class="sm:col-span-2">
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
{{ t('admin.settings.site.customEndpoints.descriptionLabel') }}
</label>
<input
v-model="ep.description"
type="text"
class="input text-sm"
:placeholder="t('admin.settings.site.customEndpoints.descriptionPlaceholder')"
/>
</div>
</div>
</div>
</div>
<button
type="button"
class="mt-3 flex w-full items-center justify-center gap-2 rounded-lg border-2 border-dashed border-gray-300 px-4 py-2.5 text-sm text-gray-500 transition-colors hover:border-primary-400 hover:text-primary-600 dark:border-dark-600 dark:text-gray-400 dark:hover:border-primary-500 dark:hover:text-primary-400"
@click="addEndpoint"
>
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2"><path stroke-linecap="round" stroke-linejoin="round" d="M12 4v16m8-8H4" /></svg>
{{ t('admin.settings.site.customEndpoints.add') }}
</button>
</div>
<!-- Contact Info --> <!-- Contact Info -->
<div> <div>
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300"> <label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
@@ -1945,6 +2020,7 @@ const form = reactive<SettingsForm>({
purchase_subscription_url: '', purchase_subscription_url: '',
sora_client_enabled: false, sora_client_enabled: false,
custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>, custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>,
custom_endpoints: [] as Array<{name: string; endpoint: string; description: string}>,
frontend_url: '', frontend_url: '',
smtp_host: '', smtp_host: '',
smtp_port: 587, smtp_port: 587,
@@ -2114,6 +2190,15 @@ function moveMenuItem(index: number, direction: -1 | 1) {
}) })
} }
// Custom endpoint management
function addEndpoint() {
form.custom_endpoints.push({ name: '', endpoint: '', description: '' })
}
function removeEndpoint(index: number) {
form.custom_endpoints.splice(index, 1)
}
async function loadSettings() { async function loadSettings() {
loading.value = true loading.value = true
try { try {
@@ -2198,6 +2283,35 @@ async function saveSettings() {
return return
} }
// Validate URL fields — novalidate disables browser-native checks, so we validate here
const isValidHttpUrl = (url: string): boolean => {
if (!url) return true
try {
const u = new URL(url)
return u.protocol === 'http:' || u.protocol === 'https:'
} catch {
return false
}
}
// Optional URL fields: auto-clear invalid values so they don't cause backend 400 errors
if (!isValidHttpUrl(form.frontend_url)) form.frontend_url = ''
if (!isValidHttpUrl(form.doc_url)) form.doc_url = ''
// Purchase URL: required when enabled; auto-clear when disabled to avoid backend rejection
if (form.purchase_subscription_enabled) {
if (!form.purchase_subscription_url) {
appStore.showError(t('admin.settings.purchase.url') + ': URL is required when purchase is enabled')
saving.value = false
return
}
if (!isValidHttpUrl(form.purchase_subscription_url)) {
appStore.showError(t('admin.settings.purchase.url') + ': must be an absolute http(s) URL (e.g. https://example.com)')
saving.value = false
return
}
} else if (!isValidHttpUrl(form.purchase_subscription_url)) {
form.purchase_subscription_url = ''
}
const payload: UpdateSettingsRequest = { const payload: UpdateSettingsRequest = {
registration_enabled: form.registration_enabled, registration_enabled: form.registration_enabled,
email_verify_enabled: form.email_verify_enabled, email_verify_enabled: form.email_verify_enabled,
@@ -2224,6 +2338,7 @@ async function saveSettings() {
purchase_subscription_url: form.purchase_subscription_url, purchase_subscription_url: form.purchase_subscription_url,
sora_client_enabled: form.sora_client_enabled, sora_client_enabled: form.sora_client_enabled,
custom_menu_items: form.custom_menu_items, custom_menu_items: form.custom_menu_items,
custom_endpoints: form.custom_endpoints,
frontend_url: form.frontend_url, frontend_url: form.frontend_url,
smtp_host: form.smtp_host, smtp_host: form.smtp_host,
smtp_port: form.smtp_port, smtp_port: form.smtp_port,

View File

@@ -2,24 +2,31 @@
<AppLayout> <AppLayout>
<TablePageLayout> <TablePageLayout>
<template #filters> <template #filters>
<div class="flex flex-wrap items-center gap-3"> <div class="flex flex-col gap-3">
<SearchInput <div class="flex flex-wrap items-center gap-3">
v-model="filterSearch" <SearchInput
:placeholder="t('keys.searchPlaceholder')" v-model="filterSearch"
class="w-full sm:w-64" :placeholder="t('keys.searchPlaceholder')"
@search="onFilterChange" class="w-full sm:w-64"
/> @search="onFilterChange"
<Select />
:model-value="filterGroupId" <Select
class="w-40" :model-value="filterGroupId"
:options="groupFilterOptions" class="w-40"
@update:model-value="onGroupFilterChange" :options="groupFilterOptions"
/> @update:model-value="onGroupFilterChange"
<Select />
:model-value="filterStatus" <Select
class="w-40" :model-value="filterStatus"
:options="statusFilterOptions" class="w-40"
@update:model-value="onStatusFilterChange" :options="statusFilterOptions"
@update:model-value="onStatusFilterChange"
/>
</div>
<EndpointPopover
v-if="publicSettings?.api_base_url || (publicSettings?.custom_endpoints?.length ?? 0) > 0"
:api-base-url="publicSettings?.api_base_url || ''"
:custom-endpoints="publicSettings?.custom_endpoints || []"
/> />
</div> </div>
</template> </template>
@@ -1050,6 +1057,7 @@ import TablePageLayout from '@/components/layout/TablePageLayout.vue'
import SearchInput from '@/components/common/SearchInput.vue' import SearchInput from '@/components/common/SearchInput.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import UseKeyModal from '@/components/keys/UseKeyModal.vue' import UseKeyModal from '@/components/keys/UseKeyModal.vue'
import EndpointPopover from '@/components/keys/EndpointPopover.vue'
import GroupBadge from '@/components/common/GroupBadge.vue' import GroupBadge from '@/components/common/GroupBadge.vue'
import GroupOptionItem from '@/components/common/GroupOptionItem.vue' import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import type { ApiKey, Group, PublicSettings, SubscriptionType, GroupPlatform } from '@/types' import type { ApiKey, Group, PublicSettings, SubscriptionType, GroupPlatform } from '@/types'