diff --git a/backend/ent/account.go b/backend/ent/account.go index c77002b3..2dbfc3a2 100644 --- a/backend/ent/account.go +++ b/backend/ent/account.go @@ -41,6 +41,8 @@ type Account struct { ProxyID *int64 `json:"proxy_id,omitempty"` // Concurrency holds the value of the "concurrency" field. Concurrency int `json:"concurrency,omitempty"` + // LoadFactor holds the value of the "load_factor" field. + LoadFactor *int `json:"load_factor,omitempty"` // Priority holds the value of the "priority" field. Priority int `json:"priority,omitempty"` // RateMultiplier holds the value of the "rate_multiplier" field. @@ -143,7 +145,7 @@ func (*Account) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case account.FieldRateMultiplier: values[i] = new(sql.NullFloat64) - case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority: + case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldLoadFactor, account.FieldPriority: values[i] = new(sql.NullInt64) case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus: values[i] = new(sql.NullString) @@ -243,6 +245,13 @@ func (_m *Account) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Concurrency = int(value.Int64) } + case account.FieldLoadFactor: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field load_factor", values[i]) + } else if value.Valid { + _m.LoadFactor = new(int) + *_m.LoadFactor = int(value.Int64) + } case account.FieldPriority: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field priority", values[i]) @@ -445,6 +454,11 @@ func (_m *Account) String() string { builder.WriteString("concurrency=") builder.WriteString(fmt.Sprintf("%v", _m.Concurrency)) builder.WriteString(", ") + if v := _m.LoadFactor; v != nil { + builder.WriteString("load_factor=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") builder.WriteString("priority=") builder.WriteString(fmt.Sprintf("%v", _m.Priority)) builder.WriteString(", ") diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go index 1fc34620..4c134649 100644 --- a/backend/ent/account/account.go +++ b/backend/ent/account/account.go @@ -37,6 +37,8 @@ const ( FieldProxyID = "proxy_id" // FieldConcurrency holds the string denoting the concurrency field in the database. FieldConcurrency = "concurrency" + // FieldLoadFactor holds the string denoting the load_factor field in the database. + FieldLoadFactor = "load_factor" // FieldPriority holds the string denoting the priority field in the database. FieldPriority = "priority" // FieldRateMultiplier holds the string denoting the rate_multiplier field in the database. @@ -121,6 +123,7 @@ var Columns = []string{ FieldExtra, FieldProxyID, FieldConcurrency, + FieldLoadFactor, FieldPriority, FieldRateMultiplier, FieldStatus, @@ -250,6 +253,11 @@ func ByConcurrency(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldConcurrency, opts...).ToFunc() } +// ByLoadFactor orders the results by the load_factor field. +func ByLoadFactor(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLoadFactor, opts...).ToFunc() +} + // ByPriority orders the results by the priority field. func ByPriority(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldPriority, opts...).ToFunc() diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go index 54db1dcb..3749b45c 100644 --- a/backend/ent/account/where.go +++ b/backend/ent/account/where.go @@ -100,6 +100,11 @@ func Concurrency(v int) predicate.Account { return predicate.Account(sql.FieldEQ(FieldConcurrency, v)) } +// LoadFactor applies equality check predicate on the "load_factor" field. It's identical to LoadFactorEQ. +func LoadFactor(v int) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldLoadFactor, v)) +} + // Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ. func Priority(v int) predicate.Account { return predicate.Account(sql.FieldEQ(FieldPriority, v)) @@ -650,6 +655,56 @@ func ConcurrencyLTE(v int) predicate.Account { return predicate.Account(sql.FieldLTE(FieldConcurrency, v)) } +// LoadFactorEQ applies the EQ predicate on the "load_factor" field. +func LoadFactorEQ(v int) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldLoadFactor, v)) +} + +// LoadFactorNEQ applies the NEQ predicate on the "load_factor" field. +func LoadFactorNEQ(v int) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldLoadFactor, v)) +} + +// LoadFactorIn applies the In predicate on the "load_factor" field. +func LoadFactorIn(vs ...int) predicate.Account { + return predicate.Account(sql.FieldIn(FieldLoadFactor, vs...)) +} + +// LoadFactorNotIn applies the NotIn predicate on the "load_factor" field. +func LoadFactorNotIn(vs ...int) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldLoadFactor, vs...)) +} + +// LoadFactorGT applies the GT predicate on the "load_factor" field. +func LoadFactorGT(v int) predicate.Account { + return predicate.Account(sql.FieldGT(FieldLoadFactor, v)) +} + +// LoadFactorGTE applies the GTE predicate on the "load_factor" field. +func LoadFactorGTE(v int) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldLoadFactor, v)) +} + +// LoadFactorLT applies the LT predicate on the "load_factor" field. +func LoadFactorLT(v int) predicate.Account { + return predicate.Account(sql.FieldLT(FieldLoadFactor, v)) +} + +// LoadFactorLTE applies the LTE predicate on the "load_factor" field. +func LoadFactorLTE(v int) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldLoadFactor, v)) +} + +// LoadFactorIsNil applies the IsNil predicate on the "load_factor" field. +func LoadFactorIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldLoadFactor)) +} + +// LoadFactorNotNil applies the NotNil predicate on the "load_factor" field. +func LoadFactorNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldLoadFactor)) +} + // PriorityEQ applies the EQ predicate on the "priority" field. func PriorityEQ(v int) predicate.Account { return predicate.Account(sql.FieldEQ(FieldPriority, v)) diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go index 963ffee8..d6046c79 100644 --- a/backend/ent/account_create.go +++ b/backend/ent/account_create.go @@ -139,6 +139,20 @@ func (_c *AccountCreate) SetNillableConcurrency(v *int) *AccountCreate { return _c } +// SetLoadFactor sets the "load_factor" field. +func (_c *AccountCreate) SetLoadFactor(v int) *AccountCreate { + _c.mutation.SetLoadFactor(v) + return _c +} + +// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil. +func (_c *AccountCreate) SetNillableLoadFactor(v *int) *AccountCreate { + if v != nil { + _c.SetLoadFactor(*v) + } + return _c +} + // SetPriority sets the "priority" field. func (_c *AccountCreate) SetPriority(v int) *AccountCreate { _c.mutation.SetPriority(v) @@ -623,6 +637,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { _spec.SetField(account.FieldConcurrency, field.TypeInt, value) _node.Concurrency = value } + if value, ok := _c.mutation.LoadFactor(); ok { + _spec.SetField(account.FieldLoadFactor, field.TypeInt, value) + _node.LoadFactor = &value + } if value, ok := _c.mutation.Priority(); ok { _spec.SetField(account.FieldPriority, field.TypeInt, value) _node.Priority = value @@ -936,6 +954,30 @@ func (u *AccountUpsert) AddConcurrency(v int) *AccountUpsert { return u } +// SetLoadFactor sets the "load_factor" field. +func (u *AccountUpsert) SetLoadFactor(v int) *AccountUpsert { + u.Set(account.FieldLoadFactor, v) + return u +} + +// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create. +func (u *AccountUpsert) UpdateLoadFactor() *AccountUpsert { + u.SetExcluded(account.FieldLoadFactor) + return u +} + +// AddLoadFactor adds v to the "load_factor" field. +func (u *AccountUpsert) AddLoadFactor(v int) *AccountUpsert { + u.Add(account.FieldLoadFactor, v) + return u +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (u *AccountUpsert) ClearLoadFactor() *AccountUpsert { + u.SetNull(account.FieldLoadFactor) + return u +} + // SetPriority sets the "priority" field. func (u *AccountUpsert) SetPriority(v int) *AccountUpsert { u.Set(account.FieldPriority, v) @@ -1419,6 +1461,34 @@ func (u *AccountUpsertOne) UpdateConcurrency() *AccountUpsertOne { }) } +// SetLoadFactor sets the "load_factor" field. +func (u *AccountUpsertOne) SetLoadFactor(v int) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetLoadFactor(v) + }) +} + +// AddLoadFactor adds v to the "load_factor" field. +func (u *AccountUpsertOne) AddLoadFactor(v int) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.AddLoadFactor(v) + }) +} + +// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateLoadFactor() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateLoadFactor() + }) +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (u *AccountUpsertOne) ClearLoadFactor() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearLoadFactor() + }) +} + // SetPriority sets the "priority" field. func (u *AccountUpsertOne) SetPriority(v int) *AccountUpsertOne { return u.Update(func(s *AccountUpsert) { @@ -2113,6 +2183,34 @@ func (u *AccountUpsertBulk) UpdateConcurrency() *AccountUpsertBulk { }) } +// SetLoadFactor sets the "load_factor" field. +func (u *AccountUpsertBulk) SetLoadFactor(v int) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetLoadFactor(v) + }) +} + +// AddLoadFactor adds v to the "load_factor" field. +func (u *AccountUpsertBulk) AddLoadFactor(v int) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.AddLoadFactor(v) + }) +} + +// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateLoadFactor() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateLoadFactor() + }) +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (u *AccountUpsertBulk) ClearLoadFactor() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearLoadFactor() + }) +} + // SetPriority sets the "priority" field. func (u *AccountUpsertBulk) SetPriority(v int) *AccountUpsertBulk { return u.Update(func(s *AccountUpsert) { diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go index 875888e0..6f443c65 100644 --- a/backend/ent/account_update.go +++ b/backend/ent/account_update.go @@ -172,6 +172,33 @@ func (_u *AccountUpdate) AddConcurrency(v int) *AccountUpdate { return _u } +// SetLoadFactor sets the "load_factor" field. +func (_u *AccountUpdate) SetLoadFactor(v int) *AccountUpdate { + _u.mutation.ResetLoadFactor() + _u.mutation.SetLoadFactor(v) + return _u +} + +// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableLoadFactor(v *int) *AccountUpdate { + if v != nil { + _u.SetLoadFactor(*v) + } + return _u +} + +// AddLoadFactor adds value to the "load_factor" field. +func (_u *AccountUpdate) AddLoadFactor(v int) *AccountUpdate { + _u.mutation.AddLoadFactor(v) + return _u +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (_u *AccountUpdate) ClearLoadFactor() *AccountUpdate { + _u.mutation.ClearLoadFactor() + return _u +} + // SetPriority sets the "priority" field. func (_u *AccountUpdate) SetPriority(v int) *AccountUpdate { _u.mutation.ResetPriority() @@ -684,6 +711,15 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AddedConcurrency(); ok { _spec.AddField(account.FieldConcurrency, field.TypeInt, value) } + if value, ok := _u.mutation.LoadFactor(); ok { + _spec.SetField(account.FieldLoadFactor, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedLoadFactor(); ok { + _spec.AddField(account.FieldLoadFactor, field.TypeInt, value) + } + if _u.mutation.LoadFactorCleared() { + _spec.ClearField(account.FieldLoadFactor, field.TypeInt) + } if value, ok := _u.mutation.Priority(); ok { _spec.SetField(account.FieldPriority, field.TypeInt, value) } @@ -1063,6 +1099,33 @@ func (_u *AccountUpdateOne) AddConcurrency(v int) *AccountUpdateOne { return _u } +// SetLoadFactor sets the "load_factor" field. +func (_u *AccountUpdateOne) SetLoadFactor(v int) *AccountUpdateOne { + _u.mutation.ResetLoadFactor() + _u.mutation.SetLoadFactor(v) + return _u +} + +// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableLoadFactor(v *int) *AccountUpdateOne { + if v != nil { + _u.SetLoadFactor(*v) + } + return _u +} + +// AddLoadFactor adds value to the "load_factor" field. +func (_u *AccountUpdateOne) AddLoadFactor(v int) *AccountUpdateOne { + _u.mutation.AddLoadFactor(v) + return _u +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (_u *AccountUpdateOne) ClearLoadFactor() *AccountUpdateOne { + _u.mutation.ClearLoadFactor() + return _u +} + // SetPriority sets the "priority" field. func (_u *AccountUpdateOne) SetPriority(v int) *AccountUpdateOne { _u.mutation.ResetPriority() @@ -1605,6 +1668,15 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er if value, ok := _u.mutation.AddedConcurrency(); ok { _spec.AddField(account.FieldConcurrency, field.TypeInt, value) } + if value, ok := _u.mutation.LoadFactor(); ok { + _spec.SetField(account.FieldLoadFactor, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedLoadFactor(); ok { + _spec.AddField(account.FieldLoadFactor, field.TypeInt, value) + } + if _u.mutation.LoadFactorCleared() { + _spec.ClearField(account.FieldLoadFactor, field.TypeInt) + } if value, ok := _u.mutation.Priority(); ok { _spec.SetField(account.FieldPriority, field.TypeInt, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 85e94072..8e54f31c 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -106,6 +106,7 @@ var ( {Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "concurrency", Type: field.TypeInt, Default: 3}, + {Name: "load_factor", Type: field.TypeInt, Nullable: true}, {Name: "priority", Type: field.TypeInt, Default: 50}, {Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, @@ -132,7 +133,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "accounts_proxies_proxy", - Columns: []*schema.Column{AccountsColumns[27]}, + Columns: []*schema.Column{AccountsColumns[28]}, RefColumns: []*schema.Column{ProxiesColumns[0]}, OnDelete: schema.SetNull, }, @@ -151,52 +152,52 @@ var ( { Name: "account_status", Unique: false, - Columns: []*schema.Column{AccountsColumns[13]}, + Columns: []*schema.Column{AccountsColumns[14]}, }, { Name: "account_proxy_id", Unique: false, - Columns: []*schema.Column{AccountsColumns[27]}, + Columns: []*schema.Column{AccountsColumns[28]}, }, { Name: "account_priority", Unique: false, - Columns: []*schema.Column{AccountsColumns[11]}, + Columns: []*schema.Column{AccountsColumns[12]}, }, { Name: "account_last_used_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[15]}, + Columns: []*schema.Column{AccountsColumns[16]}, }, { Name: "account_schedulable", Unique: false, - Columns: []*schema.Column{AccountsColumns[18]}, + Columns: []*schema.Column{AccountsColumns[19]}, }, { Name: "account_rate_limited_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[19]}, + Columns: []*schema.Column{AccountsColumns[20]}, }, { Name: "account_rate_limit_reset_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[20]}, + Columns: []*schema.Column{AccountsColumns[21]}, }, { Name: "account_overload_until", Unique: false, - Columns: []*schema.Column{AccountsColumns[21]}, + Columns: []*schema.Column{AccountsColumns[22]}, }, { Name: "account_platform_priority", Unique: false, - Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]}, + Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[12]}, }, { Name: "account_priority_status", Unique: false, - Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]}, + Columns: []*schema.Column{AccountsColumns[12], AccountsColumns[14]}, }, { Name: "account_deleted_at", diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 85e2ea71..6c6194a6 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -2260,6 +2260,8 @@ type AccountMutation struct { extra *map[string]interface{} concurrency *int addconcurrency *int + load_factor *int + addload_factor *int priority *int addpriority *int rate_multiplier *float64 @@ -2845,6 +2847,76 @@ func (m *AccountMutation) ResetConcurrency() { m.addconcurrency = nil } +// SetLoadFactor sets the "load_factor" field. +func (m *AccountMutation) SetLoadFactor(i int) { + m.load_factor = &i + m.addload_factor = nil +} + +// LoadFactor returns the value of the "load_factor" field in the mutation. +func (m *AccountMutation) LoadFactor() (r int, exists bool) { + v := m.load_factor + if v == nil { + return + } + return *v, true +} + +// OldLoadFactor returns the old "load_factor" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldLoadFactor(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLoadFactor is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLoadFactor requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLoadFactor: %w", err) + } + return oldValue.LoadFactor, nil +} + +// AddLoadFactor adds i to the "load_factor" field. +func (m *AccountMutation) AddLoadFactor(i int) { + if m.addload_factor != nil { + *m.addload_factor += i + } else { + m.addload_factor = &i + } +} + +// AddedLoadFactor returns the value that was added to the "load_factor" field in this mutation. +func (m *AccountMutation) AddedLoadFactor() (r int, exists bool) { + v := m.addload_factor + if v == nil { + return + } + return *v, true +} + +// ClearLoadFactor clears the value of the "load_factor" field. +func (m *AccountMutation) ClearLoadFactor() { + m.load_factor = nil + m.addload_factor = nil + m.clearedFields[account.FieldLoadFactor] = struct{}{} +} + +// LoadFactorCleared returns if the "load_factor" field was cleared in this mutation. +func (m *AccountMutation) LoadFactorCleared() bool { + _, ok := m.clearedFields[account.FieldLoadFactor] + return ok +} + +// ResetLoadFactor resets all changes to the "load_factor" field. +func (m *AccountMutation) ResetLoadFactor() { + m.load_factor = nil + m.addload_factor = nil + delete(m.clearedFields, account.FieldLoadFactor) +} + // SetPriority sets the "priority" field. func (m *AccountMutation) SetPriority(i int) { m.priority = &i @@ -3773,7 +3845,7 @@ func (m *AccountMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AccountMutation) Fields() []string { - fields := make([]string, 0, 27) + fields := make([]string, 0, 28) if m.created_at != nil { fields = append(fields, account.FieldCreatedAt) } @@ -3807,6 +3879,9 @@ func (m *AccountMutation) Fields() []string { if m.concurrency != nil { fields = append(fields, account.FieldConcurrency) } + if m.load_factor != nil { + fields = append(fields, account.FieldLoadFactor) + } if m.priority != nil { fields = append(fields, account.FieldPriority) } @@ -3885,6 +3960,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) { return m.ProxyID() case account.FieldConcurrency: return m.Concurrency() + case account.FieldLoadFactor: + return m.LoadFactor() case account.FieldPriority: return m.Priority() case account.FieldRateMultiplier: @@ -3948,6 +4025,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldProxyID(ctx) case account.FieldConcurrency: return m.OldConcurrency(ctx) + case account.FieldLoadFactor: + return m.OldLoadFactor(ctx) case account.FieldPriority: return m.OldPriority(ctx) case account.FieldRateMultiplier: @@ -4066,6 +4145,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error { } m.SetConcurrency(v) return nil + case account.FieldLoadFactor: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLoadFactor(v) + return nil case account.FieldPriority: v, ok := value.(int) if !ok { @@ -4189,6 +4275,9 @@ func (m *AccountMutation) AddedFields() []string { if m.addconcurrency != nil { fields = append(fields, account.FieldConcurrency) } + if m.addload_factor != nil { + fields = append(fields, account.FieldLoadFactor) + } if m.addpriority != nil { fields = append(fields, account.FieldPriority) } @@ -4205,6 +4294,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) { switch name { case account.FieldConcurrency: return m.AddedConcurrency() + case account.FieldLoadFactor: + return m.AddedLoadFactor() case account.FieldPriority: return m.AddedPriority() case account.FieldRateMultiplier: @@ -4225,6 +4316,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error { } m.AddConcurrency(v) return nil + case account.FieldLoadFactor: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddLoadFactor(v) + return nil case account.FieldPriority: v, ok := value.(int) if !ok { @@ -4256,6 +4354,9 @@ func (m *AccountMutation) ClearedFields() []string { if m.FieldCleared(account.FieldProxyID) { fields = append(fields, account.FieldProxyID) } + if m.FieldCleared(account.FieldLoadFactor) { + fields = append(fields, account.FieldLoadFactor) + } if m.FieldCleared(account.FieldErrorMessage) { fields = append(fields, account.FieldErrorMessage) } @@ -4312,6 +4413,9 @@ func (m *AccountMutation) ClearField(name string) error { case account.FieldProxyID: m.ClearProxyID() return nil + case account.FieldLoadFactor: + m.ClearLoadFactor() + return nil case account.FieldErrorMessage: m.ClearErrorMessage() return nil @@ -4386,6 +4490,9 @@ func (m *AccountMutation) ResetField(name string) error { case account.FieldConcurrency: m.ResetConcurrency() return nil + case account.FieldLoadFactor: + m.ResetLoadFactor() + return nil case account.FieldPriority: m.ResetPriority() return nil @@ -10191,7 +10298,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 30) + fields := make([]string, 0, 31) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 2c7467f6..7ae4d253 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -212,29 +212,29 @@ func init() { // account.DefaultConcurrency holds the default value on creation for the concurrency field. account.DefaultConcurrency = accountDescConcurrency.Default.(int) // accountDescPriority is the schema descriptor for priority field. - accountDescPriority := accountFields[8].Descriptor() + accountDescPriority := accountFields[9].Descriptor() // account.DefaultPriority holds the default value on creation for the priority field. account.DefaultPriority = accountDescPriority.Default.(int) // accountDescRateMultiplier is the schema descriptor for rate_multiplier field. - accountDescRateMultiplier := accountFields[9].Descriptor() + accountDescRateMultiplier := accountFields[10].Descriptor() // account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64) // accountDescStatus is the schema descriptor for status field. - accountDescStatus := accountFields[10].Descriptor() + accountDescStatus := accountFields[11].Descriptor() // account.DefaultStatus holds the default value on creation for the status field. account.DefaultStatus = accountDescStatus.Default.(string) // account.StatusValidator is a validator for the "status" field. It is called by the builders before save. account.StatusValidator = accountDescStatus.Validators[0].(func(string) error) // accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field. - accountDescAutoPauseOnExpired := accountFields[14].Descriptor() + accountDescAutoPauseOnExpired := accountFields[15].Descriptor() // account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field. account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool) // accountDescSchedulable is the schema descriptor for schedulable field. - accountDescSchedulable := accountFields[15].Descriptor() + accountDescSchedulable := accountFields[16].Descriptor() // account.DefaultSchedulable holds the default value on creation for the schedulable field. account.DefaultSchedulable = accountDescSchedulable.Default.(bool) // accountDescSessionWindowStatus is the schema descriptor for session_window_status field. - accountDescSessionWindowStatus := accountFields[23].Descriptor() + accountDescSessionWindowStatus := accountFields[24].Descriptor() // account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save. account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error) accountgroupFields := schema.AccountGroup{}.Fields() diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index 443f9e09..5616d399 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -97,6 +97,8 @@ func (Account) Fields() []ent.Field { field.Int("concurrency"). Default(3), + field.Int("load_factor").Optional().Nillable(), + // priority: 账户优先级,数值越小优先级越高 // 调度器会优先使用高优先级的账户 field.Int("priority"). diff --git a/backend/go.sum b/backend/go.sum index 32e389a7..10161387 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -124,6 +124,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM= github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= @@ -171,8 +173,6 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -182,7 +182,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -203,6 +202,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -285,6 +286,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6 github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo= +github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= +github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4= +github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -398,8 +403,6 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= -go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= -go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= @@ -455,8 +458,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk= diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 14f9e05d..f42159a8 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -102,6 +102,7 @@ type CreateAccountRequest struct { Concurrency int `json:"concurrency"` Priority int `json:"priority"` RateMultiplier *float64 `json:"rate_multiplier"` + LoadFactor *int `json:"load_factor"` GroupIDs []int64 `json:"group_ids"` ExpiresAt *int64 `json:"expires_at"` AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` @@ -120,6 +121,7 @@ type UpdateAccountRequest struct { Concurrency *int `json:"concurrency"` Priority *int `json:"priority"` RateMultiplier *float64 `json:"rate_multiplier"` + LoadFactor *int `json:"load_factor"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` GroupIDs *[]int64 `json:"group_ids"` ExpiresAt *int64 `json:"expires_at"` @@ -135,6 +137,7 @@ type BulkUpdateAccountsRequest struct { Concurrency *int `json:"concurrency"` Priority *int `json:"priority"` RateMultiplier *float64 `json:"rate_multiplier"` + LoadFactor *int `json:"load_factor"` Status string `json:"status" binding:"omitempty,oneof=active inactive error"` Schedulable *bool `json:"schedulable"` GroupIDs *[]int64 `json:"group_ids"` @@ -285,32 +288,48 @@ func (h *AccountHandler) List(c *gin.Context) { } } - // 仅非 lite 模式获取窗口费用(PostgreSQL 聚合查询,高开销) - if !lite && len(windowCostAccountIDs) > 0 { - windowCosts = make(map[int64]float64) - var mu sync.Mutex - g, gctx := errgroup.WithContext(c.Request.Context()) - g.SetLimit(10) // 限制并发数 - - for i := range accounts { - acc := &accounts[i] - if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 { - continue - } - accCopy := acc // 闭包捕获 - g.Go(func() error { - // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) - startTime := accCopy.GetCurrentWindowStartTime() - stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime) - if err == nil && stats != nil { - mu.Lock() - windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用 - mu.Unlock() + // 窗口费用获取:lite 模式从快照缓存读取,非 lite 模式执行 PostgreSQL 查询后写入缓存 + if len(windowCostAccountIDs) > 0 { + if lite { + // lite 模式:尝试从快照缓存读取 + cacheKey := buildWindowCostCacheKey(windowCostAccountIDs) + if cached, ok := accountWindowCostCache.Get(cacheKey); ok { + if costs, ok := cached.Payload.(map[int64]float64); ok { + windowCosts = costs } - return nil // 不返回错误,允许部分失败 - }) + } + // 缓存未命中则 windowCosts 保持 nil(仅发生在服务刚启动时) + } else { + // 非 lite 模式:执行 PostgreSQL 聚合查询(高开销) + windowCosts = make(map[int64]float64) + var mu sync.Mutex + g, gctx := errgroup.WithContext(c.Request.Context()) + g.SetLimit(10) // 限制并发数 + + for i := range accounts { + acc := &accounts[i] + if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 { + continue + } + accCopy := acc // 闭包捕获 + g.Go(func() error { + // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) + startTime := accCopy.GetCurrentWindowStartTime() + stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime) + if err == nil && stats != nil { + mu.Lock() + windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用 + mu.Unlock() + } + return nil // 不返回错误,允许部分失败 + }) + } + _ = g.Wait() + + // 查询完毕后写入快照缓存,供 lite 模式使用 + cacheKey := buildWindowCostCacheKey(windowCostAccountIDs) + accountWindowCostCache.Set(cacheKey, windowCosts) } - _ = g.Wait() } // Build response with concurrency info @@ -506,6 +525,7 @@ func (h *AccountHandler) Create(c *gin.Context) { Concurrency: req.Concurrency, Priority: req.Priority, RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, GroupIDs: req.GroupIDs, ExpiresAt: req.ExpiresAt, AutoPauseOnExpired: req.AutoPauseOnExpired, @@ -575,6 +595,7 @@ func (h *AccountHandler) Update(c *gin.Context) { Concurrency: req.Concurrency, // 指针类型,nil 表示未提供 Priority: req.Priority, // 指针类型,nil 表示未提供 RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, Status: req.Status, GroupIDs: req.GroupIDs, ExpiresAt: req.ExpiresAt, @@ -1101,6 +1122,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { req.Concurrency != nil || req.Priority != nil || req.RateMultiplier != nil || + req.LoadFactor != nil || req.Status != "" || req.Schedulable != nil || req.GroupIDs != nil || @@ -1119,6 +1141,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { Concurrency: req.Concurrency, Priority: req.Priority, RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, Status: req.Status, Schedulable: req.Schedulable, GroupIDs: req.GroupIDs, @@ -1328,6 +1351,29 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) { response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } +// ResetQuota handles resetting account quota usage +// POST /api/v1/admin/accounts/:id/reset-quota +func (h *AccountHandler) ResetQuota(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + + if err := h.adminService.ResetAccountQuota(c.Request.Context(), accountID); err != nil { + response.InternalError(c, "Failed to reset account quota: "+err.Error()) + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) +} + // GetTempUnschedulable handles getting temporary unschedulable status // GET /api/v1/admin/accounts/:id/temp-unschedulable func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) { diff --git a/backend/internal/handler/admin/account_window_cost_cache.go b/backend/internal/handler/admin/account_window_cost_cache.go new file mode 100644 index 00000000..3271b630 --- /dev/null +++ b/backend/internal/handler/admin/account_window_cost_cache.go @@ -0,0 +1,25 @@ +package admin + +import ( + "strconv" + "strings" + "time" +) + +var accountWindowCostCache = newSnapshotCache(30 * time.Second) + +func buildWindowCostCacheKey(accountIDs []int64) string { + if len(accountIDs) == 0 { + return "accounts_window_cost_empty" + } + var b strings.Builder + b.Grow(len(accountIDs) * 6) + _, _ = b.WriteString("accounts_window_cost:") + for i, id := range accountIDs { + if i > 0 { + _ = b.WriteByte(',') + } + _, _ = b.WriteString(strconv.FormatInt(id, 10)) + } + return b.String() +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index f3b99ddb..84a9f102 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -425,5 +425,9 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i return nil, service.ErrAPIKeyNotFound } +func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error { + return nil +} + // Ensure stub implements interface. var _ service.AdminService = (*stubAdminService)(nil) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index fe2a1d77..03b122f3 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -183,6 +183,7 @@ func AccountFromServiceShallow(a *service.Account) *Account { Extra: a.Extra, ProxyID: a.ProxyID, Concurrency: a.Concurrency, + LoadFactor: a.LoadFactor, Priority: a.Priority, RateMultiplier: a.BillingRateMultiplier(), Status: a.Status, @@ -248,6 +249,17 @@ func AccountFromServiceShallow(a *service.Account) *Account { } } + // 提取 API Key 账号配额限制(仅 apikey 类型有效) + if a.Type == service.AccountTypeAPIKey { + if limit := a.GetQuotaLimit(); limit > 0 { + out.QuotaLimit = &limit + } + used := a.GetQuotaUsed() + if out.QuotaLimit != nil { + out.QuotaUsed = &used + } + } + return out } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 920615f7..e7835170 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -131,6 +131,7 @@ type Account struct { Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` Concurrency int `json:"concurrency"` + LoadFactor *int `json:"load_factor,omitempty"` Priority int `json:"priority"` RateMultiplier float64 `json:"rate_multiplier"` Status string `json:"status"` @@ -185,6 +186,10 @@ type Account struct { CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"` CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"` + // API Key 账号配额限制 + QuotaLimit *float64 `json:"quota_limit,omitempty"` + QuotaUsed *float64 `json:"quota_used,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index d2d9790d..d2a849b1 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2132,6 +2132,14 @@ func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service return 0, nil } +func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error { + return nil +} + +func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error { + return nil +} + // ==================== Stub: SoraClient (用于 SoraGatewayService) ==================== var _ service.SoraClient = (*stubSoraClientForHandler)(nil) diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index b76ab67d..637462ad 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -216,6 +216,14 @@ func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s return 0, nil } +func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + func (r *stubAccountRepo) listSchedulable() []service.Account { var result []service.Account for _, acc := range r.accounts { diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index 4bbc68e7..b0a31a5f 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -15,6 +15,7 @@ type Model struct { // DefaultModels OpenAI models list var DefaultModels = []Model{ + {ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 746188ea..8826c048 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -57,25 +57,28 @@ type DashboardStats struct { // TrendDataPoint represents a single point in trend data type TrendDataPoint struct { - Date string `json:"date"` - Requests int64 `json:"requests"` - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - CacheTokens int64 `json:"cache_tokens"` - TotalTokens int64 `json:"total_tokens"` - Cost float64 `json:"cost"` // 标准计费 - ActualCost float64 `json:"actual_cost"` // 实际扣除 + Date string `json:"date"` + Requests int64 `json:"requests"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheCreationTokens int64 `json:"cache_creation_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 } // ModelStat represents usage statistics for a single model type ModelStat struct { - Model string `json:"model"` - Requests int64 `json:"requests"` - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - TotalTokens int64 `json:"total_tokens"` - Cost float64 `json:"cost"` // 标准计费 - ActualCost float64 `json:"actual_cost"` // 实际扣除 + Model string `json:"model"` + Requests int64 `json:"requests"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheCreationTokens int64 `json:"cache_creation_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 } // GroupStat represents usage statistics for a single group diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 6f0c5424..ffbfd466 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -84,6 +84,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account if account.RateMultiplier != nil { builder.SetRateMultiplier(*account.RateMultiplier) } + if account.LoadFactor != nil { + builder.SetLoadFactor(*account.LoadFactor) + } if account.ProxyID != nil { builder.SetProxyID(*account.ProxyID) @@ -318,6 +321,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account if account.RateMultiplier != nil { builder.SetRateMultiplier(*account.RateMultiplier) } + if account.LoadFactor != nil { + builder.SetLoadFactor(*account.LoadFactor) + } else { + builder.ClearLoadFactor() + } if account.ProxyID != nil { builder.SetProxyID(*account.ProxyID) @@ -1223,6 +1231,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates args = append(args, *updates.RateMultiplier) idx++ } + if updates.LoadFactor != nil { + if *updates.LoadFactor <= 0 { + setClauses = append(setClauses, "load_factor = NULL") + } else { + setClauses = append(setClauses, "load_factor = $"+itoa(idx)) + args = append(args, *updates.LoadFactor) + idx++ + } + } if updates.Status != nil { setClauses = append(setClauses, "status = $"+itoa(idx)) args = append(args, *updates.Status) @@ -1545,6 +1562,7 @@ func accountEntityToService(m *dbent.Account) *service.Account { Concurrency: m.Concurrency, Priority: m.Priority, RateMultiplier: &rateMultiplier, + LoadFactor: m.LoadFactor, Status: m.Status, ErrorMessage: derefString(m.ErrorMessage), LastUsedAt: m.LastUsedAt, @@ -1657,3 +1675,60 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va return r.accountsToService(ctx, accounts) } + +// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段 +func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + rows, err := r.sql.QueryContext(ctx, + `UPDATE accounts SET extra = jsonb_set( + COALESCE(extra, '{}'::jsonb), + '{quota_used}', + to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1) + ), updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + RETURNING + COALESCE((extra->>'quota_used')::numeric, 0), + COALESCE((extra->>'quota_limit')::numeric, 0)`, + amount, id) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + var newUsed, limit float64 + if rows.Next() { + if err := rows.Scan(&newUsed, &limit); err != nil { + return err + } + } + if err := rows.Err(); err != nil { + return err + } + + // 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除 + if limit > 0 && newUsed >= limit && (newUsed-amount) < limit { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err) + } + } + return nil +} + +// ResetQuotaUsed 重置账号的 extra.quota_used 为 0 +func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error { + _, err := r.sql.ExecContext(ctx, + `UPDATE accounts SET extra = jsonb_set( + COALESCE(extra, '{}'::jsonb), + '{quota_used}', + '0'::jsonb + ), updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL`, + id) + if err != nil { + return err + } + // 重置配额后触发调度快照刷新,使账号重新参与调度 + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota reset failed: account=%d err=%v", id, err) + } + return nil +} diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index f6054828..1264f6bb 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -8,6 +8,7 @@ import ( "net/http" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -95,7 +96,8 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + msg := fmt.Sprintf("API returned status %d: %s", resp.StatusCode, string(body)) + return nil, infraerrors.New(http.StatusInternalServerError, "UPSTREAM_ERROR", msg) } var usageResp service.ClaudeUsageResponse diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 44079a55..7fc11b78 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1363,7 +1363,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(actual_cost), 0) as actual_cost @@ -1401,6 +1402,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64 COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(actual_cost), 0) as actual_cost @@ -1664,7 +1667,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(actual_cost), 0) as actual_cost @@ -1747,7 +1751,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st total_requests as requests, input_tokens, output_tokens, - (cache_creation_tokens + cache_read_tokens) as cache_tokens, + cache_creation_tokens, + cache_read_tokens, (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, total_cost as cost, actual_cost @@ -1762,7 +1767,8 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st total_requests as requests, input_tokens, output_tokens, - (cache_creation_tokens + cache_read_tokens) as cache_tokens, + cache_creation_tokens, + cache_read_tokens, (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, total_cost as cost, actual_cost @@ -1806,6 +1812,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, %s @@ -2622,7 +2630,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) { &row.Requests, &row.InputTokens, &row.OutputTokens, - &row.CacheTokens, + &row.CacheCreationTokens, + &row.CacheReadTokens, &row.TotalTokens, &row.Cost, &row.ActualCost, @@ -2646,6 +2655,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) { &row.Requests, &row.InputTokens, &row.OutputTokens, + &row.CacheCreationTokens, + &row.CacheReadTokens, &row.TotalTokens, &row.Cost, &row.ActualCost, diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 54eb81e1..53fb7227 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -125,7 +125,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)"). WithArgs(start, end, requestType). - WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"})) + WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"})) trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil) require.NoError(t, err) @@ -144,7 +144,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). WithArgs(start, end, requestType). - WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"})) + WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"})) stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil) require.NoError(t, err) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 40b2d592..aafbbe21 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -1096,6 +1096,14 @@ func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map return errors.New("not implemented") } +func (s *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { s.bulkUpdateIDs = append([]int64{}, ids...) return int64(len(ids)), nil diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index e9f9bf62..2e53feb3 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -252,6 +252,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats) accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) + accounts.POST("/:id/reset-quota", h.Admin.Account.ResetQuota) accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable) accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable) accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 7d56b754..8eb3748c 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -28,6 +28,7 @@ type Account struct { // RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。 // 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。 RateMultiplier *float64 + LoadFactor *int // 调度负载因子;nil 表示使用 Concurrency Status string ErrorMessage string LastUsedAt *time.Time @@ -88,6 +89,19 @@ func (a *Account) BillingRateMultiplier() float64 { return *a.RateMultiplier } +func (a *Account) EffectiveLoadFactor() int { + if a == nil { + return 1 + } + if a.LoadFactor != nil && *a.LoadFactor > 0 { + return *a.LoadFactor + } + if a.Concurrency > 0 { + return a.Concurrency + } + return 1 +} + func (a *Account) IsSchedulable() bool { if !a.IsActive() || !a.Schedulable { return false @@ -1117,6 +1131,38 @@ func (a *Account) GetCacheTTLOverrideTarget() string { return "5m" } +// GetQuotaLimit 获取 API Key 账号的配额限制(美元) +// 返回 0 表示未启用 +func (a *Account) GetQuotaLimit() float64 { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra["quota_limit"]; ok { + return parseExtraFloat64(v) + } + return 0 +} + +// GetQuotaUsed 获取 API Key 账号的已用配额(美元) +func (a *Account) GetQuotaUsed() float64 { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra["quota_used"]; ok { + return parseExtraFloat64(v) + } + return 0 +} + +// IsQuotaExceeded 检查 API Key 账号配额是否已超限 +func (a *Account) IsQuotaExceeded() bool { + limit := a.GetQuotaLimit() + if limit <= 0 { + return false + } + return a.GetQuotaUsed() >= limit +} + // GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // 返回 0 表示未启用 func (a *Account) GetWindowCostLimit() float64 { diff --git a/backend/internal/service/account_load_factor_test.go b/backend/internal/service/account_load_factor_test.go new file mode 100644 index 00000000..a4d78a4b --- /dev/null +++ b/backend/internal/service/account_load_factor_test.go @@ -0,0 +1,46 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func intPtrHelper(v int) *int { return &v } + +func TestEffectiveLoadFactor_NilAccount(t *testing.T) { + var a *Account + require.Equal(t, 1, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) { + a := &Account{Concurrency: 5} + require.Equal(t, 5, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) { + a := &Account{Concurrency: 0} + require.Equal(t, 1, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) { + a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(20)} + require.Equal(t, 20, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) { + a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(0)} + require.Equal(t, 5, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) { + a := &Account{Concurrency: 3, LoadFactor: intPtrHelper(-1)} + require.Equal(t, 3, a.EffectiveLoadFactor()) +} + +func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) { + a := &Account{Concurrency: 0, LoadFactor: intPtrHelper(0)} + require.Equal(t, 1, a.EffectiveLoadFactor()) +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 18a70c5c..26c0b1c2 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -68,6 +68,10 @@ type AccountRepository interface { UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateExtra(ctx context.Context, id int64, updates map[string]any) error BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) + // IncrementQuotaUsed 原子递增 API Key 账号的配额用量 + IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error + // ResetQuotaUsed 重置 API Key 账号的配额用量为 0 + ResetQuotaUsed(ctx context.Context, id int64) error } // AccountBulkUpdate describes the fields that can be updated in a bulk operation. @@ -78,6 +82,7 @@ type AccountBulkUpdate struct { Concurrency *int Priority *int RateMultiplier *float64 + LoadFactor *int Status *string Schedulable *bool Credentials map[string]any diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 768cf7b7..c96b436f 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -199,6 +199,14 @@ func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates A panic("unexpected BulkUpdate call") } +func (s *accountRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (s *accountRepoStub) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + // TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。 // 预期行为: // - ExistsByID 返回 false(账号不存在) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 99046e30..9557e175 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -180,7 +180,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int } if account.Platform == PlatformAntigravity { - return s.testAntigravityAccountConnection(c, account, modelID) + return s.routeAntigravityTest(c, account, modelID) } if account.Platform == PlatformSora { @@ -1177,6 +1177,18 @@ func truncateSoraErrorBody(body []byte, max int) string { return soraerror.TruncateBody(body, max) } +// routeAntigravityTest 路由 Antigravity 账号的测试请求。 +// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。 +func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error { + if account.Type == AccountTypeAPIKey { + if strings.HasPrefix(modelID, "gemini-") { + return s.testGeminiAccountConnection(c, account, modelID) + } + return s.testClaudeAccountConnection(c, account, modelID) + } + return s.testAntigravityAccountConnection(c, account, modelID) +} + // testAntigravityAccountConnection tests an Antigravity account's connection // 支持 Claude 和 Gemini 两种协议,使用非流式请求 func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 67e7c783..446cc148 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -84,6 +84,7 @@ type AdminService interface { DeleteRedeemCode(ctx context.Context, id int64) error BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) + ResetAccountQuota(ctx context.Context, id int64) error } // CreateUserInput represents input for creating a new user via admin operations. @@ -195,6 +196,7 @@ type CreateAccountInput struct { Concurrency int Priority int RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int GroupIDs []int64 ExpiresAt *int64 AutoPauseOnExpired *bool @@ -215,6 +217,7 @@ type UpdateAccountInput struct { Concurrency *int // 使用指针区分"未提供"和"设置为0" Priority *int // 使用指针区分"未提供"和"设置为0" RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int Status string GroupIDs *[]int64 ExpiresAt *int64 @@ -230,6 +233,7 @@ type BulkUpdateAccountsInput struct { Concurrency *int Priority *int RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + LoadFactor *int Status string Schedulable *bool GroupIDs *[]int64 @@ -1413,6 +1417,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } account.RateMultiplier = input.RateMultiplier } + if input.LoadFactor != nil && *input.LoadFactor > 0 { + if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } + account.LoadFactor = input.LoadFactor + } if err := s.accountRepo.Create(ctx, account); err != nil { return nil, err } @@ -1458,6 +1468,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.Credentials = input.Credentials } if len(input.Extra) > 0 { + // 保留 quota_used,防止编辑账号时意外重置配额用量 + if oldQuotaUsed, ok := account.Extra["quota_used"]; ok { + input.Extra["quota_used"] = oldQuotaUsed + } account.Extra = input.Extra } if input.ProxyID != nil { @@ -1483,6 +1497,15 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U } account.RateMultiplier = input.RateMultiplier } + if input.LoadFactor != nil { + if *input.LoadFactor <= 0 { + account.LoadFactor = nil // 0 或负数表示清除 + } else if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } else { + account.LoadFactor = input.LoadFactor + } + } if input.Status != "" { account.Status = input.Status } @@ -1616,6 +1639,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if input.RateMultiplier != nil { repoUpdates.RateMultiplier = input.RateMultiplier } + if input.LoadFactor != nil { + if *input.LoadFactor <= 0 { + repoUpdates.LoadFactor = nil // 0 或负数表示清除 + } else if *input.LoadFactor > 10000 { + return nil, errors.New("load_factor must be <= 10000") + } else { + repoUpdates.LoadFactor = input.LoadFactor + } + } if input.Status != "" { repoUpdates.Status = &input.Status } @@ -2439,3 +2471,7 @@ func (e *MixedChannelError) Error() string { return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.", e.GroupName, e.CurrentPlatform, e.OtherPlatform) } + +func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error { + return s.accountRepo.ResetQuotaUsed(ctx, id) +} diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 5d67c808..d058c25a 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -43,15 +43,24 @@ type BillingCache interface { // ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致) type ModelPricing struct { - InputPricePerToken float64 // 每token输入价格 (USD) - OutputPricePerToken float64 // 每token输出价格 (USD) - CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) - CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) - CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) - CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) - SupportsCacheBreakdown bool // 是否支持详细的缓存分类 + InputPricePerToken float64 // 每token输入价格 (USD) + OutputPricePerToken float64 // 每token输出价格 (USD) + CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) + CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) + CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) + CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) + SupportsCacheBreakdown bool // 是否支持详细的缓存分类 + LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格 + LongContextInputMultiplier float64 // 长上下文整次会话输入倍率 + LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率 } +const ( + openAIGPT54LongContextInputThreshold = 272000 + openAIGPT54LongContextInputMultiplier = 2.0 + openAIGPT54LongContextOutputMultiplier = 1.5 +) + // UsageTokens 使用的token数量 type UsageTokens struct { InputTokens int @@ -161,6 +170,35 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok SupportsCacheBreakdown: false, } + + // OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费) + s.fallbackPrices["gpt-5.1"] = &ModelPricing{ + InputPricePerToken: 1.25e-6, // $1.25 per MTok + OutputPricePerToken: 10e-6, // $10 per MTok + CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok + CacheReadPricePerToken: 0.125e-6, + SupportsCacheBreakdown: false, + } + // OpenAI GPT-5.4(业务指定价格) + s.fallbackPrices["gpt-5.4"] = &ModelPricing{ + InputPricePerToken: 2.5e-6, // $2.5 per MTok + OutputPricePerToken: 15e-6, // $15 per MTok + CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok + CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok + SupportsCacheBreakdown: false, + LongContextInputThreshold: openAIGPT54LongContextInputThreshold, + LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier, + LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, + } + // Codex 族兜底统一按 GPT-5.1 Codex 价格计费 + s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{ + InputPricePerToken: 1.5e-6, // $1.5 per MTok + OutputPricePerToken: 12e-6, // $12 per MTok + CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok + CacheReadPricePerToken: 0.15e-6, + SupportsCacheBreakdown: false, + } + s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"] } // getFallbackPricing 根据模型系列获取回退价格 @@ -189,12 +227,30 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { } return s.fallbackPrices["claude-3-haiku"] } + // Claude 未知型号统一回退到 Sonnet,避免计费中断。 + if strings.Contains(modelLower, "claude") { + return s.fallbackPrices["claude-sonnet-4"] + } if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") { return s.fallbackPrices["gemini-3.1-pro"] } - // 默认使用Sonnet价格 - return s.fallbackPrices["claude-sonnet-4"] + // OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。 + if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") { + normalized := normalizeCodexModel(modelLower) + switch normalized { + case "gpt-5.4": + return s.fallbackPrices["gpt-5.4"] + case "gpt-5.3-codex": + return s.fallbackPrices["gpt-5.3-codex"] + case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest": + return s.fallbackPrices["gpt-5.1-codex"] + case "gpt-5.1": + return s.fallbackPrices["gpt-5.1"] + } + } + + return nil } // GetModelPricing 获取模型价格配置 @@ -212,15 +268,18 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { price5m := litellmPricing.CacheCreationInputTokenCost price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr enableBreakdown := price1h > 0 && price1h > price5m - return &ModelPricing{ - InputPricePerToken: litellmPricing.InputCostPerToken, - OutputPricePerToken: litellmPricing.OutputCostPerToken, - CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, - CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, - CacheCreation5mPrice: price5m, - CacheCreation1hPrice: price1h, - SupportsCacheBreakdown: enableBreakdown, - }, nil + return s.applyModelSpecificPricingPolicy(model, &ModelPricing{ + InputPricePerToken: litellmPricing.InputCostPerToken, + OutputPricePerToken: litellmPricing.OutputCostPerToken, + CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, + CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, + CacheCreation5mPrice: price5m, + CacheCreation1hPrice: price1h, + SupportsCacheBreakdown: enableBreakdown, + LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold, + LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier, + LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier, + }), nil } } @@ -228,7 +287,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { fallback := s.getFallbackPricing(model) if fallback != nil { log.Printf("[Billing] Using fallback pricing for model: %s", model) - return fallback, nil + return s.applyModelSpecificPricingPolicy(model, fallback), nil } return nil, fmt.Errorf("pricing not found for model: %s", model) @@ -242,12 +301,18 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul } breakdown := &CostBreakdown{} + inputPricePerToken := pricing.InputPricePerToken + outputPricePerToken := pricing.OutputPricePerToken + if s.shouldApplySessionLongContextPricing(tokens, pricing) { + inputPricePerToken *= pricing.LongContextInputMultiplier + outputPricePerToken *= pricing.LongContextOutputMultiplier + } // 计算输入token费用(使用per-token价格) - breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken + breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken // 计算输出token费用 - breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken + breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken // 计算缓存费用 if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { @@ -279,6 +344,45 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul return breakdown, nil } +func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing { + if pricing == nil { + return nil + } + if !isOpenAIGPT54Model(model) { + return pricing + } + if pricing.LongContextInputThreshold > 0 && pricing.LongContextInputMultiplier > 0 && pricing.LongContextOutputMultiplier > 0 { + return pricing + } + cloned := *pricing + if cloned.LongContextInputThreshold <= 0 { + cloned.LongContextInputThreshold = openAIGPT54LongContextInputThreshold + } + if cloned.LongContextInputMultiplier <= 0 { + cloned.LongContextInputMultiplier = openAIGPT54LongContextInputMultiplier + } + if cloned.LongContextOutputMultiplier <= 0 { + cloned.LongContextOutputMultiplier = openAIGPT54LongContextOutputMultiplier + } + return &cloned +} + +func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens, pricing *ModelPricing) bool { + if pricing == nil || pricing.LongContextInputThreshold <= 0 { + return false + } + if pricing.LongContextInputMultiplier <= 1 && pricing.LongContextOutputMultiplier <= 1 { + return false + } + totalInputTokens := tokens.InputTokens + tokens.CacheReadTokens + return totalInputTokens > pricing.LongContextInputThreshold +} + +func isOpenAIGPT54Model(model string) bool { + normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model))) + return normalized == "gpt-5.4" +} + // CalculateCostWithConfig 使用配置中的默认倍率计算费用 func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) { multiplier := s.cfg.Default.RateMultiplier diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 5eb278f6..0ba52e56 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -133,7 +133,7 @@ func TestGetModelPricing_CaseInsensitive(t *testing.T) { require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken) } -func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) { +func TestGetModelPricing_UnknownClaudeModelFallsBackToSonnet(t *testing.T) { svc := newTestBillingService() // 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格 @@ -142,6 +142,93 @@ func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) { require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) } +func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-unknown-model") + require.Error(t, err) + require.Nil(t, pricing) + require.Contains(t, err.Error(), "pricing not found") +} + +func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-5.1") + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12) +} + +func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-5.4") + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12) + require.Equal(t, 272000, pricing.LongContextInputThreshold) + require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12) + require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12) +} + +func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 300000, + OutputTokens: 4000, + } + + cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0) + require.NoError(t, err) + + expectedInput := float64(tokens.InputTokens) * 2.5e-6 * 2.0 + expectedOutput := float64(tokens.OutputTokens) * 15e-6 * 1.5 + require.InDelta(t, expectedInput, cost.InputCost, 1e-10) + require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10) +} + +func TestGetFallbackPricing_FamilyMatching(t *testing.T) { + svc := newTestBillingService() + + tests := []struct { + name string + model string + expectedInput float64 + expectNilPricing bool + }{ + {name: "empty model", model: " ", expectNilPricing: true}, + {name: "claude opus 4.6", model: "claude-opus-4.6-20260201", expectedInput: 5e-6}, + {name: "claude opus 4.5 alt separator", model: "claude-opus-4-5-20260101", expectedInput: 5e-6}, + {name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6}, + {name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6}, + {name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true}, + {name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6}, + {name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6}, + {name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6}, + {name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6}, + {name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6}, + {name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true}, + {name: "non supported family", model: "qwen-max", expectNilPricing: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pricing := svc.getFallbackPricing(tt.model) + if tt.expectNilPricing { + require.Nil(t, pricing) + return + } + require.NotNil(t, pricing) + require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12) + }) + } +} func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) { svc := newTestBillingService() diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 1cb3c61e..320ceaa7 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -187,6 +187,14 @@ func (m *mockAccountRepoForPlatform) BulkUpdate(ctx context.Context, ids []int64 return 0, nil } +func (m *mockAccountRepoForPlatform) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (m *mockAccountRepoForPlatform) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + // Verify interface implementation var _ AccountRepository = (*mockAccountRepoForPlatform)(nil) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 132361f4..9f5c8299 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1228,6 +1228,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) continue } + // 配额检查 + if !s.isAccountSchedulableForQuota(account) { + continue + } // 窗口费用检查(非粘性会话路径) if !s.isAccountSchedulableForWindowCost(ctx, account, false) { filteredWindowCost++ @@ -1260,6 +1264,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && + s.isAccountSchedulableForQuota(stickyAccount) && s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 @@ -1311,7 +1316,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro for _, acc := range routingCandidates { routingLoads = append(routingLoads, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: acc.Concurrency, + MaxConcurrency: acc.EffectiveLoadFactor(), }) } routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads) @@ -1416,6 +1421,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro s.isAccountAllowedForPlatform(account, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && + s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查 @@ -1480,6 +1486,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + // 配额检查 + if !s.isAccountSchedulableForQuota(acc) { + continue + } // 窗口费用检查(非粘性会话路径) if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue @@ -1499,7 +1509,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro for _, acc := range candidates { accountLoads = append(accountLoads, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: acc.Concurrency, + MaxConcurrency: acc.EffectiveLoadFactor(), }) } @@ -2113,6 +2123,15 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts [] return context.WithValue(ctx, windowCostPrefetchContextKey, costs) } +// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内 +// 仅适用于配置了 quota_limit 的 apikey 类型账号 +func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool { + if account.Type != AccountTypeAPIKey { + return true + } + return !account.IsQuotaExceeded() +} + // isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 // 仅适用于 Anthropic OAuth/SetupToken 账号 // 返回 true 表示可调度,false 表示不可调度 @@ -2590,7 +2609,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2644,6 +2663,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -2700,7 +2722,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { return account, nil } } @@ -2743,6 +2765,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -2818,7 +2843,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) @@ -2874,6 +2899,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -2930,7 +2958,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { return account, nil } @@ -2975,6 +3003,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } + if !s.isAccountSchedulableForQuota(acc) { + continue + } if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } @@ -3289,6 +3320,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo if account.Platform == PlatformSora { return s.isSoraModelSupportedByAccount(account, requestedModel) } + // OpenAI 透传模式:仅替换认证,允许所有模型 + if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() { + return true + } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) @@ -6379,6 +6414,89 @@ type APIKeyQuotaUpdater interface { UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error } +// postUsageBillingParams 统一扣费所需的参数 +type postUsageBillingParams struct { + Cost *CostBreakdown + User *User + APIKey *APIKey + Account *Account + Subscription *UserSubscription + IsSubscriptionBill bool + AccountRateMultiplier float64 + APIKeyService APIKeyQuotaUpdater +} + +// postUsageBilling 统一处理使用量记录后的扣费逻辑: +// - 订阅/余额扣费 +// - API Key 配额更新 +// - API Key 限速用量更新 +// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率) +func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { + cost := p.Cost + + // 1. 订阅 / 余额扣费 + if p.IsSubscriptionBill { + if cost.TotalCost > 0 { + if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil { + slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) + } + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost) + } + } else { + if cost.ActualCost > 0 { + if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil { + slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) + } + deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost) + } + } + + // 2. API Key 配额 + if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) + } + } + + // 3. API Key 限速用量 + if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) + } + deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost) + } + + // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) + if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 { + accountCost := cost.TotalCost * p.AccountRateMultiplier + if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil { + slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) + } + } + + // 5. 更新账号最近使用时间 + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) +} + +// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) +type billingDeps struct { + accountRepo AccountRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + billingCacheService *BillingCacheService + deferredService *DeferredService +} + +func (s *GatewayService) billingDeps() *billingDeps { + return &billingDeps{ + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + } +} + // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { result := input.Result @@ -6542,45 +6660,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu shouldBill := inserted || err != nil - // 根据计费类型执行扣费 - if isSubscriptionBilling { - // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) - if shouldBill && cost.TotalCost > 0 { - if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) - } - // 异步更新订阅缓存 - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } + if shouldBill { + postUsageBilling(ctx, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps()) } else { - // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) - if shouldBill && cost.ActualCost > 0 { - if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) - } - // 异步更新余额缓存 - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - } + s.deferredService.ScheduleLastUsedUpdate(account.ID) } - // 更新 API Key 配额(如果设置了配额限制) - if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err) - } - } - - // Update API Key rate limit usage - if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err) - } - s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil } @@ -6740,44 +6834,21 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * shouldBill := inserted || err != nil - // 根据计费类型执行扣费 - if isSubscriptionBilling { - // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) - if shouldBill && cost.TotalCost > 0 { - if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) - } - // 异步更新订阅缓存 - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } + if shouldBill { + postUsageBilling(ctx, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps()) } else { - // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) - if shouldBill && cost.ActualCost > 0 { - if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) - } - // 异步更新余额缓存 - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - // API Key 独立配额扣费 - if input.APIKeyService != nil && apiKey.Quota > 0 { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err) - } - } - } + s.deferredService.ScheduleLastUsedUpdate(account.ID) } - // Update API Key rate limit usage - if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err) - } - s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 9476e984..b0b804eb 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -176,6 +176,14 @@ func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64, return 0, nil } +func (m *mockAccountRepoForGemini) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { + return nil +} + +func (m *mockAccountRepoForGemini) ResetQuotaUsed(ctx context.Context, id int64) error { + return nil +} + // Verify interface implementation var _ AccountRepository = (*mockAccountRepoForGemini)(nil) diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 99013ce5..cf4bc26e 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -342,6 +342,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( } cfg := s.service.schedulingConfig() + // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 if s.service.concurrencyService != nil { return &AccountSelectionResult{ Account: account, @@ -590,7 +591,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( filtered = append(filtered, account) loadReq = append(loadReq, AccountWithConcurrency{ ID: account.ID, - MaxConcurrency: account.Concurrency, + MaxConcurrency: account.EffectiveLoadFactor(), }) } if len(filtered) == 0 { @@ -703,6 +704,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } cfg := s.service.schedulingConfig() + // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 candidate := selectionOrder[0] return &AccountSelectionResult{ Account: candidate.account, diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 16befb82..9bc48cf6 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -9,6 +9,13 @@ import ( var codexCLIInstructions string var codexModelMap = map[string]string{ + "gpt-5.4": "gpt-5.4", + "gpt-5.4-none": "gpt-5.4", + "gpt-5.4-low": "gpt-5.4", + "gpt-5.4-medium": "gpt-5.4", + "gpt-5.4-high": "gpt-5.4", + "gpt-5.4-xhigh": "gpt-5.4", + "gpt-5.4-chat-latest": "gpt-5.4", "gpt-5.3": "gpt-5.3-codex", "gpt-5.3-none": "gpt-5.3-codex", "gpt-5.3-low": "gpt-5.3-codex", @@ -154,6 +161,9 @@ func normalizeCodexModel(model string) string { normalized := strings.ToLower(modelID) + if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { + return "gpt-5.4" + } if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") { return "gpt-5.2-codex" } diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 27093f6c..7ee4bbc8 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -167,6 +167,10 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { func TestNormalizeCodexModel_Gpt53(t *testing.T) { cases := map[string]string{ + "gpt-5.4": "gpt-5.4", + "gpt-5.4-high": "gpt-5.4", + "gpt-5.4-chat-latest": "gpt-5.4", + "gpt 5.4": "gpt-5.4", "gpt-5.3": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index d92b2ecf..73bdba65 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -319,6 +319,16 @@ func NewOpenAIGatewayService( return svc } +func (s *OpenAIGatewayService) billingDeps() *billingDeps { + return &billingDeps{ + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + } +} + // CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。 // 应在应用优雅关闭时调用。 func (s *OpenAIGatewayService) CloseOpenAIWSPool() { @@ -1242,7 +1252,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex for _, acc := range candidates { accountLoads = append(accountLoads, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: acc.Concurrency, + MaxConcurrency: acc.EffectiveLoadFactor(), }) } @@ -3474,37 +3484,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec shouldBill := inserted || err != nil - // Deduct based on billing type - if isSubscriptionBilling { - if shouldBill && cost.TotalCost > 0 { - _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } + if shouldBill { + postUsageBilling(ctx, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps()) } else { - if shouldBill && cost.ActualCost > 0 { - _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - } + s.deferredService.ScheduleLastUsedUpdate(account.ID) } - // Update API key quota if applicable (only for balance mode with quota set) - if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err) - } - } - - // Update API Key rate limit usage - if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err) - } - s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index a5c2fd7a..7b6591fa 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -864,7 +864,8 @@ func isOpenAIWSClientDisconnectError(err error) bool { strings.Contains(message, "unexpected eof") || strings.Contains(message, "use of closed network connection") || strings.Contains(message, "connection reset by peer") || - strings.Contains(message, "broken pipe") + strings.Contains(message, "broken pipe") || + strings.Contains(message, "an established connection was aborted") } func classifyOpenAIWSReadFallbackReason(err error) string { diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index 92b37e73..a571dd4d 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -64,8 +64,12 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts if acc.ID <= 0 { continue } - if prev, ok := unique[acc.ID]; !ok || acc.Concurrency > prev { - unique[acc.ID] = acc.Concurrency + c := acc.Concurrency + if c <= 0 { + c = 1 + } + if prev, ok := unique[acc.ID]; !ok || c > prev { + unique[acc.ID] = c } } diff --git a/backend/internal/service/ops_metrics_collector.go b/backend/internal/service/ops_metrics_collector.go index 30adaae0..f93481e7 100644 --- a/backend/internal/service/ops_metrics_collector.go +++ b/backend/internal/service/ops_metrics_collector.go @@ -389,13 +389,9 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con if acc.ID <= 0 { continue } - maxConc := acc.Concurrency - if maxConc < 0 { - maxConc = 0 - } batch = append(batch, AccountWithConcurrency{ ID: acc.ID, - MaxConcurrency: maxConc, + MaxConcurrency: acc.Concurrency, }) } if len(batch) == 0 { diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 41e8b5eb..897623d6 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -21,8 +21,19 @@ import ( ) var ( - openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`) - openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) + openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`) + openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) + openAIGPT54FallbackPricing = &LiteLLMModelPricing{ + InputCostPerToken: 2.5e-06, // $2.5 per MTok + OutputCostPerToken: 1.5e-05, // $15 per MTok + CacheReadInputTokenCost: 2.5e-07, // $0.25 per MTok + LongContextInputTokenThreshold: 272000, + LongContextInputCostMultiplier: 2.0, + LongContextOutputCostMultiplier: 1.5, + LiteLLMProvider: "openai", + Mode: "chat", + SupportsPromptCaching: true, + } ) // LiteLLMModelPricing LiteLLM价格数据结构 @@ -33,6 +44,9 @@ type LiteLLMModelPricing struct { CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"` CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` + LongContextInputTokenThreshold int `json:"long_context_input_token_threshold,omitempty"` + LongContextInputCostMultiplier float64 `json:"long_context_input_cost_multiplier,omitempty"` + LongContextOutputCostMultiplier float64 `json:"long_context_output_cost_multiplier,omitempty"` LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` @@ -660,7 +674,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // 2. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等) // 3. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) // 4. gpt-5.3-codex -> gpt-5.2-codex -// 5. 最终回退到 DefaultTestModel (gpt-5.1-codex) +// 5. gpt-5.4* -> 业务静态兜底价 +// 6. 最终回退到 DefaultTestModel (gpt-5.1-codex) func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { if strings.HasPrefix(model, "gpt-5.3-codex-spark") { if pricing, ok := s.pricingData["gpt-5.1-codex"]; ok { @@ -690,6 +705,12 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { } } + if strings.HasPrefix(model, "gpt-5.4") { + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)")) + return openAIGPT54FallbackPricing + } + // 最终回退到 DefaultTestModel defaultModel := strings.ToLower(openai.DefaultTestModel) if pricing, ok := s.pricingData[defaultModel]; ok { diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index 127ff342..6b67c55a 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -51,3 +51,20 @@ func TestGetModelPricing_OpenAIFallbackMatchedLoggedAsInfo(t *testing.T) { require.True(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "info")) require.False(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "warn")) } + +func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": &LiteLLMModelPricing{InputCostPerToken: 1.25e-6}, + }, + } + + got := svc.GetModelPricing("gpt-5.4") + require.NotNil(t, got) + require.InDelta(t, 2.5e-6, got.InputCostPerToken, 1e-12) + require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12) + require.InDelta(t, 2.5e-7, got.CacheReadInputTokenCost, 1e-12) + require.Equal(t, 272000, got.LongContextInputTokenThreshold) + require.InDelta(t, 2.0, got.LongContextInputCostMultiplier, 1e-12) + require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12) +} diff --git a/backend/internal/service/subscription_calculate_progress_test.go b/backend/internal/service/subscription_calculate_progress_test.go index 22018bcd..53e5c568 100644 --- a/backend/internal/service/subscription_calculate_progress_test.go +++ b/backend/internal/service/subscription_calculate_progress_test.go @@ -34,7 +34,7 @@ func TestCalculateProgress_BasicFields(t *testing.T) { assert.Equal(t, int64(100), progress.ID) assert.Equal(t, "Premium", progress.GroupName) assert.Equal(t, sub.ExpiresAt, progress.ExpiresAt) - assert.Equal(t, 29, progress.ExpiresInDays) // 约 30 天 + assert.True(t, progress.ExpiresInDays == 29 || progress.ExpiresInDays == 30, "ExpiresInDays should be 29 or 30, got %d", progress.ExpiresInDays) assert.Nil(t, progress.Daily, "无日限额时 Daily 应为 nil") assert.Nil(t, progress.Weekly, "无周限额时 Weekly 应为 nil") assert.Nil(t, progress.Monthly, "无月限额时 Monthly 应为 nil") diff --git a/backend/migrations/067_add_account_load_factor.sql b/backend/migrations/067_add_account_load_factor.sql new file mode 100644 index 00000000..6805e8c2 --- /dev/null +++ b/backend/migrations/067_add_account_load_factor.sql @@ -0,0 +1 @@ +ALTER TABLE accounts ADD COLUMN IF NOT EXISTS load_factor INTEGER; diff --git a/backend/resources/model-pricing/model_prices_and_context_window.json b/backend/resources/model-pricing/model_prices_and_context_window.json index 650e128e..72860bf9 100644 --- a/backend/resources/model-pricing/model_prices_and_context_window.json +++ b/backend/resources/model-pricing/model_prices_and_context_window.json @@ -5140,6 +5140,39 @@ "supports_vision": true, "supports_web_search": true }, + "gpt-5.4": { + "cache_read_input_token_cost": 2.5e-07, + "input_cost_per_token": 2.5e-06, + "litellm_provider": "openai", + "max_input_tokens": 1050000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1.5e-05, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text", + "image" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_service_tier": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true + }, "gpt-5.3-codex": { "cache_read_input_token_cost": 1.75e-07, "cache_read_input_token_cost_priority": 3.5e-07, diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 25bb7b7b..5524e0cb 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -240,6 +240,18 @@ export async function clearRateLimit(id: number): Promise { return data } +/** + * Reset account quota usage + * @param id - Account ID + * @returns Updated account + */ +export async function resetAccountQuota(id: number): Promise { + const { data } = await apiClient.post( + `/admin/accounts/${id}/reset-quota` + ) + return data +} + /** * Get temporary unschedulable status * @param id - Account ID @@ -576,6 +588,7 @@ export const accountsAPI = { getTodayStats, getBatchTodayStats, clearRateLimit, + resetAccountQuota, getTempUnschedulableStatus, resetTempUnschedulable, setSchedulable, diff --git a/frontend/src/components/account/AccountCapacityCell.vue b/frontend/src/components/account/AccountCapacityCell.vue index 2a4babf2..2001b185 100644 --- a/frontend/src/components/account/AccountCapacityCell.vue +++ b/frontend/src/components/account/AccountCapacityCell.vue @@ -71,6 +71,24 @@ {{ rpmStrategyTag }} + + +
+ + + + + ${{ formatCost(currentQuotaUsed) }} + / + ${{ formatCost(account.quota_limit) }} + +
@@ -286,6 +304,48 @@ const rpmTooltip = computed(() => { } }) +// 是否显示配额限制(仅 apikey 类型且设置了 quota_limit) +const showQuotaLimit = computed(() => { + return ( + props.account.type === 'apikey' && + props.account.quota_limit !== undefined && + props.account.quota_limit !== null && + props.account.quota_limit > 0 + ) +}) + +// 当前已用配额 +const currentQuotaUsed = computed(() => props.account.quota_used ?? 0) + +// 配额状态样式 +const quotaClass = computed(() => { + if (!showQuotaLimit.value) return '' + + const used = currentQuotaUsed.value + const limit = props.account.quota_limit || 0 + + if (used >= limit) { + return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400' + } + if (used >= limit * 0.8) { + return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400' + } + return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400' +}) + +// 配额提示文字 +const quotaTooltip = computed(() => { + if (!showQuotaLimit.value) return '' + + const used = currentQuotaUsed.value + const limit = props.account.quota_limit || 0 + + if (used >= limit) { + return t('admin.accounts.capacity.quota.exceeded') + } + return t('admin.accounts.capacity.quota.normal') +}) + // 格式化费用显示 const formatCost = (value: number | null | undefined) => { if (value === null || value === undefined) return '0' diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 1c83e658..1d6f32fe 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -469,7 +469,7 @@ -
+
+
+
+ + +
+ +

