mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
Merge branch 'main' of github.com:InCerryGit/sub2api
# Conflicts: # backend/internal/service/billing_service.go
This commit is contained in:
@@ -716,6 +716,7 @@ var (
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "request_id", Type: field.TypeString, Size: 64},
|
||||
{Name: "model", Type: field.TypeString, Size: 100},
|
||||
{Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100},
|
||||
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
|
||||
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
|
||||
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
|
||||
@@ -756,31 +757,31 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "usage_logs_api_keys_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_accounts_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_groups_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_users_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@@ -789,38 +790,43 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_account_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[30]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[31]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_subscription_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[29]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_model",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[2]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_requested_model",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[3]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_request_id",
|
||||
Unique: false,
|
||||
@@ -829,17 +835,17 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[29]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[29]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[29]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -18239,6 +18239,7 @@ type UsageLogMutation struct {
|
||||
id *int64
|
||||
request_id *string
|
||||
model *string
|
||||
requested_model *string
|
||||
upstream_model *string
|
||||
input_tokens *int
|
||||
addinput_tokens *int
|
||||
@@ -18577,6 +18578,55 @@ func (m *UsageLogMutation) ResetModel() {
|
||||
m.model = nil
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (m *UsageLogMutation) SetRequestedModel(s string) {
|
||||
m.requested_model = &s
|
||||
}
|
||||
|
||||
// RequestedModel returns the value of the "requested_model" field in the mutation.
|
||||
func (m *UsageLogMutation) RequestedModel() (r string, exists bool) {
|
||||
v := m.requested_model
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldRequestedModel returns the old "requested_model" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldRequestedModel(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldRequestedModel is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldRequestedModel requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldRequestedModel: %w", err)
|
||||
}
|
||||
return oldValue.RequestedModel, nil
|
||||
}
|
||||
|
||||
// ClearRequestedModel clears the value of the "requested_model" field.
|
||||
func (m *UsageLogMutation) ClearRequestedModel() {
|
||||
m.requested_model = nil
|
||||
m.clearedFields[usagelog.FieldRequestedModel] = struct{}{}
|
||||
}
|
||||
|
||||
// RequestedModelCleared returns if the "requested_model" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) RequestedModelCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldRequestedModel]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetRequestedModel resets all changes to the "requested_model" field.
|
||||
func (m *UsageLogMutation) ResetRequestedModel() {
|
||||
m.requested_model = nil
|
||||
delete(m.clearedFields, usagelog.FieldRequestedModel)
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (m *UsageLogMutation) SetUpstreamModel(s string) {
|
||||
m.upstream_model = &s
|
||||
@@ -20247,7 +20297,7 @@ func (m *UsageLogMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *UsageLogMutation) Fields() []string {
|
||||
fields := make([]string, 0, 33)
|
||||
fields := make([]string, 0, 34)
|
||||
if m.user != nil {
|
||||
fields = append(fields, usagelog.FieldUserID)
|
||||
}
|
||||
@@ -20263,6 +20313,9 @@ func (m *UsageLogMutation) Fields() []string {
|
||||
if m.model != nil {
|
||||
fields = append(fields, usagelog.FieldModel)
|
||||
}
|
||||
if m.requested_model != nil {
|
||||
fields = append(fields, usagelog.FieldRequestedModel)
|
||||
}
|
||||
if m.upstream_model != nil {
|
||||
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||
}
|
||||
@@ -20365,6 +20418,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.RequestID()
|
||||
case usagelog.FieldModel:
|
||||
return m.Model()
|
||||
case usagelog.FieldRequestedModel:
|
||||
return m.RequestedModel()
|
||||
case usagelog.FieldUpstreamModel:
|
||||
return m.UpstreamModel()
|
||||
case usagelog.FieldGroupID:
|
||||
@@ -20440,6 +20495,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
||||
return m.OldRequestID(ctx)
|
||||
case usagelog.FieldModel:
|
||||
return m.OldModel(ctx)
|
||||
case usagelog.FieldRequestedModel:
|
||||
return m.OldRequestedModel(ctx)
|
||||
case usagelog.FieldUpstreamModel:
|
||||
return m.OldUpstreamModel(ctx)
|
||||
case usagelog.FieldGroupID:
|
||||
@@ -20540,6 +20597,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetModel(v)
|
||||
return nil
|
||||
case usagelog.FieldRequestedModel:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetRequestedModel(v)
|
||||
return nil
|
||||
case usagelog.FieldUpstreamModel:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
@@ -20985,6 +21049,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
|
||||
// mutation.
|
||||
func (m *UsageLogMutation) ClearedFields() []string {
|
||||
var fields []string
|
||||
if m.FieldCleared(usagelog.FieldRequestedModel) {
|
||||
fields = append(fields, usagelog.FieldRequestedModel)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldUpstreamModel) {
|
||||
fields = append(fields, usagelog.FieldUpstreamModel)
|
||||
}
|
||||
@@ -21029,6 +21096,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool {
|
||||
// error if the field is not defined in the schema.
|
||||
func (m *UsageLogMutation) ClearField(name string) error {
|
||||
switch name {
|
||||
case usagelog.FieldRequestedModel:
|
||||
m.ClearRequestedModel()
|
||||
return nil
|
||||
case usagelog.FieldUpstreamModel:
|
||||
m.ClearUpstreamModel()
|
||||
return nil
|
||||
@@ -21082,6 +21152,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
||||
case usagelog.FieldModel:
|
||||
m.ResetModel()
|
||||
return nil
|
||||
case usagelog.FieldRequestedModel:
|
||||
m.ResetRequestedModel()
|
||||
return nil
|
||||
case usagelog.FieldUpstreamModel:
|
||||
m.ResetUpstreamModel()
|
||||
return nil
|
||||
|
||||
@@ -821,96 +821,100 @@ func init() {
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
// usagelogDescRequestedModel is the schema descriptor for requested_model field.
|
||||
usagelogDescRequestedModel := usagelogFields[5].Descriptor()
|
||||
// usagelog.RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save.
|
||||
usagelog.RequestedModelValidator = usagelogDescRequestedModel.Validators[0].(func(string) error)
|
||||
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
|
||||
usagelogDescUpstreamModel := usagelogFields[5].Descriptor()
|
||||
usagelogDescUpstreamModel := usagelogFields[6].Descriptor()
|
||||
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
|
||||
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
|
||||
usagelogDescInputTokens := usagelogFields[8].Descriptor()
|
||||
usagelogDescInputTokens := usagelogFields[9].Descriptor()
|
||||
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
|
||||
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
|
||||
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
|
||||
usagelogDescOutputTokens := usagelogFields[9].Descriptor()
|
||||
usagelogDescOutputTokens := usagelogFields[10].Descriptor()
|
||||
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
|
||||
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
|
||||
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
|
||||
usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor()
|
||||
usagelogDescCacheCreationTokens := usagelogFields[11].Descriptor()
|
||||
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
|
||||
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
|
||||
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
|
||||
usagelogDescCacheReadTokens := usagelogFields[11].Descriptor()
|
||||
usagelogDescCacheReadTokens := usagelogFields[12].Descriptor()
|
||||
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
|
||||
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
|
||||
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
|
||||
usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor()
|
||||
usagelogDescCacheCreation5mTokens := usagelogFields[13].Descriptor()
|
||||
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
|
||||
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
|
||||
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
|
||||
usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor()
|
||||
usagelogDescCacheCreation1hTokens := usagelogFields[14].Descriptor()
|
||||
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
|
||||
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
|
||||
// usagelogDescInputCost is the schema descriptor for input_cost field.
|
||||
usagelogDescInputCost := usagelogFields[14].Descriptor()
|
||||
usagelogDescInputCost := usagelogFields[15].Descriptor()
|
||||
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
|
||||
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
|
||||
// usagelogDescOutputCost is the schema descriptor for output_cost field.
|
||||
usagelogDescOutputCost := usagelogFields[15].Descriptor()
|
||||
usagelogDescOutputCost := usagelogFields[16].Descriptor()
|
||||
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
|
||||
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
|
||||
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
|
||||
usagelogDescCacheCreationCost := usagelogFields[16].Descriptor()
|
||||
usagelogDescCacheCreationCost := usagelogFields[17].Descriptor()
|
||||
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
|
||||
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
|
||||
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
|
||||
usagelogDescCacheReadCost := usagelogFields[17].Descriptor()
|
||||
usagelogDescCacheReadCost := usagelogFields[18].Descriptor()
|
||||
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
|
||||
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
|
||||
// usagelogDescTotalCost is the schema descriptor for total_cost field.
|
||||
usagelogDescTotalCost := usagelogFields[18].Descriptor()
|
||||
usagelogDescTotalCost := usagelogFields[19].Descriptor()
|
||||
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
|
||||
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
|
||||
// usagelogDescActualCost is the schema descriptor for actual_cost field.
|
||||
usagelogDescActualCost := usagelogFields[19].Descriptor()
|
||||
usagelogDescActualCost := usagelogFields[20].Descriptor()
|
||||
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
|
||||
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
|
||||
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
|
||||
usagelogDescRateMultiplier := usagelogFields[20].Descriptor()
|
||||
usagelogDescRateMultiplier := usagelogFields[21].Descriptor()
|
||||
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
|
||||
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
|
||||
// usagelogDescBillingType is the schema descriptor for billing_type field.
|
||||
usagelogDescBillingType := usagelogFields[22].Descriptor()
|
||||
usagelogDescBillingType := usagelogFields[23].Descriptor()
|
||||
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
|
||||
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
|
||||
// usagelogDescStream is the schema descriptor for stream field.
|
||||
usagelogDescStream := usagelogFields[23].Descriptor()
|
||||
usagelogDescStream := usagelogFields[24].Descriptor()
|
||||
// usagelog.DefaultStream holds the default value on creation for the stream field.
|
||||
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
|
||||
// usagelogDescUserAgent is the schema descriptor for user_agent field.
|
||||
usagelogDescUserAgent := usagelogFields[26].Descriptor()
|
||||
usagelogDescUserAgent := usagelogFields[27].Descriptor()
|
||||
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
|
||||
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
|
||||
// usagelogDescIPAddress is the schema descriptor for ip_address field.
|
||||
usagelogDescIPAddress := usagelogFields[27].Descriptor()
|
||||
usagelogDescIPAddress := usagelogFields[28].Descriptor()
|
||||
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
|
||||
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
|
||||
// usagelogDescImageCount is the schema descriptor for image_count field.
|
||||
usagelogDescImageCount := usagelogFields[28].Descriptor()
|
||||
usagelogDescImageCount := usagelogFields[29].Descriptor()
|
||||
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
|
||||
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
|
||||
// usagelogDescImageSize is the schema descriptor for image_size field.
|
||||
usagelogDescImageSize := usagelogFields[29].Descriptor()
|
||||
usagelogDescImageSize := usagelogFields[30].Descriptor()
|
||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||
// usagelogDescMediaType is the schema descriptor for media_type field.
|
||||
usagelogDescMediaType := usagelogFields[30].Descriptor()
|
||||
usagelogDescMediaType := usagelogFields[31].Descriptor()
|
||||
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
|
||||
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
|
||||
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor()
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[32].Descriptor()
|
||||
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||
usagelogDescCreatedAt := usagelogFields[32].Descriptor()
|
||||
usagelogDescCreatedAt := usagelogFields[33].Descriptor()
|
||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||
userMixin := schema.User{}.Mixin()
|
||||
|
||||
@@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field {
|
||||
field.String("model").
|
||||
MaxLen(100).
|
||||
NotEmpty(),
|
||||
// RequestedModel stores the client-requested model name for stable display and analytics.
|
||||
// NULL means historical rows written before requested_model dual-write was introduced.
|
||||
field.String("requested_model").
|
||||
MaxLen(100).
|
||||
Optional().
|
||||
Nillable(),
|
||||
// UpstreamModel stores the actual upstream model name when model mapping
|
||||
// is applied. NULL means no mapping — the requested model was used as-is.
|
||||
field.String("upstream_model").
|
||||
@@ -181,6 +187,7 @@ func (UsageLog) Indexes() []ent.Index {
|
||||
index.Fields("subscription_id"),
|
||||
index.Fields("created_at"),
|
||||
index.Fields("model"),
|
||||
index.Fields("requested_model"),
|
||||
index.Fields("request_id"),
|
||||
// 复合索引用于时间范围查询
|
||||
index.Fields("user_id", "created_at"),
|
||||
|
||||
@@ -32,6 +32,8 @@ type UsageLog struct {
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
// Model holds the value of the "model" field.
|
||||
Model string `json:"model,omitempty"`
|
||||
// RequestedModel holds the value of the "requested_model" field.
|
||||
RequestedModel *string `json:"requested_model,omitempty"`
|
||||
// UpstreamModel holds the value of the "upstream_model" field.
|
||||
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||
// GroupID holds the value of the "group_id" field.
|
||||
@@ -177,7 +179,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
|
||||
values[i] = new(sql.NullString)
|
||||
case usagelog.FieldCreatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@@ -232,6 +234,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Model = value.String
|
||||
}
|
||||
case usagelog.FieldRequestedModel:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field requested_model", values[i])
|
||||
} else if value.Valid {
|
||||
_m.RequestedModel = new(string)
|
||||
*_m.RequestedModel = value.String
|
||||
}
|
||||
case usagelog.FieldUpstreamModel:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
|
||||
@@ -486,6 +495,11 @@ func (_m *UsageLog) String() string {
|
||||
builder.WriteString("model=")
|
||||
builder.WriteString(_m.Model)
|
||||
builder.WriteString(", ")
|
||||
if v := _m.RequestedModel; v != nil {
|
||||
builder.WriteString("requested_model=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.UpstreamModel; v != nil {
|
||||
builder.WriteString("upstream_model=")
|
||||
builder.WriteString(*v)
|
||||
|
||||
@@ -24,6 +24,8 @@ const (
|
||||
FieldRequestID = "request_id"
|
||||
// FieldModel holds the string denoting the model field in the database.
|
||||
FieldModel = "model"
|
||||
// FieldRequestedModel holds the string denoting the requested_model field in the database.
|
||||
FieldRequestedModel = "requested_model"
|
||||
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
|
||||
FieldUpstreamModel = "upstream_model"
|
||||
// FieldGroupID holds the string denoting the group_id field in the database.
|
||||
@@ -137,6 +139,7 @@ var Columns = []string{
|
||||
FieldAccountID,
|
||||
FieldRequestID,
|
||||
FieldModel,
|
||||
FieldRequestedModel,
|
||||
FieldUpstreamModel,
|
||||
FieldGroupID,
|
||||
FieldSubscriptionID,
|
||||
@@ -182,6 +185,8 @@ var (
|
||||
RequestIDValidator func(string) error
|
||||
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
|
||||
ModelValidator func(string) error
|
||||
// RequestedModelValidator is a validator for the "requested_model" field. It is called by the builders before save.
|
||||
RequestedModelValidator func(string) error
|
||||
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
|
||||
UpstreamModelValidator func(string) error
|
||||
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
|
||||
@@ -263,6 +268,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByRequestedModel orders the results by the requested_model field.
|
||||
func ByRequestedModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldRequestedModel, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUpstreamModel orders the results by the upstream_model field.
|
||||
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
|
||||
|
||||
@@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
|
||||
}
|
||||
|
||||
// RequestedModel applies equality check predicate on the "requested_model" field. It's identical to RequestedModelEQ.
|
||||
func RequestedModel(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
|
||||
func UpstreamModel(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||
@@ -410,6 +415,81 @@ func ModelContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelEQ applies the EQ predicate on the "requested_model" field.
|
||||
func RequestedModelEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelNEQ applies the NEQ predicate on the "requested_model" field.
|
||||
func RequestedModelNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelIn applies the In predicate on the "requested_model" field.
|
||||
func RequestedModelIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldRequestedModel, vs...))
|
||||
}
|
||||
|
||||
// RequestedModelNotIn applies the NotIn predicate on the "requested_model" field.
|
||||
func RequestedModelNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldRequestedModel, vs...))
|
||||
}
|
||||
|
||||
// RequestedModelGT applies the GT predicate on the "requested_model" field.
|
||||
func RequestedModelGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelGTE applies the GTE predicate on the "requested_model" field.
|
||||
func RequestedModelGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelLT applies the LT predicate on the "requested_model" field.
|
||||
func RequestedModelLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelLTE applies the LTE predicate on the "requested_model" field.
|
||||
func RequestedModelLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelContains applies the Contains predicate on the "requested_model" field.
|
||||
func RequestedModelContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelHasPrefix applies the HasPrefix predicate on the "requested_model" field.
|
||||
func RequestedModelHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelHasSuffix applies the HasSuffix predicate on the "requested_model" field.
|
||||
func RequestedModelHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelIsNil applies the IsNil predicate on the "requested_model" field.
|
||||
func RequestedModelIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldRequestedModel))
|
||||
}
|
||||
|
||||
// RequestedModelNotNil applies the NotNil predicate on the "requested_model" field.
|
||||
func RequestedModelNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldRequestedModel))
|
||||
}
|
||||
|
||||
// RequestedModelEqualFold applies the EqualFold predicate on the "requested_model" field.
|
||||
func RequestedModelEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// RequestedModelContainsFold applies the ContainsFold predicate on the "requested_model" field.
|
||||
func RequestedModelContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldRequestedModel, v))
|
||||
}
|
||||
|
||||
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
|
||||
func UpstreamModelEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
|
||||
|
||||
@@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (_c *UsageLogCreate) SetRequestedModel(v string) *UsageLogCreate {
|
||||
_c.mutation.SetRequestedModel(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableRequestedModel(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetRequestedModel(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
|
||||
_c.mutation.SetUpstreamModel(v)
|
||||
@@ -610,6 +624,11 @@ func (_c *UsageLogCreate) check() error {
|
||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.RequestedModel(); ok {
|
||||
if err := usagelog.RequestedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.UpstreamModel(); ok {
|
||||
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
@@ -733,6 +752,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||
_node.Model = value
|
||||
}
|
||||
if value, ok := _c.mutation.RequestedModel(); ok {
|
||||
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
|
||||
_node.RequestedModel = &value
|
||||
}
|
||||
if value, ok := _c.mutation.UpstreamModel(); ok {
|
||||
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||
_node.UpstreamModel = &value
|
||||
@@ -1034,6 +1057,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (u *UsageLogUpsert) SetRequestedModel(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldRequestedModel, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateRequestedModel() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldRequestedModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearRequestedModel clears the value of the "requested_model" field.
|
||||
func (u *UsageLogUpsert) ClearRequestedModel() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldRequestedModel)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldUpstreamModel, v)
|
||||
@@ -1641,6 +1682,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (u *UsageLogUpsertOne) SetRequestedModel(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetRequestedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateRequestedModel() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateRequestedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearRequestedModel clears the value of the "requested_model" field.
|
||||
func (u *UsageLogUpsertOne) ClearRequestedModel() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearRequestedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
@@ -2496,6 +2558,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (u *UsageLogUpsertBulk) SetRequestedModel(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetRequestedModel(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRequestedModel sets the "requested_model" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateRequestedModel() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateRequestedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearRequestedModel clears the value of the "requested_model" field.
|
||||
func (u *UsageLogUpsertBulk) ClearRequestedModel() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearRequestedModel()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
|
||||
@@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (_u *UsageLogUpdate) SetRequestedModel(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetRequestedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableRequestedModel(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetRequestedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearRequestedModel clears the value of the "requested_model" field.
|
||||
func (_u *UsageLogUpdate) ClearRequestedModel() *UsageLogUpdate {
|
||||
_u.mutation.ClearRequestedModel()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetUpstreamModel(v)
|
||||
@@ -765,6 +785,11 @@ func (_u *UsageLogUpdate) check() error {
|
||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.RequestedModel(); ok {
|
||||
if err := usagelog.RequestedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
@@ -820,6 +845,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.Model(); ok {
|
||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RequestedModel(); ok {
|
||||
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.RequestedModelCleared() {
|
||||
_spec.ClearField(usagelog.FieldRequestedModel, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||
}
|
||||
@@ -1208,6 +1239,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRequestedModel sets the "requested_model" field.
|
||||
func (_u *UsageLogUpdateOne) SetRequestedModel(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetRequestedModel(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRequestedModel sets the "requested_model" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableRequestedModel(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetRequestedModel(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearRequestedModel clears the value of the "requested_model" field.
|
||||
func (_u *UsageLogUpdateOne) ClearRequestedModel() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearRequestedModel()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpstreamModel sets the "upstream_model" field.
|
||||
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetUpstreamModel(v)
|
||||
@@ -1884,6 +1935,11 @@ func (_u *UsageLogUpdateOne) check() error {
|
||||
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.RequestedModel(); ok {
|
||||
if err := usagelog.RequestedModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "requested_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.requested_model": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.UpstreamModel(); ok {
|
||||
if err := usagelog.UpstreamModelValidator(v); err != nil {
|
||||
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
|
||||
@@ -1956,6 +2012,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
||||
if value, ok := _u.mutation.Model(); ok {
|
||||
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RequestedModel(); ok {
|
||||
_spec.SetField(usagelog.FieldRequestedModel, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.RequestedModelCleared() {
|
||||
_spec.ClearField(usagelog.FieldRequestedModel, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.UpstreamModel(); ok {
|
||||
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
|
||||
}
|
||||
|
||||
@@ -94,6 +94,10 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
@@ -195,6 +199,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -230,6 +236,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@@ -263,6 +271,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@@ -314,6 +324,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
|
||||
@@ -110,6 +110,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
@@ -176,6 +177,7 @@ type UpdateSettingsRequest struct {
|
||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -417,6 +419,55 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
customMenuJSON = string(menuBytes)
|
||||
}
|
||||
|
||||
// 自定义端点验证
|
||||
const (
|
||||
maxCustomEndpoints = 10
|
||||
maxEndpointNameLen = 50
|
||||
maxEndpointURLLen = 2048
|
||||
maxEndpointDescriptionLen = 200
|
||||
)
|
||||
|
||||
customEndpointsJSON := previousSettings.CustomEndpoints
|
||||
if req.CustomEndpoints != nil {
|
||||
endpoints := *req.CustomEndpoints
|
||||
if len(endpoints) > maxCustomEndpoints {
|
||||
response.BadRequest(c, "Too many custom endpoints (max 10)")
|
||||
return
|
||||
}
|
||||
for _, ep := range endpoints {
|
||||
if strings.TrimSpace(ep.Name) == "" {
|
||||
response.BadRequest(c, "Custom endpoint name is required")
|
||||
return
|
||||
}
|
||||
if len(ep.Name) > maxEndpointNameLen {
|
||||
response.BadRequest(c, "Custom endpoint name is too long (max 50 characters)")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ep.Endpoint) == "" {
|
||||
response.BadRequest(c, "Custom endpoint URL is required")
|
||||
return
|
||||
}
|
||||
if len(ep.Endpoint) > maxEndpointURLLen {
|
||||
response.BadRequest(c, "Custom endpoint URL is too long (max 2048 characters)")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(ep.Endpoint)); err != nil {
|
||||
response.BadRequest(c, "Custom endpoint URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
if len(ep.Description) > maxEndpointDescriptionLen {
|
||||
response.BadRequest(c, "Custom endpoint description is too long (max 200 characters)")
|
||||
return
|
||||
}
|
||||
}
|
||||
endpointBytes, err := json.Marshal(endpoints)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to serialize custom endpoints")
|
||||
return
|
||||
}
|
||||
customEndpointsJSON = string(endpointBytes)
|
||||
}
|
||||
|
||||
// Ops metrics collector interval validation (seconds).
|
||||
if req.OpsMetricsIntervalSeconds != nil {
|
||||
v := *req.OpsMetricsIntervalSeconds
|
||||
@@ -495,6 +546,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
SoraClientEnabled: req.SoraClientEnabled,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
CustomEndpoints: customEndpointsJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
@@ -592,6 +644,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: updatedSettings.SoraClientEnabled,
|
||||
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
DefaultSubscriptions: updatedDefaultSubscriptions,
|
||||
|
||||
@@ -276,11 +276,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
out.QuotaDailyLimit = &limit
|
||||
used := a.GetQuotaDailyUsed()
|
||||
if a.IsDailyQuotaPeriodExpired() {
|
||||
used = 0
|
||||
}
|
||||
out.QuotaDailyUsed = &used
|
||||
}
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
out.QuotaWeeklyLimit = &limit
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
if a.IsWeeklyQuotaPeriodExpired() {
|
||||
used = 0
|
||||
}
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
// 固定时间重置配置
|
||||
@@ -516,14 +522,17 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
||||
requestType := l.EffectiveRequestType()
|
||||
stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode)
|
||||
requestedModel := l.RequestedModel
|
||||
if requestedModel == "" {
|
||||
requestedModel = l.Model
|
||||
}
|
||||
return UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
APIKeyID: l.APIKeyID,
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
UpstreamModel: l.UpstreamModel,
|
||||
Model: requestedModel,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
InboundEndpoint: l.InboundEndpoint,
|
||||
@@ -580,6 +589,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
|
||||
}
|
||||
return &AdminUsageLog{
|
||||
UsageLog: usageLogFromServiceUser(l),
|
||||
UpstreamModel: l.UpstreamModel,
|
||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||
IPAddress: l.IPAddress,
|
||||
Account: AccountSummaryFromService(l.Account),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -106,6 +107,47 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_UsesRequestedModelAndKeepsUpstreamAdminOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upstreamModel := "claude-sonnet-4-20250514"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_4",
|
||||
Model: upstreamModel,
|
||||
RequestedModel: "claude-sonnet-4",
|
||||
UpstreamModel: &upstreamModel,
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.Equal(t, "claude-sonnet-4", userDTO.Model)
|
||||
require.Equal(t, "claude-sonnet-4", adminDTO.Model)
|
||||
|
||||
userJSON, err := json.Marshal(userDTO)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(userJSON), "upstream_model")
|
||||
|
||||
adminJSON, err := json.Marshal(adminDTO)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(adminJSON), `"upstream_model":"claude-sonnet-4-20250514"`)
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_legacy",
|
||||
Model: "claude-3",
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.Equal(t, "claude-3", userDTO.Model)
|
||||
require.Equal(t, "claude-3", adminDTO.Model)
|
||||
}
|
||||
|
||||
func f64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -15,6 +15,13 @@ type CustomMenuItem struct {
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
// CustomEndpoint represents an admin-configured API endpoint for quick copy.
|
||||
type CustomEndpoint struct {
|
||||
Name string `json:"name"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
@@ -56,6 +63,7 @@ type SystemSettings struct {
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
@@ -114,6 +122,7 @@ type PublicSettings struct {
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
@@ -218,3 +227,17 @@ func ParseUserVisibleMenuItems(raw string) []CustomMenuItem {
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// ParseCustomEndpoints parses a JSON string into a slice of CustomEndpoint.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomEndpoints(raw string) []CustomEndpoint {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || raw == "[]" {
|
||||
return []CustomEndpoint{}
|
||||
}
|
||||
var items []CustomEndpoint
|
||||
if err := json.Unmarshal([]byte(raw), &items); err != nil {
|
||||
return []CustomEndpoint{}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
@@ -334,9 +334,6 @@ type UsageLog struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
// UpstreamModel is the actual model sent to the upstream provider after mapping.
|
||||
// Omitted when no mapping was applied (requested model was used as-is).
|
||||
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level.
|
||||
@@ -396,6 +393,10 @@ type UsageLog struct {
|
||||
type AdminUsageLog struct {
|
||||
UsageLog
|
||||
|
||||
// UpstreamModel is the actual model sent to the upstream provider after mapping.
|
||||
// Omitted when no mapping was applied (requested model was used as-is).
|
||||
UpstreamModel *string `json:"upstream_model,omitempty"`
|
||||
|
||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
|
||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
|
||||
@@ -37,6 +37,16 @@ type OpenAIGatewayHandler struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackModel string) string {
|
||||
if fallbackModel = strings.TrimSpace(fallbackModel); fallbackModel != "" {
|
||||
return fallbackModel
|
||||
}
|
||||
if apiKey == nil || apiKey.Group == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
@@ -657,9 +667,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
|
||||
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
|
||||
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
|
||||
@@ -352,6 +352,30 @@ func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
|
||||
t.Run("prefers_explicit_fallback_model", func(t *testing.T) {
|
||||
apiKey := &service.APIKey{
|
||||
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
|
||||
}
|
||||
require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 "))
|
||||
})
|
||||
|
||||
t.Run("uses_group_default_on_normal_path", func(t *testing.T) {
|
||||
apiKey := &service.APIKey{
|
||||
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
|
||||
}
|
||||
require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, ""))
|
||||
})
|
||||
|
||||
t.Run("returns_empty_without_group_default", func(t *testing.T) {
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, ""))
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, ""))
|
||||
require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{
|
||||
Group: &service.Group{},
|
||||
}, ""))
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
|
||||
@@ -632,8 +632,8 @@ func TestAnthropicToResponses_ThinkingEnabled(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
// thinking.type is ignored for effort; default xhigh applies.
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; default high applies.
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.Contains(t, resp.Include, "reasoning.encrypted_content")
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
@@ -650,8 +650,8 @@ func TestAnthropicToResponses_ThinkingAdaptive(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
// thinking.type is ignored for effort; default xhigh applies.
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
// thinking.type is ignored for effort; default high applies.
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
assert.NotContains(t, resp.Include, "reasoning.summary")
|
||||
}
|
||||
@@ -666,9 +666,9 @@ func TestAnthropicToResponses_ThinkingDisabled(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
// Default effort applies (high → xhigh) even when thinking is disabled.
|
||||
// Default effort applies (high → high) even when thinking is disabled.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
@@ -680,9 +680,9 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
// Default effort applies (high → xhigh) when no thinking/output_config is set.
|
||||
// Default effort applies (high → high) when no thinking/output_config is set.
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -690,7 +690,7 @@ func TestAnthropicToResponses_NoThinking(t *testing.T) {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) {
|
||||
// Default is xhigh, but output_config.effort="low" overrides. low→low after mapping.
|
||||
// Default is high, but output_config.effort="low" overrides. low→low after mapping.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
@@ -708,7 +708,7 @@ func TestAnthropicToResponses_OutputConfigOverridesDefault(t *testing.T) {
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) {
|
||||
// No thinking field, but output_config.effort="medium" → creates reasoning.
|
||||
// medium→high after mapping.
|
||||
// medium→medium after 1:1 mapping.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
@@ -719,12 +719,12 @@ func TestAnthropicToResponses_OutputConfigWithoutThinking(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "medium", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
|
||||
// output_config.effort="high" → mapped to "xhigh".
|
||||
// output_config.effort="high" → mapped to "high" (1:1, both sides' default).
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
@@ -732,6 +732,22 @@ func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "high"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "auto", resp.Reasoning.Summary)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigMax(t *testing.T) {
|
||||
// output_config.effort="max" → mapped to OpenAI's highest supported level "xhigh".
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
OutputConfig: &AnthropicOutputConfig{Effort: "max"},
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
@@ -740,7 +756,7 @@ func TestAnthropicToResponses_OutputConfigHigh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
|
||||
// No output_config → default xhigh regardless of thinking.type.
|
||||
// No output_config → default high regardless of thinking.type.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
@@ -751,11 +767,11 @@ func TestAnthropicToResponses_NoOutputConfig(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
|
||||
// output_config present but effort empty (e.g. only format set) → default xhigh.
|
||||
// output_config present but effort empty (e.g. only format set) → default high.
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
@@ -766,7 +782,7 @@ func TestAnthropicToResponses_OutputConfigWithoutEffort(t *testing.T) {
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Reasoning)
|
||||
assert.Equal(t, "xhigh", resp.Reasoning.Effort)
|
||||
assert.Equal(t, "high", resp.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -46,9 +46,10 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
}
|
||||
|
||||
// Determine reasoning effort: only output_config.effort controls the
|
||||
// level; thinking.type is ignored. Default is xhigh when unset.
|
||||
// Anthropic levels map to OpenAI: low→low, medium→high, high→xhigh.
|
||||
effort := "high" // default → maps to xhigh
|
||||
// level; thinking.type is ignored. Default is high when unset (both
|
||||
// Anthropic and OpenAI default to high).
|
||||
// Anthropic levels map 1:1 to OpenAI: low→low, medium→medium, high→high, max→xhigh.
|
||||
effort := "high" // default → both sides' default
|
||||
if req.OutputConfig != nil && req.OutputConfig.Effort != "" {
|
||||
effort = req.OutputConfig.Effort
|
||||
}
|
||||
@@ -380,18 +381,19 @@ func extractAnthropicTextFromBlocks(blocks []AnthropicContentBlock) string {
|
||||
// mapAnthropicEffortToResponses converts Anthropic reasoning effort levels to
|
||||
// OpenAI Responses API effort levels.
|
||||
//
|
||||
// Both APIs default to "high". The mapping is 1:1 for shared levels;
|
||||
// only Anthropic's "max" (Opus 4.6 exclusive) maps to OpenAI's "xhigh"
|
||||
// (GPT-5.2+ exclusive) as both represent the highest reasoning tier.
|
||||
//
|
||||
// low → low
|
||||
// medium → high
|
||||
// high → xhigh
|
||||
// medium → medium
|
||||
// high → high
|
||||
// max → xhigh
|
||||
func mapAnthropicEffortToResponses(effort string) string {
|
||||
switch effort {
|
||||
case "medium":
|
||||
return "high"
|
||||
case "high":
|
||||
if effort == "max" {
|
||||
return "xhigh"
|
||||
default:
|
||||
return effort // "low" and any unknown values pass through unchanged
|
||||
}
|
||||
return effort // low→low, medium→medium, high→high, unknown→passthrough
|
||||
}
|
||||
|
||||
// convertAnthropicToolsToResponses maps Anthropic tool definitions to
|
||||
|
||||
@@ -181,6 +181,35 @@ func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
|
||||
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "system", Content: json.RawMessage(`[{"type":"text","text":"You are a careful visual assistant."}]`)},
|
||||
{Role: "user", Content: json.RawMessage(`[{"type":"text","text":"Describe this image"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
|
||||
var systemParts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[0].Content, &systemParts))
|
||||
require.Len(t, systemParts, 1)
|
||||
assert.Equal(t, "input_text", systemParts[0].Type)
|
||||
assert.Equal(t, "You are a careful visual assistant.", systemParts[0].Text)
|
||||
|
||||
var userParts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &userParts))
|
||||
require.Len(t, userParts, 2)
|
||||
assert.Equal(t, "input_image", userParts[1].Type)
|
||||
assert.Equal(t, "data:image/png;base64,abc123", userParts[1].ImageURL)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
@@ -398,6 +427,45 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_ToolArrayContent(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Use the tool"`)},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ChatToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: "inspect_image",
|
||||
Arguments: `{}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
ToolCallID: "call_1",
|
||||
Content: json.RawMessage(
|
||||
`[{"type":"text","text":"image width: 100"},{"type":"image_url","image_url":{"url":"data:image/png;base64,ignored"}},{"type":"text","text":"; image height: 200"}]`,
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 3)
|
||||
assert.Equal(t, "function_call_output", items[2].Type)
|
||||
assert.Equal(t, "call_1", items[2].CallID)
|
||||
assert.Equal(t, "image width: 100; image height: 200", items[2].Output)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||
resp := &ResponsesResponse{
|
||||
ID: "resp_inc",
|
||||
|
||||
@@ -6,6 +6,11 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type chatMessageContent struct {
|
||||
Text *string
|
||||
Parts []ChatContentPart
|
||||
}
|
||||
|
||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||
// Responses API request. The upstream always streams, so Stream is forced to
|
||||
// true. store is always false and reasoning.encrypted_content is always
|
||||
@@ -113,11 +118,11 @@ func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
|
||||
// chatSystemToResponses converts a system message.
|
||||
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
text, err := parseChatContent(m.Content)
|
||||
parsed, err := parseChatMessageContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content, err := json.Marshal(text)
|
||||
content, err := marshalChatInputContent(parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -127,39 +132,11 @@ func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
// chatUserToResponses converts a user message, handling both plain strings and
|
||||
// multi-modal content arrays.
|
||||
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
// Try plain string first.
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil {
|
||||
content, _ := json.Marshal(s)
|
||||
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
|
||||
}
|
||||
|
||||
var parts []ChatContentPart
|
||||
if err := json.Unmarshal(m.Content, &parts); err != nil {
|
||||
parsed, err := parseChatMessageContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse user content: %w", err)
|
||||
}
|
||||
|
||||
var responseParts []ResponsesContentPart
|
||||
for _, p := range parts {
|
||||
switch p.Type {
|
||||
case "text":
|
||||
if p.Text != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_text",
|
||||
Text: p.Text,
|
||||
})
|
||||
}
|
||||
case "image_url":
|
||||
if p.ImageURL != nil && p.ImageURL.URL != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_image",
|
||||
ImageURL: p.ImageURL.URL,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content, err := json.Marshal(responseParts)
|
||||
content, err := marshalChatInputContent(parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -312,16 +289,79 @@ func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
}
|
||||
|
||||
// parseChatContent returns the string value of a ChatMessage Content field.
|
||||
// Content must be a JSON string. Returns "" if content is null or empty.
|
||||
// Content can be a JSON string or an array of typed parts. Array content is
|
||||
// flattened to text by concatenating text parts and ignoring non-text parts.
|
||||
func parseChatContent(raw json.RawMessage) (string, error) {
|
||||
parsed, err := parseChatMessageContent(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsed.Text != nil {
|
||||
return *parsed.Text, nil
|
||||
}
|
||||
return flattenChatContentParts(parsed.Parts), nil
|
||||
}
|
||||
|
||||
func parseChatMessageContent(raw json.RawMessage) (chatMessageContent, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
return chatMessageContent{Text: stringPtr("")}, nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err != nil {
|
||||
return "", fmt.Errorf("parse content as string: %w", err)
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return chatMessageContent{Text: &s}, nil
|
||||
}
|
||||
return s, nil
|
||||
|
||||
var parts []ChatContentPart
|
||||
if err := json.Unmarshal(raw, &parts); err == nil {
|
||||
return chatMessageContent{Parts: parts}, nil
|
||||
}
|
||||
|
||||
return chatMessageContent{}, fmt.Errorf("parse content as string or parts array")
|
||||
}
|
||||
|
||||
func marshalChatInputContent(content chatMessageContent) (json.RawMessage, error) {
|
||||
if content.Text != nil {
|
||||
return json.Marshal(*content.Text)
|
||||
}
|
||||
return json.Marshal(convertChatContentPartsToResponses(content.Parts))
|
||||
}
|
||||
|
||||
func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesContentPart {
|
||||
var responseParts []ResponsesContentPart
|
||||
for _, p := range parts {
|
||||
switch p.Type {
|
||||
case "text":
|
||||
if p.Text != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_text",
|
||||
Text: p.Text,
|
||||
})
|
||||
}
|
||||
case "image_url":
|
||||
if p.ImageURL != nil && p.ImageURL.URL != "" {
|
||||
responseParts = append(responseParts, ResponsesContentPart{
|
||||
Type: "input_image",
|
||||
ImageURL: p.ImageURL.URL,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return responseParts
|
||||
}
|
||||
|
||||
func flattenChatContentParts(parts []ChatContentPart) string {
|
||||
var textParts []string
|
||||
for _, p := range parts {
|
||||
if p.Type == "text" && p.Text != "" {
|
||||
textParts = append(textParts, p.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(textParts, "")
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
|
||||
|
||||
@@ -28,50 +28,64 @@ import (
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
|
||||
|
||||
// usageLogInsertArgTypes must stay in the same order as:
|
||||
// 1. prepareUsageLogInsert().args
|
||||
// 2. every INSERT/CTE VALUES column list in this file
|
||||
// 3. execUsageLogInsertNoResult placeholder positions
|
||||
// 4. scanUsageLog selected column order (via usageLogSelectColumns)
|
||||
//
|
||||
// When adding a usage_logs column, update all of those call sites together.
|
||||
var usageLogInsertArgTypes = [...]string{
|
||||
"bigint",
|
||||
"bigint",
|
||||
"bigint",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"bigint",
|
||||
"bigint",
|
||||
"integer",
|
||||
"integer",
|
||||
"integer",
|
||||
"integer",
|
||||
"integer",
|
||||
"integer",
|
||||
"numeric",
|
||||
"numeric",
|
||||
"numeric",
|
||||
"numeric",
|
||||
"numeric",
|
||||
"numeric",
|
||||
"numeric",
|
||||
"numeric",
|
||||
"smallint",
|
||||
"smallint",
|
||||
"boolean",
|
||||
"boolean",
|
||||
"integer",
|
||||
"integer",
|
||||
"text",
|
||||
"text",
|
||||
"integer",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"boolean",
|
||||
"timestamptz",
|
||||
"bigint", // user_id
|
||||
"bigint", // api_key_id
|
||||
"bigint", // account_id
|
||||
"text", // request_id
|
||||
"text", // model
|
||||
"text", // requested_model
|
||||
"text", // upstream_model
|
||||
"bigint", // group_id
|
||||
"bigint", // subscription_id
|
||||
"integer", // input_tokens
|
||||
"integer", // output_tokens
|
||||
"integer", // cache_creation_tokens
|
||||
"integer", // cache_read_tokens
|
||||
"integer", // cache_creation_5m_tokens
|
||||
"integer", // cache_creation_1h_tokens
|
||||
"numeric", // input_cost
|
||||
"numeric", // output_cost
|
||||
"numeric", // cache_creation_cost
|
||||
"numeric", // cache_read_cost
|
||||
"numeric", // total_cost
|
||||
"numeric", // actual_cost
|
||||
"numeric", // rate_multiplier
|
||||
"numeric", // account_rate_multiplier
|
||||
"smallint", // billing_type
|
||||
"smallint", // request_type
|
||||
"boolean", // stream
|
||||
"boolean", // openai_ws_mode
|
||||
"integer", // duration_ms
|
||||
"integer", // first_token_ms
|
||||
"text", // user_agent
|
||||
"text", // ip_address
|
||||
"integer", // image_count
|
||||
"text", // image_size
|
||||
"text", // media_type
|
||||
"text", // service_tier
|
||||
"text", // reasoning_effort
|
||||
"text", // inbound_endpoint
|
||||
"text", // upstream_endpoint
|
||||
"boolean", // cache_ttl_overridden
|
||||
"timestamptz", // created_at
|
||||
}
|
||||
|
||||
const rawUsageLogModelColumn = "model"
|
||||
|
||||
// rawUsageLogModelColumn preserves the exact stored usage_logs.model semantics for direct filters.
|
||||
// Historical rows may contain upstream/billing model values, while newer rows store requested_model.
|
||||
// Requested/upstream/mapping analytics must use resolveModelDimensionExpression instead.
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
"hour": "YYYY-MM-DD HH24:00",
|
||||
@@ -88,6 +102,30 @@ func safeDateFormat(granularity string) string {
|
||||
return "YYYY-MM-DD"
|
||||
}
|
||||
|
||||
// appendRawUsageLogModelWhereCondition keeps direct model filters on the raw model column for backward
|
||||
// compatibility with historical rows. Requested/upstream analytics must use
|
||||
// resolveModelDimensionExpression instead.
|
||||
func appendRawUsageLogModelWhereCondition(conditions []string, args []any, model string) ([]string, []any) {
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return conditions, args
|
||||
}
|
||||
conditions = append(conditions, fmt.Sprintf("%s = $%d", rawUsageLogModelColumn, len(args)+1))
|
||||
args = append(args, model)
|
||||
return conditions, args
|
||||
}
|
||||
|
||||
// appendRawUsageLogModelQueryFilter keeps direct model filters on the raw model column for backward
|
||||
// compatibility with historical rows. Requested/upstream analytics must use
|
||||
// resolveModelDimensionExpression instead.
|
||||
func appendRawUsageLogModelQueryFilter(query string, args []any, model string) (string, []any) {
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return query, args
|
||||
}
|
||||
query += fmt.Sprintf(" AND %s = $%d", rawUsageLogModelColumn, len(args)+1)
|
||||
args = append(args, model)
|
||||
return query, args
|
||||
}
|
||||
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
@@ -278,6 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
requested_model,
|
||||
upstream_model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
@@ -313,12 +352,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6,
|
||||
$7, $8,
|
||||
$9, $10, $11, $12,
|
||||
$13, $14,
|
||||
$15, $16, $17, $18, $19, $20,
|
||||
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
$8, $9,
|
||||
$10, $11, $12, $13,
|
||||
$14, $15,
|
||||
$16, $17, $18, $19, $20, $21,
|
||||
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -709,6 +748,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
requested_model,
|
||||
upstream_model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
@@ -779,6 +819,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
requested_model,
|
||||
upstream_model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
@@ -820,6 +861,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
requested_model,
|
||||
upstream_model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
@@ -901,6 +943,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
requested_model,
|
||||
upstream_model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
@@ -937,7 +980,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(preparedList)*39)
|
||||
args := make([]any, 0, len(preparedList)*40)
|
||||
argPos := 1
|
||||
for idx, prepared := range preparedList {
|
||||
if idx > 0 {
|
||||
@@ -968,6 +1011,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
requested_model,
|
||||
upstream_model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
@@ -1009,6 +1053,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
requested_model,
|
||||
upstream_model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
@@ -1058,6 +1103,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
requested_model,
|
||||
upstream_model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
@@ -1093,12 +1139,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6,
|
||||
$7, $8,
|
||||
$9, $10, $11, $12,
|
||||
$13, $14,
|
||||
$15, $16, $17, $18, $19, $20,
|
||||
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
$8, $9,
|
||||
$10, $11, $12, $13,
|
||||
$14, $15,
|
||||
$16, $17, $18, $19, $20, $21,
|
||||
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
`, prepared.args...)
|
||||
@@ -1130,6 +1176,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
inboundEndpoint := nullString(log.InboundEndpoint)
|
||||
upstreamEndpoint := nullString(log.UpstreamEndpoint)
|
||||
requestedModel := strings.TrimSpace(log.RequestedModel)
|
||||
if requestedModel == "" {
|
||||
requestedModel = strings.TrimSpace(log.Model)
|
||||
}
|
||||
upstreamModel := nullString(log.UpstreamModel)
|
||||
|
||||
var requestIDArg any
|
||||
@@ -1148,6 +1198,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
log.AccountID,
|
||||
requestIDArg,
|
||||
log.Model,
|
||||
nullString(&requestedModel),
|
||||
upstreamModel,
|
||||
groupID,
|
||||
subscriptionID,
|
||||
@@ -1702,7 +1753,7 @@ func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, acco
|
||||
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
|
||||
// 性能优化:数据库层聚合计算,避免应用层循环统计
|
||||
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
query := `
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
@@ -1712,8 +1763,8 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE model = $1 AND created_at >= $2 AND created_at < $3
|
||||
`
|
||||
WHERE %s = $1 AND created_at >= $2 AND created_at < $3
|
||||
`, rawUsageLogModelColumn)
|
||||
|
||||
var stats usagestats.UsageStats
|
||||
if err := scanSingleRow(
|
||||
@@ -1837,7 +1888,7 @@ func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
query := fmt.Sprintf("SELECT %s FROM usage_logs WHERE %s = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000", usageLogSelectColumns, rawUsageLogModelColumn)
|
||||
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
@@ -2532,10 +2583,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
|
||||
args = append(args, filters.GroupID)
|
||||
}
|
||||
if filters.Model != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||||
args = append(args, filters.Model)
|
||||
}
|
||||
conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model)
|
||||
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
|
||||
if filters.BillingType != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
@@ -2768,10 +2816,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if model != "" {
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
query, args = appendRawUsageLogModelQueryFilter(query, args, model)
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
@@ -3126,13 +3171,14 @@ func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayS
|
||||
|
||||
// resolveModelDimensionExpression maps model source type to a safe SQL expression.
|
||||
func resolveModelDimensionExpression(modelType string) string {
|
||||
requestedExpr := "COALESCE(NULLIF(TRIM(requested_model), ''), model)"
|
||||
switch usagestats.NormalizeModelSource(modelType) {
|
||||
case usagestats.ModelSourceUpstream:
|
||||
return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"
|
||||
return fmt.Sprintf("COALESCE(NULLIF(TRIM(upstream_model), ''), %s)", requestedExpr)
|
||||
case usagestats.ModelSourceMapping:
|
||||
return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"
|
||||
return fmt.Sprintf("(%s || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), %s))", requestedExpr, requestedExpr)
|
||||
default:
|
||||
return "model"
|
||||
return requestedExpr
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3204,10 +3250,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
|
||||
args = append(args, filters.GroupID)
|
||||
}
|
||||
if filters.Model != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||||
args = append(args, filters.Model)
|
||||
}
|
||||
conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model)
|
||||
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
|
||||
if filters.BillingType != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
@@ -3336,10 +3379,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if model != "" {
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
query, args = appendRawUsageLogModelQueryFilter(query, args, model)
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
@@ -3410,10 +3450,7 @@ func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if model != "" {
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
query, args = appendRawUsageLogModelQueryFilter(query, args, model)
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
@@ -3888,6 +3925,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
accountID int64
|
||||
requestID sql.NullString
|
||||
model string
|
||||
requestedModel sql.NullString
|
||||
upstreamModel sql.NullString
|
||||
groupID sql.NullInt64
|
||||
subscriptionID sql.NullInt64
|
||||
@@ -3931,6 +3969,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&accountID,
|
||||
&requestID,
|
||||
&model,
|
||||
&requestedModel,
|
||||
&upstreamModel,
|
||||
&groupID,
|
||||
&subscriptionID,
|
||||
@@ -3975,6 +4014,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
APIKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
Model: model,
|
||||
RequestedModel: coalesceTrimmedString(requestedModel, model),
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
CacheCreationTokens: cacheCreationTokens,
|
||||
@@ -4181,6 +4221,13 @@ func nullString(v *string) sql.NullString {
|
||||
return sql.NullString{String: *v, Valid: true}
|
||||
}
|
||||
|
||||
func coalesceTrimmedString(v sql.NullString, fallback string) string {
|
||||
if v.Valid && strings.TrimSpace(v.String) != "" {
|
||||
return v.String
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func setToSlice(set map[int64]struct{}) []int64 {
|
||||
out := make([]int64, 0, len(set))
|
||||
for id := range set {
|
||||
|
||||
@@ -34,11 +34,11 @@ func TestResolveModelDimensionExpression(t *testing.T) {
|
||||
modelType string
|
||||
want string
|
||||
}{
|
||||
{usagestats.ModelSourceRequested, "model"},
|
||||
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"},
|
||||
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"},
|
||||
{"", "model"},
|
||||
{"invalid", "model"},
|
||||
{usagestats.ModelSourceRequested, "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
|
||||
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model))"},
|
||||
{usagestats.ModelSourceMapping, "(COALESCE(NULLIF(TRIM(requested_model), ''), model) || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), COALESCE(NULLIF(TRIM(requested_model), ''), model)))"},
|
||||
{"", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
|
||||
{"invalid", "COALESCE(NULLIF(TRIM(requested_model), ''), model)"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
@@ -21,20 +22,21 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
|
||||
createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-1",
|
||||
Model: "gpt-5",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 1,
|
||||
ActualCost: 1,
|
||||
BillingType: service.BillingTypeBalance,
|
||||
RequestType: service.RequestTypeWSV2,
|
||||
Stream: false,
|
||||
OpenAIWSMode: false,
|
||||
CreatedAt: createdAt,
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-1",
|
||||
Model: "gpt-5",
|
||||
RequestedModel: "gpt-5",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 1,
|
||||
ActualCost: 1,
|
||||
BillingType: service.BillingTypeBalance,
|
||||
RequestType: service.RequestTypeWSV2,
|
||||
Stream: false,
|
||||
OpenAIWSMode: false,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_logs").
|
||||
@@ -44,6 +46,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
log.Model,
|
||||
log.RequestedModel,
|
||||
sqlmock.AnyArg(), // upstream_model
|
||||
sqlmock.AnyArg(), // group_id
|
||||
sqlmock.AnyArg(), // subscription_id
|
||||
@@ -99,13 +102,14 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC)
|
||||
serviceTier := "priority"
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-service-tier",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
CreatedAt: createdAt,
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-service-tier",
|
||||
Model: "gpt-5.4",
|
||||
RequestedModel: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_logs").
|
||||
@@ -115,6 +119,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
log.Model,
|
||||
log.RequestedModel,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
@@ -158,6 +163,75 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestBuildUsageLogBestEffortInsertQuery_IncludesRequestedModelColumn(t *testing.T) {
|
||||
prepared := prepareUsageLogInsert(&service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-best-effort-query",
|
||||
Model: "gpt-5",
|
||||
RequestedModel: "gpt-5",
|
||||
CreatedAt: time.Date(2025, 1, 3, 12, 0, 0, 0, time.UTC),
|
||||
})
|
||||
|
||||
query, args := buildUsageLogBestEffortInsertQuery([]usageLogInsertPrepared{prepared})
|
||||
|
||||
require.Contains(t, query, "INSERT INTO usage_logs (")
|
||||
require.Contains(t, query, "\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,")
|
||||
require.Contains(t, query, "\n\t\t\trequest_id,\n\t\t\tmodel,\n\t\t\trequested_model,\n\t\t\tupstream_model,")
|
||||
require.Len(t, args, len(prepared.args))
|
||||
require.Equal(t, prepared.args[5], args[5])
|
||||
}
|
||||
|
||||
func TestExecUsageLogInsertNoResult_PersistsRequestedModel(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
prepared := prepareUsageLogInsert(&service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-best-effort-exec",
|
||||
Model: "gpt-5",
|
||||
RequestedModel: "gpt-5",
|
||||
CreatedAt: time.Date(2025, 1, 4, 12, 0, 0, 0, time.UTC),
|
||||
})
|
||||
|
||||
mock.ExpectExec("INSERT INTO usage_logs").
|
||||
WithArgs(anySliceToDriverValues(prepared.args)...).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err := execUsageLogInsertNoResult(context.Background(), db, prepared)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestPrepareUsageLogInsert_ArgCountMatchesTypes(t *testing.T) {
|
||||
prepared := prepareUsageLogInsert(&service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-arg-count",
|
||||
Model: "gpt-5",
|
||||
RequestedModel: "gpt-5",
|
||||
CreatedAt: time.Date(2025, 1, 5, 12, 0, 0, 0, time.UTC),
|
||||
})
|
||||
|
||||
require.Len(t, prepared.args, len(usageLogInsertArgTypes))
|
||||
}
|
||||
|
||||
func TestCoalesceTrimmedString(t *testing.T) {
|
||||
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{}, "fallback"))
|
||||
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{Valid: true, String: " "}, "fallback"))
|
||||
require.Equal(t, "value", coalesceTrimmedString(sql.NullString{Valid: true, String: "value"}, "fallback"))
|
||||
}
|
||||
|
||||
func anySliceToDriverValues(values []any) []driver.Value {
|
||||
out := make([]driver.Value, 0, len(values))
|
||||
for _, value := range values {
|
||||
out = append(out, value)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
@@ -354,7 +428,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
int64(20), // api_key_id
|
||||
int64(30), // account_id
|
||||
sql.NullString{Valid: true, String: "req-1"},
|
||||
"gpt-5", // model
|
||||
"gpt-5", // model
|
||||
sql.NullString{Valid: true, String: "gpt-5"}, // requested_model
|
||||
sql.NullString{}, // upstream_model
|
||||
sql.NullInt64{}, // group_id
|
||||
sql.NullInt64{}, // subscription_id
|
||||
@@ -407,6 +482,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
int64(31),
|
||||
sql.NullString{Valid: true, String: "req-2"},
|
||||
"gpt-5",
|
||||
sql.NullString{Valid: true, String: "gpt-5"},
|
||||
sql.NullString{},
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
@@ -449,6 +525,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
int64(32),
|
||||
sql.NullString{Valid: true, String: "req-3"},
|
||||
"gpt-5.4",
|
||||
sql.NullString{Valid: true, String: "gpt-5.4"},
|
||||
sql.NullString{},
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
|
||||
@@ -540,7 +540,8 @@ func TestAPIContracts(t *testing.T) {
|
||||
"max_claude_code_version": "",
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"backend_mode_enabled": false,
|
||||
"custom_menu_items": []
|
||||
"custom_menu_items": [],
|
||||
"custom_endpoints": []
|
||||
}
|
||||
}`,
|
||||
},
|
||||
|
||||
@@ -1543,6 +1543,24 @@ func isPeriodExpired(periodStart time.Time, dur time.Duration) bool {
|
||||
return time.Since(periodStart) >= dur
|
||||
}
|
||||
|
||||
// IsDailyQuotaPeriodExpired 检查日配额周期是否已过期(用于显示层判断是否需要将 used 归零)
|
||||
func (a *Account) IsDailyQuotaPeriodExpired() bool {
|
||||
start := a.getExtraTime("quota_daily_start")
|
||||
if a.GetQuotaDailyResetMode() == "fixed" {
|
||||
return a.isFixedDailyPeriodExpired(start)
|
||||
}
|
||||
return isPeriodExpired(start, 24*time.Hour)
|
||||
}
|
||||
|
||||
// IsWeeklyQuotaPeriodExpired 检查周配额周期是否已过期(用于显示层判断是否需要将 used 归零)
|
||||
func (a *Account) IsWeeklyQuotaPeriodExpired() bool {
|
||||
start := a.getExtraTime("quota_weekly_start")
|
||||
if a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
return a.isFixedWeeklyPeriodExpired(start)
|
||||
}
|
||||
return isPeriodExpired(start, 7*24*time.Hour)
|
||||
}
|
||||
|
||||
// IsQuotaExceeded 检查 API Key 账号配额是否已超限(任一维度超限即返回 true)
|
||||
func (a *Account) IsQuotaExceeded() bool {
|
||||
// 总额度
|
||||
|
||||
@@ -1742,7 +1742,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: billingModel, // 使用映射模型用于计费和日志
|
||||
Model: originalModel,
|
||||
UpstreamModel: billingModel,
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
@@ -2435,7 +2436,8 @@ handleSuccess:
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: billingModel,
|
||||
Model: originalModel,
|
||||
UpstreamModel: billingModel,
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
|
||||
@@ -542,7 +542,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
|
||||
result, err := svc.Forward(context.Background(), c, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
require.Equal(t, "claude-sonnet-4-5", result.Model)
|
||||
require.Equal(t, mappedModel, result.UpstreamModel)
|
||||
}
|
||||
|
||||
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
|
||||
@@ -594,7 +595,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
require.Equal(t, "gemini-2.5-flash", result.Model)
|
||||
require.Equal(t, mappedModel, result.UpstreamModel)
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
|
||||
@@ -664,7 +666,8 @@ func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignatur
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
require.Equal(t, originalModel, result.Model)
|
||||
require.Equal(t, mappedModel, result.UpstreamModel)
|
||||
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
|
||||
|
||||
firstReq := string(upstream.requestBodies[0])
|
||||
|
||||
@@ -222,10 +222,10 @@ func (s *BillingService) initFallbackPricing() {
|
||||
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{
|
||||
InputPricePerToken: 7.5e-7,
|
||||
OutputPricePerToken: 4.5e-6,
|
||||
CacheReadPricePerToken: 7.5e-8,
|
||||
SupportsCacheBreakdown: false,
|
||||
InputPricePerToken: 7.5e-7,
|
||||
OutputPricePerToken: 4.5e-6,
|
||||
CacheReadPricePerToken: 7.5e-8,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{
|
||||
InputPricePerToken: 2e-7,
|
||||
|
||||
@@ -119,6 +119,7 @@ const (
|
||||
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口
|
||||
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src)
|
||||
SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组)
|
||||
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组)
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
|
||||
@@ -162,6 +162,32 @@ func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID
|
||||
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
mappedModel := "claude-sonnet-4-20250514"
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_models_split",
|
||||
Usage: ClaudeUsage{InputTokens: 10, OutputTokens: 6},
|
||||
Model: "claude-sonnet-4",
|
||||
UpstreamModel: mappedModel,
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
|
||||
@@ -482,10 +482,12 @@ type ClaudeUsage struct {
|
||||
|
||||
// ForwardResult 转发结果
|
||||
type ForwardResult struct {
|
||||
RequestID string
|
||||
Usage ClaudeUsage
|
||||
Model string
|
||||
UpstreamModel string // Actual upstream model after mapping (empty = no mapping)
|
||||
RequestID string
|
||||
Usage ClaudeUsage
|
||||
Model string
|
||||
// UpstreamModel is the actual upstream model after mapping.
|
||||
// Prefer empty when it is identical to Model; persistence normalizes equal values away as no-op mappings.
|
||||
UpstreamModel string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
@@ -4197,7 +4199,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
break
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
|
||||
logger.LegacyPrintf("service.gateway", "[warn] Account %d: thinking blocks have invalid signature, retrying with filtered blocks", account.ID)
|
||||
|
||||
// Conservative two-stage fallback:
|
||||
// 1) Disable thinking + thinking->text (preserve content)
|
||||
@@ -4212,7 +4214,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
if retryResp.StatusCode < 400 {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: thinking block retry succeeded (blocks downgraded)", account.ID)
|
||||
resp = retryResp
|
||||
break
|
||||
}
|
||||
@@ -6102,13 +6104,9 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Log for debugging
|
||||
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Checking error message: %s", msg)
|
||||
|
||||
// 检测signature相关的错误(更宽松的匹配)
|
||||
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
|
||||
if strings.Contains(msg, "signature") {
|
||||
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected signature error")
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -7516,6 +7514,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
|
||||
// 根据请求类型选择计费方式
|
||||
if result.MediaType == "image" || result.MediaType == "video" {
|
||||
@@ -7531,7 +7530,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
if result.MediaType == "image" {
|
||||
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
||||
} else {
|
||||
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
|
||||
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
||||
}
|
||||
} else if result.MediaType == "prompt" {
|
||||
cost = &CostBreakdown{}
|
||||
@@ -7545,7 +7544,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
Price4K: apiKey.Group.ImagePrice4K,
|
||||
}
|
||||
}
|
||||
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
} else {
|
||||
// Token 计费
|
||||
tokens := UsageTokens{
|
||||
@@ -7557,7 +7556,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
@@ -7589,6 +7588,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
RequestedModel: result.Model,
|
||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
@@ -7719,6 +7719,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
|
||||
// 根据请求类型选择计费方式
|
||||
if result.ImageCount > 0 {
|
||||
@@ -7731,7 +7732,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
Price4K: apiKey.Group.ImagePrice4K,
|
||||
}
|
||||
}
|
||||
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
} else {
|
||||
// Token 计费(使用长上下文计费方法)
|
||||
tokens := UsageTokens{
|
||||
@@ -7743,7 +7744,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
@@ -7771,6 +7772,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
RequestedModel: result.Model,
|
||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
|
||||
@@ -1028,14 +1028,15 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: req.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
Stream: req.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1241,12 +1242,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
@@ -1310,12 +1312,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
// Final attempt: surface the upstream error body (passed through below) instead of a generic retry error.
|
||||
@@ -1350,12 +1353,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
RequestID: requestID,
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1527,14 +1531,15 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -15,6 +16,30 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type geminiCompatHTTPUpstreamStub struct {
|
||||
response *http.Response
|
||||
err error
|
||||
calls int
|
||||
lastReq *http.Request
|
||||
}
|
||||
|
||||
func (s *geminiCompatHTTPUpstreamStub) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
s.calls++
|
||||
s.lastReq = req
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
if s.response == nil {
|
||||
return nil, fmt.Errorf("missing stub response")
|
||||
}
|
||||
resp := *s.response
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
||||
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -170,6 +195,42 @@ func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLo
|
||||
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpstreamModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
httpStub := &geminiCompatHTTPUpstreamStub{
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"x-request-id": []string{"gemini-req-1"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}`)),
|
||||
},
|
||||
}
|
||||
svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-key",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4": "claude-sonnet-4-20250514",
|
||||
},
|
||||
},
|
||||
}
|
||||
body := []byte(`{"model":"claude-sonnet-4","max_tokens":16,"messages":[{"role":"user","content":"hello"}]}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "claude-sonnet-4", result.Model)
|
||||
require.Equal(t, "claude-sonnet-4-20250514", result.UpstreamModel)
|
||||
require.Equal(t, 1, httpStub.calls)
|
||||
require.NotNil(t, httpStub.lastReq)
|
||||
require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:")
|
||||
}
|
||||
|
||||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||
claudeReq := map[string]any{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
|
||||
@@ -879,6 +879,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "gpt-5.1", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
|
||||
require.NotNil(t, usageRepo.lastLog.ServiceTier)
|
||||
@@ -894,6 +895,40 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_upstream_model_billing_fallback",
|
||||
Model: "gpt-5.1",
|
||||
UpstreamModel: "gpt-5.1-codex",
|
||||
Usage: usage,
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
|
||||
require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost)
|
||||
require.Equal(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost)
|
||||
require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
|
||||
@@ -4110,9 +4110,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
||||
}
|
||||
|
||||
billingModel := result.Model
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
if result.BillingModel != "" {
|
||||
billingModel = result.BillingModel
|
||||
billingModel = strings.TrimSpace(result.BillingModel)
|
||||
}
|
||||
serviceTier := ""
|
||||
if result.ServiceTier != nil {
|
||||
@@ -4140,6 +4140,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
RequestedModel: result.Model,
|
||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
|
||||
@@ -68,3 +68,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "")
|
||||
if got := normalizeCodexModel(withoutDefault); got != "gpt-5.1" {
|
||||
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withoutDefault, got, "gpt-5.1")
|
||||
}
|
||||
|
||||
withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")
|
||||
if got := normalizeCodexModel(withDefault); got != "gpt-5.4" {
|
||||
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withDefault, got, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2328,6 +2328,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
RequestID: responseID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
ServiceTier: extractOpenAIServiceTier(reqBody),
|
||||
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
|
||||
Stream: reqStream,
|
||||
@@ -2945,6 +2946,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
RequestID: responseID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
UpstreamModel: mappedModel,
|
||||
ServiceTier: extractOpenAIServiceTierFromBody(payload),
|
||||
ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
|
||||
Stream: reqStream,
|
||||
|
||||
@@ -88,6 +88,7 @@ func (s *OpsAlertEvaluatorService) Start() {
|
||||
if s.stopCh == nil {
|
||||
s.stopCh = make(chan struct{})
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go s.run()
|
||||
})
|
||||
}
|
||||
@@ -105,7 +106,6 @@ func (s *OpsAlertEvaluatorService) Stop() {
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) run() {
|
||||
s.wg.Add(1)
|
||||
defer s.wg.Done()
|
||||
|
||||
// Start immediately to produce early feedback in ops dashboard.
|
||||
@@ -848,7 +848,9 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc
|
||||
return nil, false
|
||||
}
|
||||
return func() {
|
||||
_, _ = opsAlertEvaluatorReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result()
|
||||
releaseCtx, releaseCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer releaseCancel()
|
||||
_, _ = opsAlertEvaluatorReleaseScript.Run(releaseCtx, s.redisClient, []string{key}, s.instanceID).Result()
|
||||
}, true
|
||||
}
|
||||
|
||||
|
||||
@@ -150,6 +150,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyPurchaseSubscriptionURL,
|
||||
SettingKeySoraClientEnabled,
|
||||
SettingKeyCustomMenuItems,
|
||||
SettingKeyCustomEndpoints,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
SettingKeyBackendModeEnabled,
|
||||
}
|
||||
@@ -195,6 +196,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||
CustomEndpoints: settings[SettingKeyCustomEndpoints],
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||
}, nil
|
||||
@@ -247,6 +249,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
@@ -272,6 +275,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: s.version,
|
||||
@@ -314,6 +318,18 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
|
||||
return result
|
||||
}
|
||||
|
||||
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
|
||||
func safeRawJSONArray(raw string) json.RawMessage {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return json.RawMessage("[]")
|
||||
}
|
||||
if json.Valid([]byte(raw)) {
|
||||
return json.RawMessage(raw)
|
||||
}
|
||||
return json.RawMessage("[]")
|
||||
}
|
||||
|
||||
// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url
|
||||
// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection.
|
||||
func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) {
|
||||
@@ -454,6 +470,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
|
||||
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
|
||||
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
|
||||
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
|
||||
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
@@ -740,6 +757,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyPurchaseSubscriptionURL: "",
|
||||
SettingKeySoraClientEnabled: "false",
|
||||
SettingKeyCustomMenuItems: "[]",
|
||||
SettingKeyCustomEndpoints: "[]",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeyDefaultSubscriptions: "[]",
|
||||
@@ -805,6 +823,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||
CustomEndpoints: settings[SettingKeyCustomEndpoints],
|
||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ type SystemSettings struct {
|
||||
PurchaseSubscriptionURL string
|
||||
SoraClientEnabled bool
|
||||
CustomMenuItems string // JSON array of custom menu items
|
||||
CustomEndpoints string // JSON array of custom endpoints
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
@@ -104,6 +105,7 @@ type PublicSettings struct {
|
||||
PurchaseSubscriptionURL string
|
||||
SoraClientEnabled bool
|
||||
CustomMenuItems string // JSON array of custom menu items
|
||||
CustomEndpoints string // JSON array of custom endpoints
|
||||
|
||||
LinuxDoOAuthEnabled bool
|
||||
BackendModeEnabled bool
|
||||
|
||||
@@ -148,10 +148,13 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream)
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
originalModel := reqModel
|
||||
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
var upstreamModel string
|
||||
if mappedModel != "" && mappedModel != reqModel {
|
||||
reqModel = mappedModel
|
||||
upstreamModel = mappedModel
|
||||
}
|
||||
|
||||
modelCfg, ok := GetSoraModelConfig(reqModel)
|
||||
@@ -213,13 +216,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
|
||||
}
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Model: reqModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: "prompt",
|
||||
RequestID: "",
|
||||
Model: originalModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: "prompt",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -269,13 +273,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Model: reqModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: "prompt",
|
||||
RequestID: "",
|
||||
Model: originalModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: "prompt",
|
||||
}, nil
|
||||
}
|
||||
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
|
||||
@@ -419,16 +424,17 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: taskID,
|
||||
Model: reqModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: mediaType,
|
||||
MediaURL: firstMediaURL(finalURLs),
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
RequestID: taskID,
|
||||
Model: originalModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: mediaType,
|
||||
MediaURL: firstMediaURL(finalURLs),
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -144,6 +144,11 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"prompt-enhance-short-10s": "prompt-enhance-short-15s",
|
||||
},
|
||||
},
|
||||
}
|
||||
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
|
||||
|
||||
@@ -152,6 +157,7 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "prompt", result.MediaType)
|
||||
require.Equal(t, "prompt-enhance-short-10s", result.Model)
|
||||
require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
|
||||
|
||||
@@ -98,6 +98,9 @@ type UsageLog struct {
|
||||
AccountID int64
|
||||
RequestID string
|
||||
Model string
|
||||
// RequestedModel is the client-requested model name recorded for stable user/admin display.
|
||||
// Empty should be treated as Model for backward compatibility with historical rows.
|
||||
RequestedModel string
|
||||
// UpstreamModel is the actual model sent to the upstream provider after mapping.
|
||||
// Nil means no mapping was applied (requested model was used as-is).
|
||||
UpstreamModel *string
|
||||
|
||||
@@ -19,3 +19,10 @@ func optionalNonEqualStringPtr(value, compare string) *string {
|
||||
}
|
||||
return &value
|
||||
}
|
||||
|
||||
func forwardResultBillingModel(requestedModel, upstreamModel string) string {
|
||||
if trimmedUpstream := strings.TrimSpace(upstreamModel); trimmedUpstream != "" {
|
||||
return trimmedUpstream
|
||||
}
|
||||
return strings.TrimSpace(requestedModel)
|
||||
}
|
||||
|
||||
3
backend/migrations/077_add_usage_log_requested_model.sql
Normal file
3
backend/migrations/077_add_usage_log_requested_model.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
-- Add requested_model field to usage_logs for normalized request/upstream model tracking.
|
||||
-- NULL means historical rows written before requested_model dual-write was introduced.
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100);
|
||||
@@ -0,0 +1,3 @@
|
||||
-- Support requested_model / upstream_model aggregations with time-range filters.
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_created_requested_model_upstream_model
|
||||
ON usage_logs (created_at, requested_model, upstream_model);
|
||||
@@ -4,7 +4,7 @@
|
||||
*/
|
||||
|
||||
import { apiClient } from '../client'
|
||||
import type { CustomMenuItem } from '@/types'
|
||||
import type { CustomMenuItem, CustomEndpoint } from '@/types'
|
||||
|
||||
export interface DefaultSubscriptionSetting {
|
||||
group_id: number
|
||||
@@ -43,6 +43,7 @@ export interface SystemSettings {
|
||||
sora_client_enabled: boolean
|
||||
backend_mode_enabled: boolean
|
||||
custom_menu_items: CustomMenuItem[]
|
||||
custom_endpoints: CustomEndpoint[]
|
||||
// SMTP settings
|
||||
smtp_host: string
|
||||
smtp_port: number
|
||||
@@ -112,6 +113,7 @@ export interface UpdateSettingsRequest {
|
||||
sora_client_enabled?: boolean
|
||||
backend_mode_enabled?: boolean
|
||||
custom_menu_items?: CustomMenuItem[]
|
||||
custom_endpoints?: CustomEndpoint[]
|
||||
smtp_host?: string
|
||||
smtp_port?: number
|
||||
smtp_username?: string
|
||||
|
||||
@@ -39,6 +39,7 @@ const DataTableStub = {
|
||||
template: `
|
||||
<div>
|
||||
<div v-for="row in data" :key="row.request_id">
|
||||
<slot name="cell-model" :row="row" :value="row.model" />
|
||||
<slot name="cell-cost" :row="row" />
|
||||
</div>
|
||||
</div>
|
||||
@@ -108,4 +109,42 @@ describe('admin UsageTable tooltip', () => {
|
||||
expect(text).toContain('$30.0000 / 1M tokens')
|
||||
expect(text).toContain('$0.069568')
|
||||
})
|
||||
|
||||
it('shows requested and upstream models separately for admin rows', () => {
|
||||
const row = {
|
||||
request_id: 'req-admin-model-1',
|
||||
model: 'claude-sonnet-4',
|
||||
upstream_model: 'claude-sonnet-4-20250514',
|
||||
actual_cost: 0,
|
||||
total_cost: 0,
|
||||
account_rate_multiplier: 1,
|
||||
rate_multiplier: 1,
|
||||
input_cost: 0,
|
||||
output_cost: 0,
|
||||
cache_creation_cost: 0,
|
||||
cache_read_cost: 0,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
}
|
||||
|
||||
const wrapper = mount(UsageTable, {
|
||||
props: {
|
||||
data: [row],
|
||||
loading: false,
|
||||
columns: [],
|
||||
},
|
||||
global: {
|
||||
stubs: {
|
||||
DataTable: DataTableStub,
|
||||
EmptyState: true,
|
||||
Icon: true,
|
||||
Teleport: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const text = wrapper.text()
|
||||
expect(text).toContain('claude-sonnet-4')
|
||||
expect(text).toContain('claude-sonnet-4-20250514')
|
||||
})
|
||||
})
|
||||
|
||||
141
frontend/src/components/keys/EndpointPopover.vue
Normal file
141
frontend/src/components/keys/EndpointPopover.vue
Normal file
@@ -0,0 +1,141 @@
|
||||
<script setup lang="ts">
|
||||
import { computed, onBeforeUnmount, ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
import type { CustomEndpoint } from '@/types'
|
||||
|
||||
const props = defineProps<{
|
||||
apiBaseUrl: string
|
||||
customEndpoints: CustomEndpoint[]
|
||||
}>()
|
||||
|
||||
const { t } = useI18n()
|
||||
const { copyToClipboard } = useClipboard()
|
||||
const copiedEndpoint = ref<string | null>(null)
|
||||
|
||||
let copiedResetTimer: number | undefined
|
||||
|
||||
const allEndpoints = computed(() => {
|
||||
const items: Array<{ name: string; endpoint: string; description: string; isDefault: boolean }> = []
|
||||
if (props.apiBaseUrl) {
|
||||
items.push({
|
||||
name: t('keys.endpoints.title'),
|
||||
endpoint: props.apiBaseUrl,
|
||||
description: '',
|
||||
isDefault: true,
|
||||
})
|
||||
}
|
||||
for (const ep of props.customEndpoints) {
|
||||
items.push({ ...ep, isDefault: false })
|
||||
}
|
||||
return items
|
||||
})
|
||||
|
||||
async function copy(url: string) {
|
||||
const success = await copyToClipboard(url, t('keys.endpoints.copied'))
|
||||
if (!success) return
|
||||
|
||||
copiedEndpoint.value = url
|
||||
if (copiedResetTimer !== undefined) {
|
||||
window.clearTimeout(copiedResetTimer)
|
||||
}
|
||||
copiedResetTimer = window.setTimeout(() => {
|
||||
if (copiedEndpoint.value === url) {
|
||||
copiedEndpoint.value = null
|
||||
}
|
||||
}, 1800)
|
||||
}
|
||||
|
||||
function tooltipHint(endpoint: string): string {
|
||||
return copiedEndpoint.value === endpoint
|
||||
? t('keys.endpoints.copiedHint')
|
||||
: t('keys.endpoints.clickToCopy')
|
||||
}
|
||||
|
||||
function speedTestUrl(endpoint: string): string {
|
||||
return `https://www.tcptest.cn/http/${encodeURIComponent(endpoint)}`
|
||||
}
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
if (copiedResetTimer !== undefined) {
|
||||
window.clearTimeout(copiedResetTimer)
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div v-if="allEndpoints.length > 0" class="flex flex-wrap gap-2">
|
||||
<div
|
||||
v-for="(item, index) in allEndpoints"
|
||||
:key="index"
|
||||
class="flex items-center gap-1.5 rounded-lg border border-gray-200 bg-white px-2.5 py-1.5 text-xs transition-colors hover:border-primary-200 dark:border-dark-600 dark:bg-dark-800 dark:hover:border-primary-700"
|
||||
>
|
||||
<span class="font-medium text-gray-600 dark:text-gray-300">{{ item.name }}</span>
|
||||
<span
|
||||
v-if="item.isDefault"
|
||||
class="rounded bg-primary-50 px-1 py-px text-[10px] font-medium leading-tight text-primary-600 dark:bg-primary-900/30 dark:text-primary-400"
|
||||
>{{ t('keys.endpoints.default') }}</span>
|
||||
|
||||
<span class="text-gray-300 dark:text-dark-500">|</span>
|
||||
|
||||
<div class="group/endpoint relative flex items-center gap-1.5">
|
||||
<div
|
||||
class="pointer-events-none absolute bottom-full left-1/2 z-20 mb-2 w-max max-w-[24rem] -translate-x-1/2 translate-y-1 rounded-xl border border-slate-200 bg-white px-3 py-2.5 text-left opacity-0 shadow-[0_14px_36px_-20px_rgba(15,23,42,0.35)] ring-1 ring-slate-200/80 transition-all duration-150 group-hover/endpoint:translate-y-0 group-hover/endpoint:opacity-100 group-focus-within/endpoint:translate-y-0 group-focus-within/endpoint:opacity-100 dark:border-slate-700 dark:bg-slate-900 dark:ring-slate-700/70"
|
||||
>
|
||||
<p
|
||||
v-if="item.description"
|
||||
class="max-w-[24rem] break-words text-xs leading-5 text-slate-600 dark:text-slate-200"
|
||||
>
|
||||
{{ item.description }}
|
||||
</p>
|
||||
<p
|
||||
class="flex items-center gap-1.5 text-[11px] leading-4 text-primary-600 dark:text-primary-300"
|
||||
:class="item.description ? 'mt-1.5' : ''"
|
||||
>
|
||||
<span class="h-1.5 w-1.5 rounded-full bg-primary-500 dark:bg-primary-300"></span>
|
||||
{{ tooltipHint(item.endpoint) }}
|
||||
</p>
|
||||
<div class="absolute left-1/2 top-full h-3 w-3 -translate-x-1/2 -translate-y-1/2 rotate-45 border-b border-r border-slate-200 bg-white dark:border-slate-700 dark:bg-slate-900"></div>
|
||||
</div>
|
||||
|
||||
<code
|
||||
class="cursor-pointer font-mono text-gray-500 decoration-gray-400 decoration-dashed underline-offset-2 hover:text-primary-600 hover:underline focus:text-primary-600 focus:underline focus:outline-none dark:text-gray-400 dark:decoration-gray-500 dark:hover:text-primary-400 dark:focus:text-primary-400"
|
||||
role="button"
|
||||
tabindex="0"
|
||||
@click="copy(item.endpoint)"
|
||||
@keydown.enter.prevent="copy(item.endpoint)"
|
||||
@keydown.space.prevent="copy(item.endpoint)"
|
||||
>{{ item.endpoint }}</code>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="rounded p-0.5 transition-colors"
|
||||
:class="copiedEndpoint === item.endpoint
|
||||
? 'text-emerald-500 dark:text-emerald-400'
|
||||
: 'text-gray-400 hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400'"
|
||||
:aria-label="tooltipHint(item.endpoint)"
|
||||
@click="copy(item.endpoint)"
|
||||
>
|
||||
<svg v-if="copiedEndpoint === item.endpoint" class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2.2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M5 13l4 4L19 7" />
|
||||
</svg>
|
||||
<svg v-else class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z" />
|
||||
</svg>
|
||||
</button>
|
||||
|
||||
<a
|
||||
:href="speedTestUrl(item.endpoint)"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="rounded p-0.5 text-gray-400 transition-colors hover:text-amber-500 dark:text-gray-500 dark:hover:text-amber-400"
|
||||
:title="t('keys.endpoints.speedTest')"
|
||||
>
|
||||
<svg class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M13 10V3L4 14h7v7l9-11h-7z" />
|
||||
</svg>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
@@ -0,0 +1,69 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { flushPromises, mount } from '@vue/test-utils'
|
||||
|
||||
const copyToClipboard = vi.fn().mockResolvedValue(true)
|
||||
|
||||
const messages: Record<string, string> = {
|
||||
'keys.endpoints.title': 'API 端点',
|
||||
'keys.endpoints.default': '默认',
|
||||
'keys.endpoints.copied': '已复制',
|
||||
'keys.endpoints.copiedHint': '已复制到剪贴板',
|
||||
'keys.endpoints.clickToCopy': '点击可复制此端点',
|
||||
'keys.endpoints.speedTest': '测速',
|
||||
}
|
||||
|
||||
vi.mock('vue-i18n', () => ({
|
||||
useI18n: () => ({
|
||||
t: (key: string) => messages[key] ?? key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useClipboard', () => ({
|
||||
useClipboard: () => ({
|
||||
copyToClipboard,
|
||||
}),
|
||||
}))
|
||||
|
||||
import EndpointPopover from '../EndpointPopover.vue'
|
||||
|
||||
describe('EndpointPopover', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('将说明提示渲染到 URL 上方而不是旧的 title 图标上', () => {
|
||||
const wrapper = mount(EndpointPopover, {
|
||||
props: {
|
||||
apiBaseUrl: 'https://default.example.com/v1',
|
||||
customEndpoints: [
|
||||
{
|
||||
name: '备用线路',
|
||||
endpoint: 'https://backup.example.com/v1',
|
||||
description: '自定义说明',
|
||||
},
|
||||
],
|
||||
},
|
||||
})
|
||||
|
||||
expect(wrapper.text()).toContain('自定义说明')
|
||||
expect(wrapper.text()).toContain('点击可复制此端点')
|
||||
expect(wrapper.find('[role="button"]').attributes('title')).toBeUndefined()
|
||||
expect(wrapper.find('[title="自定义说明"]').exists()).toBe(false)
|
||||
})
|
||||
|
||||
it('点击 URL 后会复制并切换为已复制提示', async () => {
|
||||
const wrapper = mount(EndpointPopover, {
|
||||
props: {
|
||||
apiBaseUrl: 'https://default.example.com/v1',
|
||||
customEndpoints: [],
|
||||
},
|
||||
})
|
||||
|
||||
await wrapper.find('[role="button"]').trigger('click')
|
||||
await flushPromises()
|
||||
|
||||
expect(copyToClipboard).toHaveBeenCalledWith('https://default.example.com/v1', '已复制')
|
||||
expect(wrapper.text()).toContain('已复制到剪贴板')
|
||||
expect(wrapper.find('button[aria-label="已复制到剪贴板"]').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
@@ -533,6 +533,14 @@ export default {
|
||||
title: 'API Keys',
|
||||
description: 'Manage your API keys and access tokens',
|
||||
searchPlaceholder: 'Search name or key...',
|
||||
endpoints: {
|
||||
title: 'API Endpoints',
|
||||
default: 'Default',
|
||||
copied: 'Copied',
|
||||
copiedHint: 'Copied to clipboard',
|
||||
clickToCopy: 'Click to copy this endpoint',
|
||||
speedTest: 'Speed Test',
|
||||
},
|
||||
allGroups: 'All Groups',
|
||||
allStatus: 'All Status',
|
||||
createKey: 'Create API Key',
|
||||
@@ -4162,6 +4170,18 @@ export default {
|
||||
apiBaseUrlPlaceholder: 'https://api.example.com',
|
||||
apiBaseUrlHint:
|
||||
'Used for "Use Key" and "Import to CC Switch" features. Leave empty to use current site URL.',
|
||||
customEndpoints: {
|
||||
title: 'Custom Endpoints',
|
||||
description: 'Add additional API endpoint URLs for users to quickly copy on the API Keys page',
|
||||
itemLabel: 'Endpoint #{n}',
|
||||
name: 'Name',
|
||||
namePlaceholder: 'e.g., OpenAI Compatible',
|
||||
endpointUrl: 'Endpoint URL',
|
||||
endpointUrlPlaceholder: 'https://api2.example.com',
|
||||
descriptionLabel: 'Description',
|
||||
descriptionPlaceholder: 'e.g., Supports OpenAI format requests',
|
||||
add: 'Add Endpoint',
|
||||
},
|
||||
contactInfo: 'Contact Info',
|
||||
contactInfoPlaceholder: 'e.g., QQ: 123456789',
|
||||
contactInfoHint: 'Customer support contact info, displayed on redeem page, profile, etc.',
|
||||
|
||||
@@ -533,6 +533,14 @@ export default {
|
||||
title: 'API 密钥',
|
||||
description: '管理您的 API 密钥和访问令牌',
|
||||
searchPlaceholder: '搜索名称或Key...',
|
||||
endpoints: {
|
||||
title: 'API 端点',
|
||||
default: '默认',
|
||||
copied: '已复制',
|
||||
copiedHint: '已复制到剪贴板',
|
||||
clickToCopy: '点击可复制此端点',
|
||||
speedTest: '测速',
|
||||
},
|
||||
allGroups: '全部分组',
|
||||
allStatus: '全部状态',
|
||||
createKey: '创建密钥',
|
||||
@@ -4324,6 +4332,18 @@ export default {
|
||||
apiBaseUrl: 'API 端点地址',
|
||||
apiBaseUrlHint: '用于"使用密钥"和"导入到 CC Switch"功能,留空则使用当前站点地址',
|
||||
apiBaseUrlPlaceholder: 'https://api.example.com',
|
||||
customEndpoints: {
|
||||
title: '自定义端点',
|
||||
description: '添加额外的 API 端点地址,用户可在「API Keys」页面快速复制',
|
||||
itemLabel: '端点 #{n}',
|
||||
name: '名称',
|
||||
namePlaceholder: '如:OpenAI Compatible',
|
||||
endpointUrl: '端点地址',
|
||||
endpointUrlPlaceholder: 'https://api2.example.com',
|
||||
descriptionLabel: '介绍',
|
||||
descriptionPlaceholder: '如:支持 OpenAI 格式请求',
|
||||
add: '添加端点',
|
||||
},
|
||||
contactInfo: '客服联系方式',
|
||||
contactInfoPlaceholder: '例如:QQ: 123456789',
|
||||
contactInfoHint: '填写客服联系方式,将展示在兑换页面、个人资料等位置',
|
||||
|
||||
@@ -330,6 +330,7 @@ export const useAppStore = defineStore('app', () => {
|
||||
purchase_subscription_enabled: false,
|
||||
purchase_subscription_url: '',
|
||||
custom_menu_items: [],
|
||||
custom_endpoints: [],
|
||||
linuxdo_oauth_enabled: false,
|
||||
sora_client_enabled: false,
|
||||
backend_mode_enabled: false,
|
||||
|
||||
@@ -84,6 +84,12 @@ export interface CustomMenuItem {
|
||||
sort_order: number
|
||||
}
|
||||
|
||||
export interface CustomEndpoint {
|
||||
name: string
|
||||
endpoint: string
|
||||
description: string
|
||||
}
|
||||
|
||||
export interface PublicSettings {
|
||||
registration_enabled: boolean
|
||||
email_verify_enabled: boolean
|
||||
@@ -104,6 +110,7 @@ export interface PublicSettings {
|
||||
purchase_subscription_enabled: boolean
|
||||
purchase_subscription_url: string
|
||||
custom_menu_items: CustomMenuItem[]
|
||||
custom_endpoints: CustomEndpoint[]
|
||||
linuxdo_oauth_enabled: boolean
|
||||
sora_client_enabled: boolean
|
||||
backend_mode_enabled: boolean
|
||||
@@ -978,7 +985,6 @@ export interface UsageLog {
|
||||
account_id: number | null
|
||||
request_id: string
|
||||
model: string
|
||||
upstream_model?: string | null
|
||||
service_tier?: string | null
|
||||
reasoning_effort?: string | null
|
||||
inbound_endpoint?: string | null
|
||||
@@ -1033,6 +1039,8 @@ export interface UsageLogAccountSummary {
|
||||
}
|
||||
|
||||
export interface AdminUsageLog extends UsageLog {
|
||||
upstream_model?: string | null
|
||||
|
||||
// 账号计费倍率(仅管理员可见)
|
||||
account_rate_multiplier?: number | null
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Settings Form -->
|
||||
<form v-else @submit.prevent="saveSettings" class="space-y-6">
|
||||
<form v-else @submit.prevent="saveSettings" class="space-y-6" novalidate>
|
||||
<!-- Tab Navigation -->
|
||||
<div class="sticky top-0 z-10 overflow-x-auto settings-tabs-scroll">
|
||||
<nav class="settings-tabs">
|
||||
@@ -1248,6 +1248,81 @@
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Custom Endpoints -->
|
||||
<div>
|
||||
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{{ t('admin.settings.site.customEndpoints.title') }}
|
||||
</label>
|
||||
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.settings.site.customEndpoints.description') }}
|
||||
</p>
|
||||
|
||||
<div class="space-y-3">
|
||||
<div
|
||||
v-for="(ep, index) in form.custom_endpoints"
|
||||
:key="index"
|
||||
class="rounded-lg border border-gray-200 p-4 dark:border-dark-600"
|
||||
>
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<span class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{{ t('admin.settings.site.customEndpoints.itemLabel', { n: index + 1 }) }}
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
class="rounded p-1 text-red-400 hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
|
||||
@click="removeEndpoint(index)"
|
||||
>
|
||||
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2"><path stroke-linecap="round" stroke-linejoin="round" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" /></svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="grid grid-cols-1 gap-3 sm:grid-cols-2">
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.settings.site.customEndpoints.name') }}
|
||||
</label>
|
||||
<input
|
||||
v-model="ep.name"
|
||||
type="text"
|
||||
class="input text-sm"
|
||||
:placeholder="t('admin.settings.site.customEndpoints.namePlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.settings.site.customEndpoints.endpointUrl') }}
|
||||
</label>
|
||||
<input
|
||||
v-model="ep.endpoint"
|
||||
type="url"
|
||||
class="input font-mono text-sm"
|
||||
:placeholder="t('admin.settings.site.customEndpoints.endpointUrlPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<div class="sm:col-span-2">
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">
|
||||
{{ t('admin.settings.site.customEndpoints.descriptionLabel') }}
|
||||
</label>
|
||||
<input
|
||||
v-model="ep.description"
|
||||
type="text"
|
||||
class="input text-sm"
|
||||
:placeholder="t('admin.settings.site.customEndpoints.descriptionPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="mt-3 flex w-full items-center justify-center gap-2 rounded-lg border-2 border-dashed border-gray-300 px-4 py-2.5 text-sm text-gray-500 transition-colors hover:border-primary-400 hover:text-primary-600 dark:border-dark-600 dark:text-gray-400 dark:hover:border-primary-500 dark:hover:text-primary-400"
|
||||
@click="addEndpoint"
|
||||
>
|
||||
<svg class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2"><path stroke-linecap="round" stroke-linejoin="round" d="M12 4v16m8-8H4" /></svg>
|
||||
{{ t('admin.settings.site.customEndpoints.add') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Contact Info -->
|
||||
<div>
|
||||
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
@@ -1945,6 +2020,7 @@ const form = reactive<SettingsForm>({
|
||||
purchase_subscription_url: '',
|
||||
sora_client_enabled: false,
|
||||
custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>,
|
||||
custom_endpoints: [] as Array<{name: string; endpoint: string; description: string}>,
|
||||
frontend_url: '',
|
||||
smtp_host: '',
|
||||
smtp_port: 587,
|
||||
@@ -2114,6 +2190,15 @@ function moveMenuItem(index: number, direction: -1 | 1) {
|
||||
})
|
||||
}
|
||||
|
||||
// Custom endpoint management
|
||||
function addEndpoint() {
|
||||
form.custom_endpoints.push({ name: '', endpoint: '', description: '' })
|
||||
}
|
||||
|
||||
function removeEndpoint(index: number) {
|
||||
form.custom_endpoints.splice(index, 1)
|
||||
}
|
||||
|
||||
async function loadSettings() {
|
||||
loading.value = true
|
||||
try {
|
||||
@@ -2198,6 +2283,35 @@ async function saveSettings() {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate URL fields — novalidate disables browser-native checks, so we validate here
|
||||
const isValidHttpUrl = (url: string): boolean => {
|
||||
if (!url) return true
|
||||
try {
|
||||
const u = new URL(url)
|
||||
return u.protocol === 'http:' || u.protocol === 'https:'
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
// Optional URL fields: auto-clear invalid values so they don't cause backend 400 errors
|
||||
if (!isValidHttpUrl(form.frontend_url)) form.frontend_url = ''
|
||||
if (!isValidHttpUrl(form.doc_url)) form.doc_url = ''
|
||||
// Purchase URL: required when enabled; auto-clear when disabled to avoid backend rejection
|
||||
if (form.purchase_subscription_enabled) {
|
||||
if (!form.purchase_subscription_url) {
|
||||
appStore.showError(t('admin.settings.purchase.url') + ': URL is required when purchase is enabled')
|
||||
saving.value = false
|
||||
return
|
||||
}
|
||||
if (!isValidHttpUrl(form.purchase_subscription_url)) {
|
||||
appStore.showError(t('admin.settings.purchase.url') + ': must be an absolute http(s) URL (e.g. https://example.com)')
|
||||
saving.value = false
|
||||
return
|
||||
}
|
||||
} else if (!isValidHttpUrl(form.purchase_subscription_url)) {
|
||||
form.purchase_subscription_url = ''
|
||||
}
|
||||
|
||||
const payload: UpdateSettingsRequest = {
|
||||
registration_enabled: form.registration_enabled,
|
||||
email_verify_enabled: form.email_verify_enabled,
|
||||
@@ -2224,6 +2338,7 @@ async function saveSettings() {
|
||||
purchase_subscription_url: form.purchase_subscription_url,
|
||||
sora_client_enabled: form.sora_client_enabled,
|
||||
custom_menu_items: form.custom_menu_items,
|
||||
custom_endpoints: form.custom_endpoints,
|
||||
frontend_url: form.frontend_url,
|
||||
smtp_host: form.smtp_host,
|
||||
smtp_port: form.smtp_port,
|
||||
|
||||
@@ -2,24 +2,31 @@
|
||||
<AppLayout>
|
||||
<TablePageLayout>
|
||||
<template #filters>
|
||||
<div class="flex flex-wrap items-center gap-3">
|
||||
<SearchInput
|
||||
v-model="filterSearch"
|
||||
:placeholder="t('keys.searchPlaceholder')"
|
||||
class="w-full sm:w-64"
|
||||
@search="onFilterChange"
|
||||
/>
|
||||
<Select
|
||||
:model-value="filterGroupId"
|
||||
class="w-40"
|
||||
:options="groupFilterOptions"
|
||||
@update:model-value="onGroupFilterChange"
|
||||
/>
|
||||
<Select
|
||||
:model-value="filterStatus"
|
||||
class="w-40"
|
||||
:options="statusFilterOptions"
|
||||
@update:model-value="onStatusFilterChange"
|
||||
<div class="flex flex-col gap-3">
|
||||
<div class="flex flex-wrap items-center gap-3">
|
||||
<SearchInput
|
||||
v-model="filterSearch"
|
||||
:placeholder="t('keys.searchPlaceholder')"
|
||||
class="w-full sm:w-64"
|
||||
@search="onFilterChange"
|
||||
/>
|
||||
<Select
|
||||
:model-value="filterGroupId"
|
||||
class="w-40"
|
||||
:options="groupFilterOptions"
|
||||
@update:model-value="onGroupFilterChange"
|
||||
/>
|
||||
<Select
|
||||
:model-value="filterStatus"
|
||||
class="w-40"
|
||||
:options="statusFilterOptions"
|
||||
@update:model-value="onStatusFilterChange"
|
||||
/>
|
||||
</div>
|
||||
<EndpointPopover
|
||||
v-if="publicSettings?.api_base_url || (publicSettings?.custom_endpoints?.length ?? 0) > 0"
|
||||
:api-base-url="publicSettings?.api_base_url || ''"
|
||||
:custom-endpoints="publicSettings?.custom_endpoints || []"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
@@ -1050,6 +1057,7 @@ import TablePageLayout from '@/components/layout/TablePageLayout.vue'
|
||||
import SearchInput from '@/components/common/SearchInput.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import UseKeyModal from '@/components/keys/UseKeyModal.vue'
|
||||
import EndpointPopover from '@/components/keys/EndpointPopover.vue'
|
||||
import GroupBadge from '@/components/common/GroupBadge.vue'
|
||||
import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
|
||||
import type { ApiKey, Group, PublicSettings, SubscriptionType, GroupPlatform } from '@/types'
|
||||
|
||||
Reference in New Issue
Block a user