{{ t('admin.accounts.loadFactorHint') }}

+
+ + +
@@ -1749,10 +1752,18 @@
-
+
- + +
+
+ + +

{{ t('admin.accounts.loadFactorHint') }}

@@ -2337,11 +2348,12 @@ import Icon from '@/components/icons/Icon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' +import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue' import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { - OPENAI_WS_MODE_CTX_POOL, + // OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, OPENAI_WS_MODE_PASSTHROUGH, isOpenAIWSModeEnabled, @@ -2460,6 +2472,7 @@ const accountCategory = ref<'oauth-based' | 'apikey'>('oauth-based') // UI selec const addMethod = ref('oauth') // For oauth-based: 'oauth' or 'setup-token' const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyValue = ref('') +const editQuotaLimit = ref(null) const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) @@ -2542,7 +2555,8 @@ const geminiSelectedTier = computed(() => { const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 + // { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) @@ -2632,6 +2646,7 @@ const form = reactive({ credentials: {} as Record, proxy_id: null as number | null, concurrency: 10, + load_factor: null as number | null, priority: 1, rate_multiplier: 1, group_ids: [] as number[], @@ -3111,6 +3126,7 @@ const resetForm = () => { form.credentials = {} form.proxy_id = null form.concurrency = 10 + form.load_factor = null form.priority = 1 form.rate_multiplier = 1 form.group_ids = [] @@ -3119,6 +3135,7 @@ const resetForm = () => { addMethod.value = 'oauth' apiKeyBaseUrl.value = 'https://api.anthropic.com' apiKeyValue.value = '' + editQuotaLimit.value = null modelMappings.value = [] modelRestrictionMode.value = 'whitelist' allowedModels.value = [...claudeModels] // Default fill related models @@ -3482,6 +3499,7 @@ const handleImportAccessToken = async (accessTokenInput: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3532,15 +3550,21 @@ const createAccountAndFinish = async ( if (!applyTempUnschedConfig(credentials)) { return } + // Inject quota_limit for apikey accounts + let finalExtra = extra + if (type === 'apikey' && editQuotaLimit.value != null && editQuotaLimit.value > 0) { + finalExtra = { ...(extra || {}), quota_limit: editQuotaLimit.value } + } await doCreateAccount({ name: form.name, notes: form.notes, platform, type, credentials, - extra, + extra: finalExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3596,6 +3620,7 @@ const handleOpenAIExchange = async (authCode: string) => { extra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3625,6 +3650,7 @@ const handleOpenAIExchange = async (authCode: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3702,6 +3728,7 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { extra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3729,6 +3756,7 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3817,6 +3845,7 @@ const handleSoraValidateST = async (sessionTokenInput: string) => { extra: soraExtra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -3905,6 +3934,7 @@ const handleAntigravityValidateRT = async (refreshTokenInput: string) => { extra: {}, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, @@ -4063,8 +4093,11 @@ const handleAnthropicExchange = async (authCode: string) => { } // Add RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - extra.base_rpm = baseRpm.value + if (rpmLimitEnabled.value) { + const DEFAULT_BASE_RPM = 15 + extra.base_rpm = (baseRpm.value != null && baseRpm.value > 0) + ? baseRpm.value + : DEFAULT_BASE_RPM extra.rpm_strategy = rpmStrategy.value if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { extra.rpm_sticky_buffer = rpmStickyBuffer.value @@ -4175,8 +4208,11 @@ const handleCookieAuth = async (sessionKey: string) => { } // Add RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - extra.base_rpm = baseRpm.value + if (rpmLimitEnabled.value) { + const DEFAULT_BASE_RPM = 15 + extra.base_rpm = (baseRpm.value != null && baseRpm.value > 0) + ? baseRpm.value + : DEFAULT_BASE_RPM extra.rpm_strategy = rpmStrategy.value if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { extra.rpm_sticky_buffer = rpmStickyBuffer.value @@ -4222,6 +4258,7 @@ const handleCookieAuth = async (sessionKey: string) => { extra, proxy_id: form.proxy_id, concurrency: form.concurrency, + load_factor: form.load_factor ?? undefined, priority: form.priority, rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 20d785e2..be7d2d45 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -650,10 +650,18 @@
-
+
- + +
+
+ + +

{{ t('admin.accounts.loadFactorHint') }}

@@ -759,6 +767,9 @@
+ + +
(OPENAI_WS_MODE_OFF const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) +const editQuotaLimit = ref(null) const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 + // { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) const openaiResponsesWebSocketV2Mode = computed({ @@ -1464,6 +1478,7 @@ const form = reactive({ notes: '', proxy_id: null as number | null, concurrency: 1, + load_factor: null as number | null, priority: 1, rate_multiplier: 1, status: 'active' as 'active' | 'inactive', @@ -1497,9 +1512,12 @@ watch( form.notes = newAccount.notes || '' form.proxy_id = newAccount.proxy_id form.concurrency = newAccount.concurrency + form.load_factor = newAccount.load_factor ?? null form.priority = newAccount.priority form.rate_multiplier = newAccount.rate_multiplier ?? 1 - form.status = newAccount.status as 'active' | 'inactive' + form.status = (newAccount.status === 'active' || newAccount.status === 'inactive') + ? newAccount.status + : 'active' form.group_ids = newAccount.group_ids || [] form.expires_at = newAccount.expires_at ?? null @@ -1540,6 +1558,14 @@ watch( anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true } + // Load quota limit for apikey accounts + if (newAccount.type === 'apikey') { + const quotaVal = extra?.quota_limit as number | undefined + editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null + } else { + editQuotaLimit.value = null + } + // Load antigravity model mapping (Antigravity 只支持映射模式) if (newAccount.platform === 'antigravity') { const credentials = newAccount.credentials as Record | undefined @@ -2039,6 +2065,11 @@ const handleSubmit = async () => { if (!props.account) return const accountID = props.account.id + if (form.status !== 'active' && form.status !== 'inactive') { + appStore.showError(t('admin.accounts.pleaseSelectStatus')) + return + } + const updatePayload: Record = { ...form } try { // 后端期望 proxy_id: 0 表示清除代理,而不是 null @@ -2048,6 +2079,11 @@ const handleSubmit = async () => { if (form.expires_at === null) { updatePayload.expires_at = 0 } + // load_factor: 空值/NaN/0/负数 时发送 0(后端约定 <= 0 = 清除) + const lf = form.load_factor + if (lf == null || Number.isNaN(lf) || lf <= 0) { + updatePayload.load_factor = 0 + } updatePayload.auto_pause_on_expired = autoPauseOnExpired.value // For apikey type, handle credentials update @@ -2187,8 +2223,11 @@ const handleSubmit = async () => { } // RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - newExtra.base_rpm = baseRpm.value + if (rpmLimitEnabled.value) { + const DEFAULT_BASE_RPM = 15 + newExtra.base_rpm = (baseRpm.value != null && baseRpm.value > 0) + ? baseRpm.value + : DEFAULT_BASE_RPM newExtra.rpm_strategy = rpmStrategy.value if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { newExtra.rpm_sticky_buffer = rpmStickyBuffer.value @@ -2282,6 +2321,19 @@ const handleSubmit = async () => { updatePayload.extra = newExtra } + // For apikey accounts, handle quota_limit in extra + if (props.account.type === 'apikey') { + const currentExtra = (updatePayload.extra as Record) || + (props.account.extra as Record) || {} + const newExtra: Record = { ...currentExtra } + if (editQuotaLimit.value != null && editQuotaLimit.value > 0) { + newExtra.quota_limit = editQuotaLimit.value + } else { + delete newExtra.quota_limit + } + updatePayload.extra = newExtra + } + const canContinue = await ensureAntigravityMixedChannelConfirmed(async () => { await submitUpdateAccount(accountID, updatePayload) }) diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue new file mode 100644 index 00000000..1be73a25 --- /dev/null +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -0,0 +1,92 @@ + + + diff --git a/frontend/src/components/admin/account/AccountActionMenu.vue b/frontend/src/components/admin/account/AccountActionMenu.vue index fbff0bed..02596b9f 100644 --- a/frontend/src/components/admin/account/AccountActionMenu.vue +++ b/frontend/src/components/admin/account/AccountActionMenu.vue @@ -41,6 +41,10 @@ {{ t('admin.accounts.clearRateLimit') }} +
@@ -55,7 +59,7 @@ import { Icon } from '@/components/icons' import type { Account } from '@/types' const props = defineProps<{ show: boolean; account: Account | null; position: { top: number; left: number } | null }>() -const emit = defineEmits(['close', 'test', 'stats', 'schedule', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit']) +const emit = defineEmits(['close', 'test', 'stats', 'schedule', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit', 'reset-quota']) const { t } = useI18n() const isRateLimited = computed(() => { if (props.account?.rate_limit_reset_at && new Date(props.account.rate_limit_reset_at) > new Date()) { @@ -71,6 +75,12 @@ const isRateLimited = computed(() => { return false }) const isOverloaded = computed(() => props.account?.overload_until && new Date(props.account.overload_until) > new Date()) +const hasQuotaLimit = computed(() => { + return props.account?.type === 'apikey' && + props.account?.quota_limit !== undefined && + props.account?.quota_limit !== null && + props.account?.quota_limit > 0 +}) const handleKeydown = (event: KeyboardEvent) => { if (event.key === 'Escape') emit('close') diff --git a/frontend/src/components/charts/TokenUsageTrend.vue b/frontend/src/components/charts/TokenUsageTrend.vue index d9ceda87..a255fb03 100644 --- a/frontend/src/components/charts/TokenUsageTrend.vue +++ b/frontend/src/components/charts/TokenUsageTrend.vue @@ -63,7 +63,8 @@ const chartColors = computed(() => ({ grid: isDarkMode.value ? '#374151' : '#e5e7eb', input: '#3b82f6', output: '#10b981', - cache: '#f59e0b' + cacheCreation: '#f59e0b', + cacheRead: '#06b6d4' })) const chartData = computed(() => { @@ -89,10 +90,18 @@ const chartData = computed(() => { tension: 0.3 }, { - label: 'Cache', - data: props.trendData.map((d) => d.cache_tokens), - borderColor: chartColors.value.cache, - backgroundColor: `${chartColors.value.cache}20`, + label: 'Cache Creation', + data: props.trendData.map((d) => d.cache_creation_tokens), + borderColor: chartColors.value.cacheCreation, + backgroundColor: `${chartColors.value.cacheCreation}20`, + fill: true, + tension: 0.3 + }, + { + label: 'Cache Read', + data: props.trendData.map((d) => d.cache_read_tokens), + borderColor: chartColors.value.cacheRead, + backgroundColor: `${chartColors.value.cacheRead}20`, fill: true, tension: 0.3 } diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue index 4dd7ff0c..99d78f69 100644 --- a/frontend/src/components/keys/UseKeyModal.vue +++ b/frontend/src/components/keys/UseKeyModal.vue @@ -443,7 +443,22 @@ $env:ANTHROPIC_AUTH_TOKEN="${apiKey}"` content = '' } - return [{ path, content }] + const vscodeSettingsPath = activeTab.value === 'unix' + ? '~/.claude/settings.json' + : '%userprofile%\\.claude\\settings.json' + + const vscodeContent = `{ + "env": { + "ANTHROPIC_BASE_URL": "${baseUrl}", + "ANTHROPIC_AUTH_TOKEN": "${apiKey}", + "CLAUDE_CODE_ATTRIBUTION_HEADER": "0" + } +}` + + return [ + { path, content }, + { path: vscodeSettingsPath, content: vscodeContent, hint: 'VSCode Claude Code' } + ] } function generateGeminiCliContent(baseUrl: string, apiKey: string): FileConfig { @@ -496,16 +511,18 @@ function generateOpenAIFiles(baseUrl: string, apiKey: string): FileConfig[] { const configDir = isWindows ? '%userprofile%\\.codex' : '~/.codex' // config.toml content - const configContent = `model_provider = "sub2api" -model = "gpt-5.3-codex" -model_reasoning_effort = "high" -network_access = "enabled" + const configContent = `model_provider = "OpenAI" +model = "gpt-5.4" +review_model = "gpt-5.4" +model_reasoning_effort = "xhigh" disable_response_storage = true +network_access = "enabled" windows_wsl_setup_acknowledged = true -model_verbosity = "high" +model_context_window = 1000000 +model_auto_compact_token_limit = 900000 -[model_providers.sub2api] -name = "sub2api" +[model_providers.OpenAI] +name = "OpenAI" base_url = "${baseUrl}" wire_api = "responses" requires_openai_auth = true` @@ -533,16 +550,18 @@ function generateOpenAIWsFiles(baseUrl: string, apiKey: string): FileConfig[] { const configDir = isWindows ? '%userprofile%\\.codex' : '~/.codex' // config.toml content with WebSocket v2 - const configContent = `model_provider = "sub2api" -model = "gpt-5.3-codex" -model_reasoning_effort = "high" -network_access = "enabled" + const configContent = `model_provider = "OpenAI" +model = "gpt-5.4" +review_model = "gpt-5.4" +model_reasoning_effort = "xhigh" disable_response_storage = true +network_access = "enabled" windows_wsl_setup_acknowledged = true -model_verbosity = "high" +model_context_window = 1000000 +model_auto_compact_token_limit = 900000 -[model_providers.sub2api] -name = "sub2api" +[model_providers.OpenAI] +name = "OpenAI" base_url = "${baseUrl}" wire_api = "responses" supports_websockets = true @@ -655,6 +674,22 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin xhigh: {} } }, + 'gpt-5.4': { + name: 'GPT-5.4', + limit: { + context: 1050000, + output: 128000 + }, + options: { + store: false + }, + variants: { + low: {}, + medium: {}, + high: {}, + xhigh: {} + } + }, 'gpt-5.3-codex-spark': { name: 'GPT-5.3 Codex Spark', limit: { diff --git a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts index 4088e5a4..79c88a29 100644 --- a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts +++ b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts @@ -2,6 +2,13 @@ import { describe, expect, it } from 'vitest' import { buildModelMappingObject, getModelsByPlatform } from '../useModelWhitelist' describe('useModelWhitelist', () => { + it('openai 模型列表包含 GPT-5.4 官方快照', () => { + const models = getModelsByPlatform('openai') + + expect(models).toContain('gpt-5.4') + expect(models).toContain('gpt-5.4-2026-03-05') + }) + it('antigravity 模型列表包含图片模型兼容项', () => { const models = getModelsByPlatform('antigravity') @@ -15,4 +22,12 @@ describe('useModelWhitelist', () => { 'gemini-3.1-flash-image': 'gemini-3.1-flash-image' }) }) + + it('whitelist 模式会保留 GPT-5.4 官方快照的精确映射', () => { + const mapping = buildModelMappingObject('whitelist', ['gpt-5.4-2026-03-05'], []) + + expect(mapping).toEqual({ + 'gpt-5.4-2026-03-05': 'gpt-5.4-2026-03-05' + }) + }) }) diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index 444e4b91..986a99f4 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -24,6 +24,8 @@ const openaiModels = [ // GPT-5.2 系列 'gpt-5.2', 'gpt-5.2-2025-12-11', 'gpt-5.2-chat-latest', 'gpt-5.2-codex', 'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11', + // GPT-5.4 系列 + 'gpt-5.4', 'gpt-5.4-2026-03-05', // GPT-5.3 系列 'gpt-5.3-codex', 'gpt-5.3-codex-spark', 'chatgpt-4o-latest', @@ -277,6 +279,7 @@ const openaiPresetMappings = [ { label: 'GPT-5.3 Codex Spark', from: 'gpt-5.3-codex-spark', to: 'gpt-5.3-codex-spark', color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' }, { label: 'GPT-5.1', from: 'gpt-5.1', to: 'gpt-5.1', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' }, { label: 'GPT-5.2', from: 'gpt-5.2', to: 'gpt-5.2', color: 'bg-red-100 text-red-700 hover:bg-red-200 dark:bg-red-900/30 dark:text-red-400' }, + { label: 'GPT-5.4', from: 'gpt-5.4', to: 'gpt-5.4', color: 'bg-rose-100 text-rose-700 hover:bg-rose-200 dark:bg-rose-900/30 dark:text-rose-400' }, { label: 'GPT-5.1 Codex', from: 'gpt-5.1-codex', to: 'gpt-5.1-codex', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' } ] diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index cb388600..87d8d816 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -133,6 +133,8 @@ export default { requests: 'Requests', inputTokens: 'Input Tokens', outputTokens: 'Output Tokens', + cacheCreationTokens: 'Cache Creation', + cacheReadTokens: 'Cache Read', totalTokens: 'Total Tokens', cost: 'Cost', // Status @@ -155,11 +157,19 @@ export default { subscriptionExpires: 'Subscription Expires', // Usage stat cells todayRequests: 'Today Requests', + todayInputTokens: 'Today Input', + todayOutputTokens: 'Today Output', todayTokens: 'Today Tokens', + todayCacheCreation: 'Today Cache Creation', + todayCacheRead: 'Today Cache Read', todayCost: 'Today Cost', rpmTpm: 'RPM / TPM', totalRequests: 'Total Requests', + totalInputTokens: 'Total Input', + totalOutputTokens: 'Total Output', totalTokensLabel: 'Total Tokens', + totalCacheCreation: 'Total Cache Creation', + totalCacheRead: 'Total Cache Read', totalCost: 'Total Cost', avgDuration: 'Avg Duration', // Messages @@ -1724,6 +1734,10 @@ export default { stickyExemptWarning: 'RPM limit (Sticky Exempt) - Approaching limit', stickyExemptOver: 'RPM limit (Sticky Exempt) - Over limit, sticky only' }, + quota: { + exceeded: 'Quota exceeded, account paused', + normal: 'Quota normal' + }, }, tempUnschedulable: { title: 'Temp Unschedulable', @@ -1769,6 +1783,14 @@ export default { } }, clearRateLimit: 'Clear Rate Limit', + resetQuota: 'Reset Quota', + quotaLimit: 'Quota Limit', + quotaLimitPlaceholder: '0 means unlimited', + quotaLimitHint: 'Set max spending limit (USD). Account will be paused when reached. Changing limit won\'t reset usage.', + quotaLimitToggle: 'Enable Quota Limit', + quotaLimitToggleHint: 'When enabled, account will be paused when usage reaches the set limit', + quotaLimitAmount: 'Limit Amount', + quotaLimitAmountHint: 'Maximum spending limit (USD). Account will be auto-paused when reached. Changing limit won\'t reset usage.', testConnection: 'Test Connection', reAuthorize: 'Re-Authorize', refreshToken: 'Refresh Token', @@ -1981,10 +2003,12 @@ export default { proxy: 'Proxy', noProxy: 'No Proxy', concurrency: 'Concurrency', + loadFactor: 'Load Factor', + loadFactorHint: 'Higher load factor increases scheduling frequency', priority: 'Priority', priorityHint: 'Lower value accounts are used first', billingRateMultiplier: 'Billing Rate Multiplier', - billingRateMultiplierHint: '>=0, 0 means free. Affects account billing only', + billingRateMultiplierHint: '0 = free, affects account billing only', expiresAt: 'Expires At', expiresAtHint: 'Leave empty for no expiration', higherPriorityFirst: 'Lower value means higher priority', @@ -2000,6 +2024,7 @@ export default { accountUpdated: 'Account updated successfully', failedToCreate: 'Failed to create account', failedToUpdate: 'Failed to update account', + pleaseSelectStatus: 'Please select a valid account status', mixedChannelWarningTitle: 'Mixed Channel Warning', mixedChannelWarning: 'Warning: Group "{groupName}" contains both {currentPlatform} and {otherPlatform} accounts. Mixing different channels may cause thinking block signature validation issues, which will fallback to non-thinking mode. Are you sure you want to continue?', pleaseEnterAccountName: 'Please enter account name', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 7c208aa9..ec783615 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -133,6 +133,8 @@ export default { requests: '请求数', inputTokens: '输入 Tokens', outputTokens: '输出 Tokens', + cacheCreationTokens: '缓存创建', + cacheReadTokens: '缓存读取', totalTokens: '总 Tokens', cost: '费用', // Status @@ -155,11 +157,19 @@ export default { subscriptionExpires: '订阅到期', // Usage stat cells todayRequests: '今日请求', + todayInputTokens: '今日输入', + todayOutputTokens: '今日输出', todayTokens: '今日 Tokens', + todayCacheCreation: '今日缓存创建', + todayCacheRead: '今日缓存读取', todayCost: '今日费用', rpmTpm: 'RPM / TPM', totalRequests: '累计请求', + totalInputTokens: '累计输入', + totalOutputTokens: '累计输出', totalTokensLabel: '累计 Tokens', + totalCacheCreation: '累计缓存创建', + totalCacheRead: '累计缓存读取', totalCost: '累计费用', avgDuration: '平均耗时', // Messages @@ -1774,8 +1784,20 @@ export default { stickyExemptWarning: 'RPM 限制 (粘性豁免) - 接近阈值', stickyExemptOver: 'RPM 限制 (粘性豁免) - 超限,仅粘性会话' }, + quota: { + exceeded: '配额已用完,账号暂停调度', + normal: '配额正常' + }, }, clearRateLimit: '清除速率限制', + resetQuota: '重置配额', + quotaLimit: '配额限制', + quotaLimitPlaceholder: '0 表示不限制', + quotaLimitHint: '设置最大使用额度(美元),达到后账号暂停调度。修改限额不会重置已用额度。', + quotaLimitToggle: '启用配额限制', + quotaLimitToggleHint: '开启后,当账号用量达到设定额度时自动暂停调度', + quotaLimitAmount: '限额金额', + quotaLimitAmountHint: '账号最大可用额度(美元),达到后自动暂停。修改限额不会重置已用额度。', testConnection: '测试连接', reAuthorize: '重新授权', refreshToken: '刷新令牌', @@ -2123,10 +2145,12 @@ export default { proxy: '代理', noProxy: '无代理', concurrency: '并发数', + loadFactor: '负载因子', + loadFactorHint: '提高负载因子可以提高对账号的调度频率', priority: '优先级', priorityHint: '优先级越小的账号优先使用', billingRateMultiplier: '账号计费倍率', - billingRateMultiplierHint: '>=0,0 表示该账号计费为 0;仅影响账号计费口径', + billingRateMultiplierHint: '0 表示不计费,仅影响账号计费', expiresAt: '过期时间', expiresAtHint: '留空表示不过期', higherPriorityFirst: '数值越小优先级越高', @@ -2142,6 +2166,7 @@ export default { accountUpdated: '账号更新成功', failedToCreate: '创建账号失败', failedToUpdate: '更新账号失败', + pleaseSelectStatus: '请选择有效的账号状态', mixedChannelWarningTitle: '混合渠道警告', mixedChannelWarning: '警告:分组 "{groupName}" 中同时包含 {currentPlatform} 和 {otherPlatform} 账号。混合使用不同渠道可能导致 thinking block 签名验证问题,会自动回退到非 thinking 模式。确定要继续吗?', pleaseEnterAccountName: '请输入账号名称', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index cdc4953a..243586bf 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -653,6 +653,7 @@ export interface Account { } & Record) proxy_id: number | null concurrency: number + load_factor?: number | null current_concurrency?: number // Real-time concurrency count from Redis priority: number rate_multiplier?: number // Account billing multiplier (>=0, 0 means free) @@ -705,6 +706,10 @@ export interface Account { cache_ttl_override_enabled?: boolean | null cache_ttl_override_target?: string | null + // API Key 账号配额限制 + quota_limit?: number | null + quota_used?: number | null + // 运行时状态(仅当启用对应限制时返回) current_window_cost?: number | null // 当前窗口费用 active_sessions?: number | null // 当前活跃会话数 @@ -783,6 +788,7 @@ export interface CreateAccountRequest { extra?: Record proxy_id?: number | null concurrency?: number + load_factor?: number | null priority?: number rate_multiplier?: number // Account billing multiplier (>=0, 0 means free) group_ids?: number[] @@ -799,6 +805,7 @@ export interface UpdateAccountRequest { extra?: Record proxy_id?: number | null concurrency?: number + load_factor?: number | null priority?: number rate_multiplier?: number // Account billing multiplier (>=0, 0 means free) schedulable?: boolean @@ -1098,7 +1105,8 @@ export interface TrendDataPoint { requests: number input_tokens: number output_tokens: number - cache_tokens: number + cache_creation_tokens: number + cache_read_tokens: number total_tokens: number cost: number // 标准计费 actual_cost: number // 实际扣除 @@ -1109,6 +1117,8 @@ export interface ModelStat { requests: number input_tokens: number output_tokens: number + cache_creation_tokens: number + cache_read_tokens: number total_tokens: number cost: number // 标准计费 actual_cost: number // 实际扣除 diff --git a/frontend/src/views/KeyUsageView.vue b/frontend/src/views/KeyUsageView.vue index a061ad9f..755f1966 100644 --- a/frontend/src/views/KeyUsageView.vue +++ b/frontend/src/views/KeyUsageView.vue @@ -302,6 +302,8 @@ {{ t('keyUsage.requests') }} {{ t('keyUsage.inputTokens') }} {{ t('keyUsage.outputTokens') }} + {{ t('keyUsage.cacheCreationTokens') }} + {{ t('keyUsage.cacheReadTokens') }} {{ t('keyUsage.totalTokens') }} {{ t('keyUsage.cost') }} @@ -316,6 +318,8 @@ {{ fmtNum(m.requests) }} {{ fmtNum(m.input_tokens) }} {{ fmtNum(m.output_tokens) }} + {{ fmtNum(m.cache_creation_tokens) }} + {{ fmtNum(m.cache_read_tokens) }} {{ fmtNum(m.total_tokens) }} {{ usd(m.actual_cost != null ? m.actual_cost : m.cost) }} @@ -694,11 +698,19 @@ const usageStatCells = computed(() => { return [ { label: t('keyUsage.todayRequests'), value: fmtNum(today.requests) }, + { label: t('keyUsage.todayInputTokens'), value: fmtNum(today.input_tokens) }, + { label: t('keyUsage.todayOutputTokens'), value: fmtNum(today.output_tokens) }, { label: t('keyUsage.todayTokens'), value: fmtNum(today.total_tokens) }, + { label: t('keyUsage.todayCacheCreation'), value: fmtNum(today.cache_creation_tokens) }, + { label: t('keyUsage.todayCacheRead'), value: fmtNum(today.cache_read_tokens) }, { label: t('keyUsage.todayCost'), value: usd(today.actual_cost) }, { label: t('keyUsage.rpmTpm'), value: `${usage.rpm || 0} / ${usage.tpm || 0}` }, { label: t('keyUsage.totalRequests'), value: fmtNum(total.requests) }, + { label: t('keyUsage.totalInputTokens'), value: fmtNum(total.input_tokens) }, + { label: t('keyUsage.totalOutputTokens'), value: fmtNum(total.output_tokens) }, { label: t('keyUsage.totalTokensLabel'), value: fmtNum(total.total_tokens) }, + { label: t('keyUsage.totalCacheCreation'), value: fmtNum(total.cache_creation_tokens) }, + { label: t('keyUsage.totalCacheRead'), value: fmtNum(total.cache_read_tokens) }, { label: t('keyUsage.totalCost'), value: usd(total.actual_cost) }, { label: t('keyUsage.avgDuration'), value: usage.average_duration_ms ? `${Math.round(usage.average_duration_ms)} ms` : '-' }, ] diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index 146b2647..0173ea0a 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -261,7 +261,7 @@ - + @@ -1125,6 +1125,16 @@ const handleClearRateLimit = async (a: Account) => { console.error('Failed to clear rate limit:', error) } } +const handleResetQuota = async (a: Account) => { + try { + const updated = await adminAPI.accounts.resetAccountQuota(a.id) + patchAccountInList(updated) + enterAutoRefreshSilentWindow() + appStore.showSuccess(t('common.success')) + } catch (error) { + console.error('Failed to reset quota:', error) + } +} const handleDelete = (a: Account) => { deletingAcc.value = a; showDeleteDialog.value = true } const confirmDelete = async () => { if(!deletingAcc.value) return; try { await adminAPI.accounts.delete(deletingAcc.value.id); showDeleteDialog.value = false; deletingAcc.value = null; reload() } catch (error) { console.error('Failed to delete account:', error) } } const handleToggleSchedulable = async (a: Account) => { diff --git a/frontend/src/views/user/UsageView.vue b/frontend/src/views/user/UsageView.vue index ff875325..4bd5f6d8 100644 --- a/frontend/src/views/user/UsageView.vue +++ b/frontend/src/views/user/UsageView.vue @@ -113,6 +113,9 @@
+