Merge branch 'main' into release/custom-0.1.91

# Conflicts:
#	frontend/src/components/admin/account/AccountActionMenu.vue
#	frontend/src/views/admin/AccountsView.vue
This commit is contained in:
erio
2026-03-06 04:08:14 +08:00
97 changed files with 6442 additions and 311 deletions

View File

@@ -1 +1 @@
0.1.88
0.1.90.9

View File

@@ -41,6 +41,8 @@ type Account struct {
ProxyID *int64 `json:"proxy_id,omitempty"`
// Concurrency holds the value of the "concurrency" field.
Concurrency int `json:"concurrency,omitempty"`
// LoadFactor holds the value of the "load_factor" field.
LoadFactor *int `json:"load_factor,omitempty"`
// Priority holds the value of the "priority" field.
Priority int `json:"priority,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
@@ -143,7 +145,7 @@ func (*Account) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case account.FieldRateMultiplier:
values[i] = new(sql.NullFloat64)
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldLoadFactor, account.FieldPriority:
values[i] = new(sql.NullInt64)
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus:
values[i] = new(sql.NullString)
@@ -243,6 +245,13 @@ func (_m *Account) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Concurrency = int(value.Int64)
}
case account.FieldLoadFactor:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field load_factor", values[i])
} else if value.Valid {
_m.LoadFactor = new(int)
*_m.LoadFactor = int(value.Int64)
}
case account.FieldPriority:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field priority", values[i])
@@ -445,6 +454,11 @@ func (_m *Account) String() string {
builder.WriteString("concurrency=")
builder.WriteString(fmt.Sprintf("%v", _m.Concurrency))
builder.WriteString(", ")
if v := _m.LoadFactor; v != nil {
builder.WriteString("load_factor=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("priority=")
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
builder.WriteString(", ")

View File

@@ -37,6 +37,8 @@ const (
FieldProxyID = "proxy_id"
// FieldConcurrency holds the string denoting the concurrency field in the database.
FieldConcurrency = "concurrency"
// FieldLoadFactor holds the string denoting the load_factor field in the database.
FieldLoadFactor = "load_factor"
// FieldPriority holds the string denoting the priority field in the database.
FieldPriority = "priority"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
@@ -121,6 +123,7 @@ var Columns = []string{
FieldExtra,
FieldProxyID,
FieldConcurrency,
FieldLoadFactor,
FieldPriority,
FieldRateMultiplier,
FieldStatus,
@@ -250,6 +253,11 @@ func ByConcurrency(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldConcurrency, opts...).ToFunc()
}
// ByLoadFactor orders the results by the load_factor field.
func ByLoadFactor(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLoadFactor, opts...).ToFunc()
}
// ByPriority orders the results by the priority field.
func ByPriority(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPriority, opts...).ToFunc()

View File

@@ -100,6 +100,11 @@ func Concurrency(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldConcurrency, v))
}
// LoadFactor applies equality check predicate on the "load_factor" field. It's identical to LoadFactorEQ.
func LoadFactor(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
}
// Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ.
func Priority(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldPriority, v))
@@ -650,6 +655,56 @@ func ConcurrencyLTE(v int) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldConcurrency, v))
}
// LoadFactorEQ applies the EQ predicate on the "load_factor" field.
func LoadFactorEQ(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldLoadFactor, v))
}
// LoadFactorNEQ applies the NEQ predicate on the "load_factor" field.
func LoadFactorNEQ(v int) predicate.Account {
return predicate.Account(sql.FieldNEQ(FieldLoadFactor, v))
}
// LoadFactorIn applies the In predicate on the "load_factor" field.
func LoadFactorIn(vs ...int) predicate.Account {
return predicate.Account(sql.FieldIn(FieldLoadFactor, vs...))
}
// LoadFactorNotIn applies the NotIn predicate on the "load_factor" field.
func LoadFactorNotIn(vs ...int) predicate.Account {
return predicate.Account(sql.FieldNotIn(FieldLoadFactor, vs...))
}
// LoadFactorGT applies the GT predicate on the "load_factor" field.
func LoadFactorGT(v int) predicate.Account {
return predicate.Account(sql.FieldGT(FieldLoadFactor, v))
}
// LoadFactorGTE applies the GTE predicate on the "load_factor" field.
func LoadFactorGTE(v int) predicate.Account {
return predicate.Account(sql.FieldGTE(FieldLoadFactor, v))
}
// LoadFactorLT applies the LT predicate on the "load_factor" field.
func LoadFactorLT(v int) predicate.Account {
return predicate.Account(sql.FieldLT(FieldLoadFactor, v))
}
// LoadFactorLTE applies the LTE predicate on the "load_factor" field.
func LoadFactorLTE(v int) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldLoadFactor, v))
}
// LoadFactorIsNil applies the IsNil predicate on the "load_factor" field.
func LoadFactorIsNil() predicate.Account {
return predicate.Account(sql.FieldIsNull(FieldLoadFactor))
}
// LoadFactorNotNil applies the NotNil predicate on the "load_factor" field.
func LoadFactorNotNil() predicate.Account {
return predicate.Account(sql.FieldNotNull(FieldLoadFactor))
}
// PriorityEQ applies the EQ predicate on the "priority" field.
func PriorityEQ(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldPriority, v))

View File

@@ -139,6 +139,20 @@ func (_c *AccountCreate) SetNillableConcurrency(v *int) *AccountCreate {
return _c
}
// SetLoadFactor sets the "load_factor" field.
func (_c *AccountCreate) SetLoadFactor(v int) *AccountCreate {
_c.mutation.SetLoadFactor(v)
return _c
}
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
func (_c *AccountCreate) SetNillableLoadFactor(v *int) *AccountCreate {
if v != nil {
_c.SetLoadFactor(*v)
}
return _c
}
// SetPriority sets the "priority" field.
func (_c *AccountCreate) SetPriority(v int) *AccountCreate {
_c.mutation.SetPriority(v)
@@ -623,6 +637,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
_spec.SetField(account.FieldConcurrency, field.TypeInt, value)
_node.Concurrency = value
}
if value, ok := _c.mutation.LoadFactor(); ok {
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
_node.LoadFactor = &value
}
if value, ok := _c.mutation.Priority(); ok {
_spec.SetField(account.FieldPriority, field.TypeInt, value)
_node.Priority = value
@@ -936,6 +954,30 @@ func (u *AccountUpsert) AddConcurrency(v int) *AccountUpsert {
return u
}
// SetLoadFactor sets the "load_factor" field.
func (u *AccountUpsert) SetLoadFactor(v int) *AccountUpsert {
u.Set(account.FieldLoadFactor, v)
return u
}
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
func (u *AccountUpsert) UpdateLoadFactor() *AccountUpsert {
u.SetExcluded(account.FieldLoadFactor)
return u
}
// AddLoadFactor adds v to the "load_factor" field.
func (u *AccountUpsert) AddLoadFactor(v int) *AccountUpsert {
u.Add(account.FieldLoadFactor, v)
return u
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (u *AccountUpsert) ClearLoadFactor() *AccountUpsert {
u.SetNull(account.FieldLoadFactor)
return u
}
// SetPriority sets the "priority" field.
func (u *AccountUpsert) SetPriority(v int) *AccountUpsert {
u.Set(account.FieldPriority, v)
@@ -1419,6 +1461,34 @@ func (u *AccountUpsertOne) UpdateConcurrency() *AccountUpsertOne {
})
}
// SetLoadFactor sets the "load_factor" field.
func (u *AccountUpsertOne) SetLoadFactor(v int) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.SetLoadFactor(v)
})
}
// AddLoadFactor adds v to the "load_factor" field.
func (u *AccountUpsertOne) AddLoadFactor(v int) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.AddLoadFactor(v)
})
}
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
func (u *AccountUpsertOne) UpdateLoadFactor() *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.UpdateLoadFactor()
})
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (u *AccountUpsertOne) ClearLoadFactor() *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.ClearLoadFactor()
})
}
// SetPriority sets the "priority" field.
func (u *AccountUpsertOne) SetPriority(v int) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
@@ -2113,6 +2183,34 @@ func (u *AccountUpsertBulk) UpdateConcurrency() *AccountUpsertBulk {
})
}
// SetLoadFactor sets the "load_factor" field.
func (u *AccountUpsertBulk) SetLoadFactor(v int) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.SetLoadFactor(v)
})
}
// AddLoadFactor adds v to the "load_factor" field.
func (u *AccountUpsertBulk) AddLoadFactor(v int) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.AddLoadFactor(v)
})
}
// UpdateLoadFactor sets the "load_factor" field to the value that was provided on create.
func (u *AccountUpsertBulk) UpdateLoadFactor() *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.UpdateLoadFactor()
})
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (u *AccountUpsertBulk) ClearLoadFactor() *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.ClearLoadFactor()
})
}
// SetPriority sets the "priority" field.
func (u *AccountUpsertBulk) SetPriority(v int) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {

View File

@@ -172,6 +172,33 @@ func (_u *AccountUpdate) AddConcurrency(v int) *AccountUpdate {
return _u
}
// SetLoadFactor sets the "load_factor" field.
func (_u *AccountUpdate) SetLoadFactor(v int) *AccountUpdate {
_u.mutation.ResetLoadFactor()
_u.mutation.SetLoadFactor(v)
return _u
}
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
func (_u *AccountUpdate) SetNillableLoadFactor(v *int) *AccountUpdate {
if v != nil {
_u.SetLoadFactor(*v)
}
return _u
}
// AddLoadFactor adds value to the "load_factor" field.
func (_u *AccountUpdate) AddLoadFactor(v int) *AccountUpdate {
_u.mutation.AddLoadFactor(v)
return _u
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (_u *AccountUpdate) ClearLoadFactor() *AccountUpdate {
_u.mutation.ClearLoadFactor()
return _u
}
// SetPriority sets the "priority" field.
func (_u *AccountUpdate) SetPriority(v int) *AccountUpdate {
_u.mutation.ResetPriority()
@@ -684,6 +711,15 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedConcurrency(); ok {
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
}
if value, ok := _u.mutation.LoadFactor(); ok {
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedLoadFactor(); ok {
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
}
if _u.mutation.LoadFactorCleared() {
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
}
if value, ok := _u.mutation.Priority(); ok {
_spec.SetField(account.FieldPriority, field.TypeInt, value)
}
@@ -1063,6 +1099,33 @@ func (_u *AccountUpdateOne) AddConcurrency(v int) *AccountUpdateOne {
return _u
}
// SetLoadFactor sets the "load_factor" field.
func (_u *AccountUpdateOne) SetLoadFactor(v int) *AccountUpdateOne {
_u.mutation.ResetLoadFactor()
_u.mutation.SetLoadFactor(v)
return _u
}
// SetNillableLoadFactor sets the "load_factor" field if the given value is not nil.
func (_u *AccountUpdateOne) SetNillableLoadFactor(v *int) *AccountUpdateOne {
if v != nil {
_u.SetLoadFactor(*v)
}
return _u
}
// AddLoadFactor adds value to the "load_factor" field.
func (_u *AccountUpdateOne) AddLoadFactor(v int) *AccountUpdateOne {
_u.mutation.AddLoadFactor(v)
return _u
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (_u *AccountUpdateOne) ClearLoadFactor() *AccountUpdateOne {
_u.mutation.ClearLoadFactor()
return _u
}
// SetPriority sets the "priority" field.
func (_u *AccountUpdateOne) SetPriority(v int) *AccountUpdateOne {
_u.mutation.ResetPriority()
@@ -1605,6 +1668,15 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
if value, ok := _u.mutation.AddedConcurrency(); ok {
_spec.AddField(account.FieldConcurrency, field.TypeInt, value)
}
if value, ok := _u.mutation.LoadFactor(); ok {
_spec.SetField(account.FieldLoadFactor, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedLoadFactor(); ok {
_spec.AddField(account.FieldLoadFactor, field.TypeInt, value)
}
if _u.mutation.LoadFactorCleared() {
_spec.ClearField(account.FieldLoadFactor, field.TypeInt)
}
if value, ok := _u.mutation.Priority(); ok {
_spec.SetField(account.FieldPriority, field.TypeInt, value)
}

View File

@@ -62,22 +62,24 @@ type Group struct {
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
// 是否仅允许 Claude Code 客户端
// allow Claude Code client only
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// Claude Code 请求降级使用的分组 ID
// fallback group for non-Claude-Code requests
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// 无效请求兜底使用的分组 ID
// fallback group for invalid request
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
// 模型路由配置:模型模式 -> 优先账号ID列表
// model routing config: pattern -> account ids
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
// 是否启用模型路由配置
// whether model routing is enabled
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
// 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)
// whether MCP XML prompt injection is enabled
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
// 支持的模型系列:claude, gemini_text, gemini_image
// supported model scopes: claude, gemini_text, gemini_image
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// 分组显示排序,数值越小越靠前
// group display order, lower comes first
SortOrder int `json:"sort_order,omitempty"`
// simulate claude usage as claude-max style (1h cache write)
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -186,7 +188,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case group.FieldModelRouting, group.FieldSupportedModelScopes:
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject:
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldSimulateClaudeMaxEnabled:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
values[i] = new(sql.NullFloat64)
@@ -415,6 +417,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.SortOrder = int(value.Int64)
}
case group.FieldSimulateClaudeMaxEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field simulate_claude_max_enabled", values[i])
} else if value.Valid {
_m.SimulateClaudeMaxEnabled = value.Bool
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -608,6 +616,9 @@ func (_m *Group) String() string {
builder.WriteString(", ")
builder.WriteString("sort_order=")
builder.WriteString(fmt.Sprintf("%v", _m.SortOrder))
builder.WriteString(", ")
builder.WriteString("simulate_claude_max_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.SimulateClaudeMaxEnabled))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -75,6 +75,8 @@ const (
FieldSupportedModelScopes = "supported_model_scopes"
// FieldSortOrder holds the string denoting the sort_order field in the database.
FieldSortOrder = "sort_order"
// FieldSimulateClaudeMaxEnabled holds the string denoting the simulate_claude_max_enabled field in the database.
FieldSimulateClaudeMaxEnabled = "simulate_claude_max_enabled"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -180,6 +182,7 @@ var Columns = []string{
FieldMcpXMLInject,
FieldSupportedModelScopes,
FieldSortOrder,
FieldSimulateClaudeMaxEnabled,
}
var (
@@ -247,6 +250,8 @@ var (
DefaultSupportedModelScopes []string
// DefaultSortOrder holds the default value on creation for the "sort_order" field.
DefaultSortOrder int
// DefaultSimulateClaudeMaxEnabled holds the default value on creation for the "simulate_claude_max_enabled" field.
DefaultSimulateClaudeMaxEnabled bool
)
// OrderOption defines the ordering options for the Group queries.
@@ -397,6 +402,11 @@ func BySortOrder(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSortOrder, opts...).ToFunc()
}
// BySimulateClaudeMaxEnabled orders the results by the simulate_claude_max_enabled field.
func BySimulateClaudeMaxEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSimulateClaudeMaxEnabled, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -195,6 +195,11 @@ func SortOrder(v int) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSortOrder, v))
}
// SimulateClaudeMaxEnabled applies equality check predicate on the "simulate_claude_max_enabled" field. It's identical to SimulateClaudeMaxEnabledEQ.
func SimulateClaudeMaxEnabled(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1470,6 +1475,16 @@ func SortOrderLTE(v int) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSortOrder, v))
}
// SimulateClaudeMaxEnabledEQ applies the EQ predicate on the "simulate_claude_max_enabled" field.
func SimulateClaudeMaxEnabledEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
}
// SimulateClaudeMaxEnabledNEQ applies the NEQ predicate on the "simulate_claude_max_enabled" field.
func SimulateClaudeMaxEnabledNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSimulateClaudeMaxEnabled, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {

View File

@@ -424,6 +424,20 @@ func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate {
return _c
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (_c *GroupCreate) SetSimulateClaudeMaxEnabled(v bool) *GroupCreate {
_c.mutation.SetSimulateClaudeMaxEnabled(v)
return _c
}
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupCreate {
if v != nil {
_c.SetSimulateClaudeMaxEnabled(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -613,6 +627,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultSortOrder
_c.mutation.SetSortOrder(v)
}
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
v := group.DefaultSimulateClaudeMaxEnabled
_c.mutation.SetSimulateClaudeMaxEnabled(v)
}
return nil
}
@@ -683,6 +701,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.SortOrder(); !ok {
return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)}
}
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
return &ValidationError{Name: "simulate_claude_max_enabled", err: errors.New(`ent: missing required field "Group.simulate_claude_max_enabled"`)}
}
return nil
}
@@ -830,6 +851,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
_node.SortOrder = value
}
if value, ok := _c.mutation.SimulateClaudeMaxEnabled(); ok {
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
_node.SimulateClaudeMaxEnabled = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1520,6 +1545,18 @@ func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert {
return u
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (u *GroupUpsert) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsert {
u.Set(group.FieldSimulateClaudeMaxEnabled, v)
return u
}
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSimulateClaudeMaxEnabled() *GroupUpsert {
u.SetExcluded(group.FieldSimulateClaudeMaxEnabled)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -2188,6 +2225,20 @@ func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne {
})
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (u *GroupUpsertOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSimulateClaudeMaxEnabled(v)
})
}
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSimulateClaudeMaxEnabled() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSimulateClaudeMaxEnabled()
})
}
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -3022,6 +3073,20 @@ func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk {
})
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (u *GroupUpsertBulk) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSimulateClaudeMaxEnabled(v)
})
}
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSimulateClaudeMaxEnabled() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSimulateClaudeMaxEnabled()
})
}
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -625,6 +625,20 @@ func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate {
return _u
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (_u *GroupUpdate) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdate {
_u.mutation.SetSimulateClaudeMaxEnabled(v)
return _u
}
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdate {
if v != nil {
_u.SetSimulateClaudeMaxEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1110,6 +1124,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedSortOrder(); ok {
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -2014,6 +2031,20 @@ func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne {
return _u
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (_u *GroupUpdateOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdateOne {
_u.mutation.SetSimulateClaudeMaxEnabled(v)
return _u
}
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetSimulateClaudeMaxEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -2529,6 +2560,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.AddedSortOrder(); ok {
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -106,6 +106,7 @@ var (
{Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "concurrency", Type: field.TypeInt, Default: 3},
{Name: "load_factor", Type: field.TypeInt, Nullable: true},
{Name: "priority", Type: field.TypeInt, Default: 50},
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
@@ -132,7 +133,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "accounts_proxies_proxy",
Columns: []*schema.Column{AccountsColumns[27]},
Columns: []*schema.Column{AccountsColumns[28]},
RefColumns: []*schema.Column{ProxiesColumns[0]},
OnDelete: schema.SetNull,
},
@@ -151,52 +152,52 @@ var (
{
Name: "account_status",
Unique: false,
Columns: []*schema.Column{AccountsColumns[13]},
Columns: []*schema.Column{AccountsColumns[14]},
},
{
Name: "account_proxy_id",
Unique: false,
Columns: []*schema.Column{AccountsColumns[27]},
Columns: []*schema.Column{AccountsColumns[28]},
},
{
Name: "account_priority",
Unique: false,
Columns: []*schema.Column{AccountsColumns[11]},
Columns: []*schema.Column{AccountsColumns[12]},
},
{
Name: "account_last_used_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[15]},
Columns: []*schema.Column{AccountsColumns[16]},
},
{
Name: "account_schedulable",
Unique: false,
Columns: []*schema.Column{AccountsColumns[18]},
Columns: []*schema.Column{AccountsColumns[19]},
},
{
Name: "account_rate_limited_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[19]},
Columns: []*schema.Column{AccountsColumns[20]},
},
{
Name: "account_rate_limit_reset_at",
Unique: false,
Columns: []*schema.Column{AccountsColumns[20]},
Columns: []*schema.Column{AccountsColumns[21]},
},
{
Name: "account_overload_until",
Unique: false,
Columns: []*schema.Column{AccountsColumns[21]},
Columns: []*schema.Column{AccountsColumns[22]},
},
{
Name: "account_platform_priority",
Unique: false,
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]},
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[12]},
},
{
Name: "account_priority_status",
Unique: false,
Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]},
Columns: []*schema.Column{AccountsColumns[12], AccountsColumns[14]},
},
{
Name: "account_deleted_at",
@@ -406,6 +407,7 @@ var (
{Name: "mcp_xml_inject", Type: field.TypeBool, Default: true},
{Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "sort_order", Type: field.TypeInt, Default: 0},
{Name: "simulate_claude_max_enabled", Type: field.TypeBool, Default: false},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{

View File

@@ -2260,6 +2260,8 @@ type AccountMutation struct {
extra *map[string]interface{}
concurrency *int
addconcurrency *int
load_factor *int
addload_factor *int
priority *int
addpriority *int
rate_multiplier *float64
@@ -2845,6 +2847,76 @@ func (m *AccountMutation) ResetConcurrency() {
m.addconcurrency = nil
}
// SetLoadFactor sets the "load_factor" field.
func (m *AccountMutation) SetLoadFactor(i int) {
m.load_factor = &i
m.addload_factor = nil
}
// LoadFactor returns the value of the "load_factor" field in the mutation.
func (m *AccountMutation) LoadFactor() (r int, exists bool) {
v := m.load_factor
if v == nil {
return
}
return *v, true
}
// OldLoadFactor returns the old "load_factor" field's value of the Account entity.
// If the Account object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AccountMutation) OldLoadFactor(ctx context.Context) (v *int, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldLoadFactor is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldLoadFactor requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldLoadFactor: %w", err)
}
return oldValue.LoadFactor, nil
}
// AddLoadFactor adds i to the "load_factor" field.
func (m *AccountMutation) AddLoadFactor(i int) {
if m.addload_factor != nil {
*m.addload_factor += i
} else {
m.addload_factor = &i
}
}
// AddedLoadFactor returns the value that was added to the "load_factor" field in this mutation.
func (m *AccountMutation) AddedLoadFactor() (r int, exists bool) {
v := m.addload_factor
if v == nil {
return
}
return *v, true
}
// ClearLoadFactor clears the value of the "load_factor" field.
func (m *AccountMutation) ClearLoadFactor() {
m.load_factor = nil
m.addload_factor = nil
m.clearedFields[account.FieldLoadFactor] = struct{}{}
}
// LoadFactorCleared returns if the "load_factor" field was cleared in this mutation.
func (m *AccountMutation) LoadFactorCleared() bool {
_, ok := m.clearedFields[account.FieldLoadFactor]
return ok
}
// ResetLoadFactor resets all changes to the "load_factor" field.
func (m *AccountMutation) ResetLoadFactor() {
m.load_factor = nil
m.addload_factor = nil
delete(m.clearedFields, account.FieldLoadFactor)
}
// SetPriority sets the "priority" field.
func (m *AccountMutation) SetPriority(i int) {
m.priority = &i
@@ -3773,7 +3845,7 @@ func (m *AccountMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *AccountMutation) Fields() []string {
fields := make([]string, 0, 27)
fields := make([]string, 0, 28)
if m.created_at != nil {
fields = append(fields, account.FieldCreatedAt)
}
@@ -3807,6 +3879,9 @@ func (m *AccountMutation) Fields() []string {
if m.concurrency != nil {
fields = append(fields, account.FieldConcurrency)
}
if m.load_factor != nil {
fields = append(fields, account.FieldLoadFactor)
}
if m.priority != nil {
fields = append(fields, account.FieldPriority)
}
@@ -3885,6 +3960,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
return m.ProxyID()
case account.FieldConcurrency:
return m.Concurrency()
case account.FieldLoadFactor:
return m.LoadFactor()
case account.FieldPriority:
return m.Priority()
case account.FieldRateMultiplier:
@@ -3948,6 +4025,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldProxyID(ctx)
case account.FieldConcurrency:
return m.OldConcurrency(ctx)
case account.FieldLoadFactor:
return m.OldLoadFactor(ctx)
case account.FieldPriority:
return m.OldPriority(ctx)
case account.FieldRateMultiplier:
@@ -4066,6 +4145,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
}
m.SetConcurrency(v)
return nil
case account.FieldLoadFactor:
v, ok := value.(int)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetLoadFactor(v)
return nil
case account.FieldPriority:
v, ok := value.(int)
if !ok {
@@ -4189,6 +4275,9 @@ func (m *AccountMutation) AddedFields() []string {
if m.addconcurrency != nil {
fields = append(fields, account.FieldConcurrency)
}
if m.addload_factor != nil {
fields = append(fields, account.FieldLoadFactor)
}
if m.addpriority != nil {
fields = append(fields, account.FieldPriority)
}
@@ -4205,6 +4294,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) {
switch name {
case account.FieldConcurrency:
return m.AddedConcurrency()
case account.FieldLoadFactor:
return m.AddedLoadFactor()
case account.FieldPriority:
return m.AddedPriority()
case account.FieldRateMultiplier:
@@ -4225,6 +4316,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error {
}
m.AddConcurrency(v)
return nil
case account.FieldLoadFactor:
v, ok := value.(int)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddLoadFactor(v)
return nil
case account.FieldPriority:
v, ok := value.(int)
if !ok {
@@ -4256,6 +4354,9 @@ func (m *AccountMutation) ClearedFields() []string {
if m.FieldCleared(account.FieldProxyID) {
fields = append(fields, account.FieldProxyID)
}
if m.FieldCleared(account.FieldLoadFactor) {
fields = append(fields, account.FieldLoadFactor)
}
if m.FieldCleared(account.FieldErrorMessage) {
fields = append(fields, account.FieldErrorMessage)
}
@@ -4312,6 +4413,9 @@ func (m *AccountMutation) ClearField(name string) error {
case account.FieldProxyID:
m.ClearProxyID()
return nil
case account.FieldLoadFactor:
m.ClearLoadFactor()
return nil
case account.FieldErrorMessage:
m.ClearErrorMessage()
return nil
@@ -4386,6 +4490,9 @@ func (m *AccountMutation) ResetField(name string) error {
case account.FieldConcurrency:
m.ResetConcurrency()
return nil
case account.FieldLoadFactor:
m.ResetLoadFactor()
return nil
case account.FieldPriority:
m.ResetPriority()
return nil
@@ -8089,6 +8196,7 @@ type GroupMutation struct {
appendsupported_model_scopes []string
sort_order *int
addsort_order *int
simulate_claude_max_enabled *bool
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -9833,6 +9941,42 @@ func (m *GroupMutation) ResetSortOrder() {
m.addsort_order = nil
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (m *GroupMutation) SetSimulateClaudeMaxEnabled(b bool) {
m.simulate_claude_max_enabled = &b
}
// SimulateClaudeMaxEnabled returns the value of the "simulate_claude_max_enabled" field in the mutation.
func (m *GroupMutation) SimulateClaudeMaxEnabled() (r bool, exists bool) {
v := m.simulate_claude_max_enabled
if v == nil {
return
}
return *v, true
}
// OldSimulateClaudeMaxEnabled returns the old "simulate_claude_max_enabled" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldSimulateClaudeMaxEnabled(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldSimulateClaudeMaxEnabled is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldSimulateClaudeMaxEnabled requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldSimulateClaudeMaxEnabled: %w", err)
}
return oldValue.SimulateClaudeMaxEnabled, nil
}
// ResetSimulateClaudeMaxEnabled resets all changes to the "simulate_claude_max_enabled" field.
func (m *GroupMutation) ResetSimulateClaudeMaxEnabled() {
m.simulate_claude_max_enabled = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -10191,7 +10335,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 30)
fields := make([]string, 0, 31)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -10282,6 +10426,9 @@ func (m *GroupMutation) Fields() []string {
if m.sort_order != nil {
fields = append(fields, group.FieldSortOrder)
}
if m.simulate_claude_max_enabled != nil {
fields = append(fields, group.FieldSimulateClaudeMaxEnabled)
}
return fields
}
@@ -10350,6 +10497,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.SupportedModelScopes()
case group.FieldSortOrder:
return m.SortOrder()
case group.FieldSimulateClaudeMaxEnabled:
return m.SimulateClaudeMaxEnabled()
}
return nil, false
}
@@ -10419,6 +10568,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldSupportedModelScopes(ctx)
case group.FieldSortOrder:
return m.OldSortOrder(ctx)
case group.FieldSimulateClaudeMaxEnabled:
return m.OldSimulateClaudeMaxEnabled(ctx)
}
return nil, fmt.Errorf("unknown Group field %s", name)
}
@@ -10638,6 +10789,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetSortOrder(v)
return nil
case group.FieldSimulateClaudeMaxEnabled:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetSimulateClaudeMaxEnabled(v)
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -11065,6 +11223,9 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldSortOrder:
m.ResetSortOrder()
return nil
case group.FieldSimulateClaudeMaxEnabled:
m.ResetSimulateClaudeMaxEnabled()
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}

View File

@@ -212,29 +212,29 @@ func init() {
// account.DefaultConcurrency holds the default value on creation for the concurrency field.
account.DefaultConcurrency = accountDescConcurrency.Default.(int)
// accountDescPriority is the schema descriptor for priority field.
accountDescPriority := accountFields[8].Descriptor()
accountDescPriority := accountFields[9].Descriptor()
// account.DefaultPriority holds the default value on creation for the priority field.
account.DefaultPriority = accountDescPriority.Default.(int)
// accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
accountDescRateMultiplier := accountFields[9].Descriptor()
accountDescRateMultiplier := accountFields[10].Descriptor()
// account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
// accountDescStatus is the schema descriptor for status field.
accountDescStatus := accountFields[10].Descriptor()
accountDescStatus := accountFields[11].Descriptor()
// account.DefaultStatus holds the default value on creation for the status field.
account.DefaultStatus = accountDescStatus.Default.(string)
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
accountDescAutoPauseOnExpired := accountFields[15].Descriptor()
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
// accountDescSchedulable is the schema descriptor for schedulable field.
accountDescSchedulable := accountFields[15].Descriptor()
accountDescSchedulable := accountFields[16].Descriptor()
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
accountDescSessionWindowStatus := accountFields[23].Descriptor()
accountDescSessionWindowStatus := accountFields[24].Descriptor()
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
accountgroupFields := schema.AccountGroup{}.Fields()
@@ -447,6 +447,10 @@ func init() {
groupDescSortOrder := groupFields[26].Descriptor()
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
// groupDescSimulateClaudeMaxEnabled is the schema descriptor for simulate_claude_max_enabled field.
groupDescSimulateClaudeMaxEnabled := groupFields[27].Descriptor()
// group.DefaultSimulateClaudeMaxEnabled holds the default value on creation for the simulate_claude_max_enabled field.
group.DefaultSimulateClaudeMaxEnabled = groupDescSimulateClaudeMaxEnabled.Default.(bool)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
_ = idempotencyrecordMixinFields0

View File

@@ -97,6 +97,8 @@ func (Account) Fields() []ent.Field {
field.Int("concurrency").
Default(3),
field.Int("load_factor").Optional().Nillable(),
// priority: 账户优先级,数值越小优先级越高
// 调度器会优先使用高优先级的账户
field.Int("priority").

View File

@@ -33,8 +33,6 @@ func (Group) Mixin() []ent.Mixin {
func (Group) Fields() []ent.Field {
return []ent.Field{
// 唯一约束通过部分索引实现WHERE deleted_at IS NULL支持软删除后重用
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
field.String("name").
MaxLen(100).
NotEmpty(),
@@ -51,7 +49,6 @@ func (Group) Fields() []ent.Field {
MaxLen(20).
Default(domain.StatusActive),
// Subscription-related fields (added by migration 003)
field.String("platform").
MaxLen(50).
Default(domain.PlatformAnthropic),
@@ -73,7 +70,6 @@ func (Group) Fields() []ent.Field {
field.Int("default_validity_days").
Default(30),
// 图片生成计费配置antigravity 和 gemini 平台使用)
field.Float("image_price_1k").
Optional().
Nillable().
@@ -87,7 +83,6 @@ func (Group) Fields() []ent.Field {
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Sora 按次计费配置(阶段 1
field.Float("sora_image_price_360").
Optional().
Nillable().
@@ -109,45 +104,41 @@ func (Group) Fields() []ent.Field {
field.Int64("sora_storage_quota_bytes").
Default(0),
// Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only").
Default(false).
Comment("是否仅允许 Claude Code 客户端"),
Comment("allow Claude Code client only"),
field.Int64("fallback_group_id").
Optional().
Nillable().
Comment("Claude Code 请求降级使用的分组 ID"),
Comment("fallback group for non-Claude-Code requests"),
field.Int64("fallback_group_id_on_invalid_request").
Optional().
Nillable().
Comment("无效请求兜底使用的分组 ID"),
Comment("fallback group for invalid request"),
// 模型路由配置 (added by migration 040)
field.JSON("model_routing", map[string][]int64{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
// 模型路由开关 (added by migration 041)
Comment("model routing config: pattern -> account ids"),
field.Bool("model_routing_enabled").
Default(false).
Comment("是否启用模型路由配置"),
Comment("whether model routing is enabled"),
// MCP XML 协议注入开关 (added by migration 042)
field.Bool("mcp_xml_inject").
Default(true).
Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"),
Comment("whether MCP XML prompt injection is enabled"),
// 支持的模型系列 (added by migration 046)
field.JSON("supported_model_scopes", []string{}).
Default([]string{"claude", "gemini_text", "gemini_image"}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("支持的模型系列:claude, gemini_text, gemini_image"),
Comment("supported model scopes: claude, gemini_text, gemini_image"),
// 分组排序 (added by migration 052)
field.Int("sort_order").
Default(0).
Comment("分组显示排序,数值越小越靠前"),
Comment("group display order, lower comes first"),
field.Bool("simulate_claude_max_enabled").
Default(false).
Comment("simulate claude usage as claude-max style (1h cache write)"),
}
}
@@ -163,14 +154,11 @@ func (Group) Edges() []ent.Edge {
edge.From("allowed_users", User.Type).
Ref("allowed_groups").
Through("user_allowed_groups", UserAllowedGroup.Type),
// 注意fallback_group_id 直接作为字段使用,不定义 edge
// 这样允许多个分组指向同一个降级分组M2O 关系)
}
}
func (Group) Indexes() []ent.Index {
return []ent.Index{
// name 字段已在 Fields() 中声明 Unique(),无需重复索引
index.Fields("status"),
index.Fields("platform"),
index.Fields("subscription_type"),

View File

@@ -89,6 +89,7 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
@@ -140,6 +141,8 @@ require (
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/quic-go/qpack v0.6.0 // indirect

View File

@@ -124,6 +124,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
@@ -171,8 +173,6 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@@ -182,7 +182,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -203,6 +202,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -285,6 +286,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -398,8 +403,6 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
@@ -455,8 +458,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=

View File

@@ -102,6 +102,7 @@ type CreateAccountRequest struct {
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
LoadFactor *int `json:"load_factor"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
@@ -120,6 +121,7 @@ type UpdateAccountRequest struct {
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
LoadFactor *int `json:"load_factor"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
@@ -135,6 +137,7 @@ type BulkUpdateAccountsRequest struct {
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier"`
LoadFactor *int `json:"load_factor"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
GroupIDs *[]int64 `json:"group_ids"`
@@ -506,6 +509,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
LoadFactor: req.LoadFactor,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
@@ -575,6 +579,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
Concurrency: req.Concurrency, // 指针类型nil 表示未提供
Priority: req.Priority, // 指针类型nil 表示未提供
RateMultiplier: req.RateMultiplier,
LoadFactor: req.LoadFactor,
Status: req.Status,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
@@ -1101,6 +1106,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req.Concurrency != nil ||
req.Priority != nil ||
req.RateMultiplier != nil ||
req.LoadFactor != nil ||
req.Status != "" ||
req.Schedulable != nil ||
req.GroupIDs != nil ||
@@ -1119,6 +1125,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
LoadFactor: req.LoadFactor,
Status: req.Status,
Schedulable: req.Schedulable,
GroupIDs: req.GroupIDs,
@@ -1132,6 +1139,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
})
return
}
@@ -1328,6 +1341,29 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// ResetQuota handles resetting account quota usage
// POST /api/v1/admin/accounts/:id/reset-quota
func (h *AccountHandler) ResetQuota(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
if err := h.adminService.ResetAccountQuota(c.Request.Context(), accountID); err != nil {
response.InternalError(c, "Failed to reset account quota: "+err.Error())
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// GetTempUnschedulable handles getting temporary unschedulable status
// GET /api/v1/admin/accounts/:id/temp-unschedulable
func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) {

View File

@@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
require.Contains(t, resp["message"], "claude-max")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
@@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
require.Contains(t, resp["message"], "claude-max")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)

View File

@@ -425,5 +425,9 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return nil, service.ErrAPIKeyNotFound
}
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
return nil
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -46,9 +46,10 @@ type CreateGroupRequest struct {
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// Sora 存储配额
@@ -81,9 +82,10 @@ type UpdateGroupRequest struct {
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// Sora 存储配额
@@ -201,6 +203,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
@@ -252,6 +255,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,

View File

@@ -122,13 +122,14 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
SortOrder: g.SortOrder,
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
SortOrder: g.SortOrder,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
@@ -183,6 +184,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
LoadFactor: a.LoadFactor,
Priority: a.Priority,
RateMultiplier: a.BillingRateMultiplier(),
Status: a.Status,
@@ -248,6 +250,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
}
}
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
if a.Type == service.AccountTypeAPIKey {
if limit := a.GetQuotaLimit(); limit > 0 {
out.QuotaLimit = &limit
}
used := a.GetQuotaUsed()
if out.QuotaLimit != nil {
out.QuotaUsed = &used
}
}
return out
}

View File

@@ -111,6 +111,8 @@ type AdminGroup struct {
// MCP XML 协议注入(仅 antigravity 平台使用)
MCPXMLInject bool `json:"mcp_xml_inject"`
// Claude usage 模拟开关(仅管理员可见)
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
@@ -131,6 +133,7 @@ type Account struct {
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
LoadFactor *int `json:"load_factor,omitempty"`
Priority int `json:"priority"`
RateMultiplier float64 `json:"rate_multiplier"`
Status string `json:"status"`
@@ -185,6 +188,10 @@ type Account struct {
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
// API Key 账号配额限制
QuotaLimit *float64 `json:"quota_limit,omitempty"`
QuotaUsed *float64 `json:"quota_used,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`

View File

@@ -439,6 +439,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
APIKey: apiKey,
User: apiKey.User,
Account: account,
@@ -630,6 +631,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// ===== 用户消息串行队列 END =====
// 转发请求 - 根据账号平台分流
c.Set("parsed_request", parsedReq)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
@@ -734,6 +736,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,

View File

@@ -2132,6 +2132,14 @@ func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service
return 0, nil
}
func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error {
return nil
}
func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error {
return nil
}
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
var _ service.SoraClient = (*stubSoraClientForHandler)(nil)

View File

@@ -216,6 +216,14 @@ func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
return 0, nil
}
func (r *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (r *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
func (r *stubAccountRepo) listSchedulable() []service.Account {
var result []service.Account
for _, acc := range r.accounts {

View File

@@ -18,6 +18,9 @@ const (
BlockTypeFunction
)
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
type UsageMapHook func(usageMap map[string]any)
// StreamingProcessor 流式响应处理器
type StreamingProcessor struct {
blockType BlockType
@@ -30,6 +33,7 @@ type StreamingProcessor struct {
originalModel string
webSearchQueries []string
groundingChunks []GeminiGroundingChunk
usageMapHook UsageMapHook
// 累计 usage
inputTokens int
@@ -45,6 +49,25 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
}
}
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
p.usageMapHook = fn
}
func usageToMap(u ClaudeUsage) map[string]any {
m := map[string]any{
"input_tokens": u.InputTokens,
"output_tokens": u.OutputTokens,
}
if u.CacheCreationInputTokens > 0 {
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
}
if u.CacheReadInputTokens > 0 {
m["cache_read_input_tokens"] = u.CacheReadInputTokens
}
return m
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line)
@@ -158,6 +181,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
responseID = "msg_" + generateRandomID()
}
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
message := map[string]any{
"id": responseID,
"type": "message",
@@ -166,7 +196,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
"model": p.originalModel,
"stop_reason": nil,
"stop_sequence": nil,
"usage": usage,
"usage": usageValue,
}
event := map[string]any{
@@ -477,13 +507,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
CacheReadInputTokens: p.cacheReadTokens,
}
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
deltaEvent := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": usage,
"usage": usageValue,
}
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))

View File

@@ -84,6 +84,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.LoadFactor != nil {
builder.SetLoadFactor(*account.LoadFactor)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
@@ -318,6 +321,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.LoadFactor != nil {
builder.SetLoadFactor(*account.LoadFactor)
} else {
builder.ClearLoadFactor()
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
@@ -1223,6 +1231,15 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.RateMultiplier)
idx++
}
if updates.LoadFactor != nil {
if *updates.LoadFactor <= 0 {
setClauses = append(setClauses, "load_factor = NULL")
} else {
setClauses = append(setClauses, "load_factor = $"+itoa(idx))
args = append(args, *updates.LoadFactor)
idx++
}
}
if updates.Status != nil {
setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status)
@@ -1545,6 +1562,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
LoadFactor: m.LoadFactor,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
@@ -1657,3 +1675,60 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
return r.accountsToService(ctx, accounts)
}
// IncrementQuotaUsed 原子递增账号的 extra.quota_used 字段
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
rows, err := r.sql.QueryContext(ctx,
`UPDATE accounts SET extra = jsonb_set(
COALESCE(extra, '{}'::jsonb),
'{quota_used}',
to_jsonb(COALESCE((extra->>'quota_used')::numeric, 0) + $1)
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
COALESCE((extra->>'quota_limit')::numeric, 0)`,
amount, id)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
var newUsed, limit float64
if rows.Next() {
if err := rows.Scan(&newUsed, &limit); err != nil {
return err
}
}
if err := rows.Err(); err != nil {
return err
}
// 配额刚超限时触发调度快照刷新,使账号及时从调度候选中移除
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", id, err)
}
}
return nil
}
// ResetQuotaUsed 重置账号的 extra.quota_used 为 0
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx,
`UPDATE accounts SET extra = jsonb_set(
COALESCE(extra, '{}'::jsonb),
'{quota_used}',
'0'::jsonb
), updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`,
id)
if err != nil {
return err
}
// 重置配额后触发调度快照刷新,使账号重新参与调度
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue quota reset failed: account=%d err=%v", id, err)
}
return nil
}

View File

@@ -164,6 +164,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldModelRoutingEnabled,
group.FieldModelRouting,
group.FieldMcpXMLInject,
group.FieldSimulateClaudeMaxEnabled,
group.FieldSupportedModelScopes,
)
}).
@@ -617,6 +618,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.McpXMLInject,
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
SupportedModelScopes: g.SupportedModelScopes,
SortOrder: g.SortOrder,
CreatedAt: g.CreatedAt,

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -95,7 +96,8 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
msg := fmt.Sprintf("API returned status %d: %s", resp.StatusCode, string(body))
return nil, infraerrors.New(http.StatusInternalServerError, "UPSTREAM_ERROR", msg)
}
var usageResp service.ClaudeUsageResponse

View File

@@ -59,7 +59,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -125,7 +126,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
// 显式处理可空字段nil 需要 clear非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {

View File

@@ -1870,7 +1870,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
query := `
SELECT
COALESCE(ul.group_id, 0) as group_id,
COALESCE(g.name, '') as group_name,
COALESCE(g.name, '(无分组)') as group_name,
COUNT(*) as requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(ul.total_cost), 0) as cost,

View File

@@ -1096,6 +1096,14 @@ func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map
return errors.New("not implemented")
}
func (s *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
s.bulkUpdateIDs = append([]int64{}, ids...)
return int64(len(ids)), nil

View File

@@ -252,6 +252,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/reset-quota", h.Admin.Account.ResetQuota)
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)

View File

@@ -28,6 +28,7 @@ type Account struct {
// RateMultiplier 账号计费倍率(>=0允许 0 表示该账号计费为 0
// 使用指针用于兼容旧版本调度缓存Redis中缺字段的情况nil 表示按 1.0 处理。
RateMultiplier *float64
LoadFactor *int // 调度负载因子nil 表示使用 Concurrency
Status string
ErrorMessage string
LastUsedAt *time.Time
@@ -88,6 +89,19 @@ func (a *Account) BillingRateMultiplier() float64 {
return *a.RateMultiplier
}
func (a *Account) EffectiveLoadFactor() int {
if a == nil {
return 1
}
if a.LoadFactor != nil && *a.LoadFactor > 0 {
return *a.LoadFactor
}
if a.Concurrency > 0 {
return a.Concurrency
}
return 1
}
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
@@ -1117,6 +1131,38 @@ func (a *Account) GetCacheTTLOverrideTarget() string {
return "5m"
}
// GetQuotaLimit 获取 API Key 账号的配额限制(美元)
// 返回 0 表示未启用
func (a *Account) GetQuotaLimit() float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["quota_limit"]; ok {
return parseExtraFloat64(v)
}
return 0
}
// GetQuotaUsed 获取 API Key 账号的已用配额(美元)
func (a *Account) GetQuotaUsed() float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["quota_used"]; ok {
return parseExtraFloat64(v)
}
return 0
}
// IsQuotaExceeded 检查 API Key 账号配额是否已超限
func (a *Account) IsQuotaExceeded() bool {
limit := a.GetQuotaLimit()
if limit <= 0 {
return false
}
return a.GetQuotaUsed() >= limit
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {

View File

@@ -0,0 +1,42 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestEffectiveLoadFactor_NilAccount(t *testing.T) {
var a *Account
require.Equal(t, 1, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) {
a := &Account{Concurrency: 5}
require.Equal(t, 5, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) {
a := &Account{Concurrency: 0}
require.Equal(t, 1, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) {
a := &Account{Concurrency: 5, LoadFactor: intPtr(20)}
require.Equal(t, 20, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) {
a := &Account{Concurrency: 5, LoadFactor: intPtr(0)}
require.Equal(t, 5, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) {
a := &Account{Concurrency: 3, LoadFactor: intPtr(-1)}
require.Equal(t, 3, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) {
a := &Account{Concurrency: 0, LoadFactor: intPtr(0)}
require.Equal(t, 1, a.EffectiveLoadFactor())
}

View File

@@ -68,6 +68,10 @@ type AccountRepository interface {
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
// ResetQuotaUsed 重置 API Key 账号的配额用量为 0
ResetQuotaUsed(ctx context.Context, id int64) error
}
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
@@ -78,6 +82,7 @@ type AccountBulkUpdate struct {
Concurrency *int
Priority *int
RateMultiplier *float64
LoadFactor *int
Status *string
Schedulable *bool
Credentials map[string]any

View File

@@ -199,6 +199,14 @@ func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates A
panic("unexpected BulkUpdate call")
}
func (s *accountRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (s *accountRepoStub) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
// TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。
// 预期行为:
// - ExistsByID 返回 false账号不存在

View File

@@ -180,7 +180,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
}
if account.Platform == PlatformAntigravity {
return s.testAntigravityAccountConnection(c, account, modelID)
return s.routeAntigravityTest(c, account, modelID)
}
if account.Platform == PlatformSora {
@@ -1177,6 +1177,18 @@ func truncateSoraErrorBody(body []byte, max int) string {
return soraerror.TruncateBody(body, max)
}
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致OAuth/Upstream 走 CRS 中转。
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error {
if account.Type == AccountTypeAPIKey {
if strings.HasPrefix(modelID, "gemini-") {
return s.testGeminiAccountConnection(c, account, modelID)
}
return s.testClaudeAccountConnection(c, account, modelID)
}
return s.testAntigravityAccountConnection(c, account, modelID)
}
// testAntigravityAccountConnection tests an Antigravity account's connection
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {

View File

@@ -84,6 +84,7 @@ type AdminService interface {
DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
ResetAccountQuota(ctx context.Context, id int64) error
}
// CreateUserInput represents input for creating a new user via admin operations.
@@ -137,9 +138,10 @@ type CreateGroupInput struct {
// 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
MCPXMLInject *bool
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
MCPXMLInject *bool
SimulateClaudeMaxEnabled *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// Sora 存储配额
@@ -173,9 +175,10 @@ type UpdateGroupInput struct {
// 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
MCPXMLInject *bool
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
MCPXMLInject *bool
SimulateClaudeMaxEnabled *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// Sora 存储配额
@@ -195,6 +198,7 @@ type CreateAccountInput struct {
Concurrency int
Priority int
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
LoadFactor *int
GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
@@ -215,6 +219,7 @@ type UpdateAccountInput struct {
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
LoadFactor *int
Status string
GroupIDs *[]int64
ExpiresAt *int64
@@ -230,6 +235,7 @@ type BulkUpdateAccountsInput struct {
Concurrency *int
Priority *int
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
LoadFactor *int
Status string
Schedulable *bool
GroupIDs *[]int64
@@ -353,6 +359,10 @@ type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
type proxyQualityTarget struct {
Target string
URL string
@@ -428,10 +438,6 @@ type userGroupRateBatchReader interface {
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAdminService creates a new AdminService
func NewAdminService(
userRepo UserRepository,
@@ -847,6 +853,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
if input.MCPXMLInject != nil {
mcpXMLInject = *input.MCPXMLInject
}
simulateClaudeMaxEnabled := false
if input.SimulateClaudeMaxEnabled != nil {
if platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
}
simulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var accountIDsToCopy []int64
@@ -903,6 +916,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SimulateClaudeMaxEnabled: simulateClaudeMaxEnabled,
SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
@@ -1112,6 +1126,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.MCPXMLInject != nil {
group.MCPXMLInject = *input.MCPXMLInject
}
if input.SimulateClaudeMaxEnabled != nil {
if group.Platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
}
group.SimulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
}
if group.Platform != PlatformAnthropic {
group.SimulateClaudeMaxEnabled = false
}
// 支持的模型系列(仅 antigravity 平台使用)
if input.SupportedModelScopes != nil {
@@ -1413,6 +1436,9 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
account.RateMultiplier = input.RateMultiplier
}
if input.LoadFactor != nil && *input.LoadFactor > 0 {
account.LoadFactor = input.LoadFactor
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
@@ -1458,6 +1484,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Credentials = input.Credentials
}
if len(input.Extra) > 0 {
// 保留 quota_used防止编辑账号时意外重置配额用量
if oldQuotaUsed, ok := account.Extra["quota_used"]; ok {
input.Extra["quota_used"] = oldQuotaUsed
}
account.Extra = input.Extra
}
if input.ProxyID != nil {
@@ -1483,6 +1513,13 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
}
account.RateMultiplier = input.RateMultiplier
}
if input.LoadFactor != nil {
if *input.LoadFactor <= 0 {
account.LoadFactor = nil // 0 或负数表示清除
} else {
account.LoadFactor = input.LoadFactor
}
}
if input.Status != "" {
account.Status = input.Status
}
@@ -1616,6 +1653,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.RateMultiplier != nil {
repoUpdates.RateMultiplier = input.RateMultiplier
}
if input.LoadFactor != nil {
repoUpdates.LoadFactor = input.LoadFactor
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
@@ -2439,3 +2479,7 @@ func (e *MixedChannelError) Error() string {
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
}
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
return s.accountRepo.ResetQuotaUsed(ctx, id)
}

View File

@@ -43,6 +43,16 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
return nil
}
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
if err, ok := s.listByGroupErr[groupID]; ok {
return nil, err
}
if rows, ok := s.listByGroupData[groupID]; ok {
return rows, nil
}
return nil, nil
}
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
s.getByIDsCalled = true
s.getByIDsIDs = append([]int64{}, ids...)
@@ -63,16 +73,6 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
return nil, errors.New("account not found")
}
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
if err, ok := s.listByGroupErr[groupID]; ok {
return nil, err
}
if rows, ok := s.listByGroupData[groupID]; ok {
return rows, nil
}
return nil, nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}

View File

@@ -785,3 +785,57 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_CreateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
enabled := true
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "openai-group",
Platform: PlatformOpenAI,
SimulateClaudeMaxEnabled: &enabled,
})
require.Error(t, err)
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
require.Nil(t, repo.created)
}
func TestAdminService_UpdateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "openai-group",
Platform: PlatformOpenAI,
Status: StatusActive,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
enabled := true
_, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
SimulateClaudeMaxEnabled: &enabled,
})
require.Error(t, err)
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_ClearsSimulateClaudeMaxWhenPlatformChanges(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "anthropic-group",
Platform: PlatformAnthropic,
Status: StatusActive,
SimulateClaudeMaxEnabled: true,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
Platform: PlatformOpenAI,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.False(t, repo.updated.SimulateClaudeMaxEnabled)
}

View File

@@ -1599,7 +1599,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var clientDisconnect bool
if claudeReq.Stream {
// 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel, account.ID)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
return nil, err
@@ -1609,7 +1609,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
clientDisconnect = streamRes.clientDisconnect
} else {
// 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel, account.ID)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
return nil, err
@@ -1618,6 +1618,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
firstTokenMs = streamRes.firstTokenMs
}
// Claude Max cache billing: 同步 ForwardResult.Usage 与客户端响应体一致
applyClaudeMaxCacheBillingPolicyToUsage(usage, parsedRequestFromGinContext(c), claudeMaxGroupFromGinContext(c), originalModel, account.ID)
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
@@ -3415,7 +3418,7 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
// 用于处理客户端非流式请求但上游只支持流式的情况
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
@@ -3573,6 +3576,9 @@ returnResponse:
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
// Claude Max cache billing simulation (non-streaming)
claudeResp = applyClaudeMaxNonStreamingRewrite(c, claudeResp, agUsage, originalModel, accountID)
c.Data(http.StatusOK, "application/json", claudeResp)
// 转换为 service.ClaudeUsage
@@ -3587,7 +3593,7 @@ returnResponse:
}
// handleClaudeStreamingResponse 处理 Claude 流式响应Gemini SSE → Claude SSE 转换)
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
@@ -3600,6 +3606,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
processor := antigravity.NewStreamingProcessor(originalModel)
setupClaudeMaxStreamingHook(c, processor, originalModel, accountID)
var firstTokenMs *int
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)

View File

@@ -710,7 +710,7 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
_ = pr.Close()
require.NoError(t, err)
@@ -787,7 +787,7 @@ func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro", 0)
_ = pr.Close()
require.NoError(t, err)
@@ -990,7 +990,7 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
_ = pr.Close()
require.NoError(t, err)
@@ -1014,7 +1014,7 @@ func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
require.NoError(t, err)
require.NotNil(t, result)

View File

@@ -59,9 +59,10 @@ type APIKeyAuthGroupSnapshot struct {
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject bool `json:"mcp_xml_inject"`
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject bool `json:"mcp_xml_inject"`
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`

View File

@@ -244,6 +244,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
MCPXMLInject: apiKey.Group.MCPXMLInject,
SimulateClaudeMaxEnabled: apiKey.Group.SimulateClaudeMaxEnabled,
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
}
}
@@ -301,6 +302,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
MCPXMLInject: snapshot.Group.MCPXMLInject,
SimulateClaudeMaxEnabled: snapshot.Group.SimulateClaudeMaxEnabled,
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
}
}

View File

@@ -0,0 +1,450 @@
package service
import (
"encoding/json"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
)
type claudeMaxCacheBillingOutcome struct {
Simulated bool
}
func applyClaudeMaxCacheBillingPolicyToUsage(usage *ClaudeUsage, parsed *ParsedRequest, group *Group, model string, accountID int64) claudeMaxCacheBillingOutcome {
var out claudeMaxCacheBillingOutcome
if usage == nil || !shouldApplyClaudeMaxBillingRulesForUsage(group, model, parsed) {
return out
}
resolvedModel := strings.TrimSpace(model)
if resolvedModel == "" && parsed != nil {
resolvedModel = strings.TrimSpace(parsed.Model)
}
if hasCacheCreationTokens(*usage) {
// Upstream already returned cache creation usage; keep original usage.
return out
}
if !shouldSimulateClaudeMaxUsageForUsage(*usage, parsed) {
return out
}
beforeInputTokens := usage.InputTokens
out.Simulated = safelyProjectUsageToClaudeMax1H(usage, parsed)
if out.Simulated {
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
resolvedModel,
accountID,
beforeInputTokens,
usage.InputTokens,
usage.CacheCreation1hTokens,
)
}
return out
}
func isClaudeFamilyModel(model string) bool {
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
if normalized == "" {
return false
}
return strings.Contains(normalized, "claude-")
}
func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool {
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
return false
}
return shouldApplyClaudeMaxBillingRulesForUsage(input.APIKey.Group, input.Result.Model, input.ParsedRequest)
}
func shouldApplyClaudeMaxBillingRulesForUsage(group *Group, model string, parsed *ParsedRequest) bool {
if group == nil {
return false
}
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
return false
}
resolvedModel := model
if resolvedModel == "" && parsed != nil {
resolvedModel = parsed.Model
}
if !isClaudeFamilyModel(resolvedModel) {
return false
}
return true
}
func hasCacheCreationTokens(usage ClaudeUsage) bool {
return usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0
}
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
if input == nil || input.Result == nil {
return false
}
if !shouldApplyClaudeMaxBillingRules(input) {
return false
}
return shouldSimulateClaudeMaxUsageForUsage(input.Result.Usage, input.ParsedRequest)
}
func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool {
if usage.InputTokens <= 0 {
return false
}
if hasCacheCreationTokens(usage) {
return false
}
if !hasClaudeCacheSignals(parsed) {
return false
}
return true
}
func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) {
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r)
changed = false
}
}()
return projectUsageToClaudeMax1H(usage, parsed)
}
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
if usage == nil {
return false
}
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if totalWindowTokens <= 1 {
return false
}
simulatedInputTokens := computeClaudeMaxProjectedInputTokens(totalWindowTokens, parsed)
if simulatedInputTokens <= 0 {
simulatedInputTokens = 1
}
if simulatedInputTokens >= totalWindowTokens {
simulatedInputTokens = totalWindowTokens - 1
}
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
if usage.InputTokens == simulatedInputTokens &&
usage.CacheCreation5mTokens == 0 &&
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
usage.CacheCreationInputTokens == cacheCreation1hTokens {
return false
}
usage.InputTokens = simulatedInputTokens
usage.CacheCreation5mTokens = 0
usage.CacheCreation1hTokens = cacheCreation1hTokens
usage.CacheCreationInputTokens = cacheCreation1hTokens
return true
}
type claudeCacheProjection struct {
HasBreakpoint bool
BreakpointCount int
TotalEstimatedTokens int
TailEstimatedTokens int
}
func computeClaudeMaxProjectedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
if totalWindowTokens <= 1 {
return totalWindowTokens
}
projection := analyzeClaudeCacheProjection(parsed)
if !projection.HasBreakpoint || projection.TotalEstimatedTokens <= 0 || projection.TailEstimatedTokens <= 0 {
return totalWindowTokens
}
totalEstimate := int64(projection.TotalEstimatedTokens)
tailEstimate := int64(projection.TailEstimatedTokens)
if tailEstimate > totalEstimate {
tailEstimate = totalEstimate
}
scaled := (int64(totalWindowTokens)*tailEstimate + totalEstimate/2) / totalEstimate
if scaled <= 0 {
scaled = 1
}
if scaled >= int64(totalWindowTokens) {
scaled = int64(totalWindowTokens - 1)
}
return int(scaled)
}
func hasClaudeCacheSignals(parsed *ParsedRequest) bool {
if parsed == nil {
return false
}
if hasTopLevelEphemeralCacheControl(parsed) {
return true
}
return countExplicitCacheBreakpoints(parsed) > 0
}
func hasTopLevelEphemeralCacheControl(parsed *ParsedRequest) bool {
if parsed == nil || len(parsed.Body) == 0 {
return false
}
cacheType := strings.TrimSpace(gjson.GetBytes(parsed.Body, "cache_control.type").String())
return strings.EqualFold(cacheType, "ephemeral")
}
func analyzeClaudeCacheProjection(parsed *ParsedRequest) claudeCacheProjection {
var projection claudeCacheProjection
if parsed == nil {
return projection
}
total := 0
lastBreakpointAt := -1
switch system := parsed.System.(type) {
case string:
total += claudeMaxMessageOverheadTokens + estimateClaudeTextTokens(system)
case []any:
for _, raw := range system {
block, ok := raw.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
total += estimateClaudeBlockTokens(block)
if hasEphemeralCacheControl(block) {
lastBreakpointAt = total
projection.BreakpointCount++
projection.HasBreakpoint = true
}
}
}
for _, rawMsg := range parsed.Messages {
total += claudeMaxMessageOverheadTokens
msg, ok := rawMsg.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
content, exists := msg["content"]
if !exists {
continue
}
msgTokens, msgLastBreak, msgBreakCount := estimateClaudeContentTokens(content)
total += msgTokens
if msgBreakCount > 0 {
lastBreakpointAt = total - msgTokens + msgLastBreak
projection.BreakpointCount += msgBreakCount
projection.HasBreakpoint = true
}
}
if total <= 0 {
total = 1
}
projection.TotalEstimatedTokens = total
if projection.HasBreakpoint && lastBreakpointAt >= 0 {
tail := total - lastBreakpointAt
if tail <= 0 {
tail = 1
}
projection.TailEstimatedTokens = tail
return projection
}
if hasTopLevelEphemeralCacheControl(parsed) {
tail := estimateLastUserMessageTokens(parsed)
if tail <= 0 {
tail = 1
}
projection.HasBreakpoint = true
projection.BreakpointCount = 1
projection.TailEstimatedTokens = tail
}
return projection
}
func countExplicitCacheBreakpoints(parsed *ParsedRequest) int {
if parsed == nil {
return 0
}
total := 0
if system, ok := parsed.System.([]any); ok {
for _, raw := range system {
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
total++
}
}
}
for _, rawMsg := range parsed.Messages {
msg, ok := rawMsg.(map[string]any)
if !ok {
continue
}
content, ok := msg["content"].([]any)
if !ok {
continue
}
for _, raw := range content {
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
total++
}
}
}
return total
}
func hasEphemeralCacheControl(block map[string]any) bool {
if block == nil {
return false
}
raw, ok := block["cache_control"]
if !ok || raw == nil {
return false
}
switch cc := raw.(type) {
case map[string]any:
cacheType, _ := cc["type"].(string)
return strings.EqualFold(strings.TrimSpace(cacheType), "ephemeral")
case map[string]string:
return strings.EqualFold(strings.TrimSpace(cc["type"]), "ephemeral")
default:
return false
}
}
func estimateClaudeContentTokens(content any) (tokens int, lastBreakAt int, breakpointCount int) {
switch value := content.(type) {
case string:
return estimateClaudeTextTokens(value), -1, 0
case []any:
total := 0
lastBreak := -1
breaks := 0
for _, raw := range value {
block, ok := raw.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
total += estimateClaudeBlockTokens(block)
if hasEphemeralCacheControl(block) {
lastBreak = total
breaks++
}
}
return total, lastBreak, breaks
default:
return estimateStructuredTokens(value), -1, 0
}
}
func estimateClaudeBlockTokens(block map[string]any) int {
if block == nil {
return claudeMaxUnknownContentTokens
}
tokens := claudeMaxBlockOverheadTokens
blockType, _ := block["type"].(string)
switch blockType {
case "text":
if text, ok := block["text"].(string); ok {
tokens += estimateClaudeTextTokens(text)
}
case "tool_result":
if content, ok := block["content"]; ok {
nested, _, _ := estimateClaudeContentTokens(content)
tokens += nested
}
case "tool_use":
if name, ok := block["name"].(string); ok {
tokens += estimateClaudeTextTokens(name)
}
if input, ok := block["input"]; ok {
tokens += estimateStructuredTokens(input)
}
default:
if text, ok := block["text"].(string); ok {
tokens += estimateClaudeTextTokens(text)
} else if content, ok := block["content"]; ok {
nested, _, _ := estimateClaudeContentTokens(content)
tokens += nested
}
}
if tokens <= claudeMaxBlockOverheadTokens {
tokens += claudeMaxUnknownContentTokens
}
return tokens
}
func estimateLastUserMessageTokens(parsed *ParsedRequest) int {
if parsed == nil || len(parsed.Messages) == 0 {
return 0
}
for i := len(parsed.Messages) - 1; i >= 0; i-- {
msg, ok := parsed.Messages[i].(map[string]any)
if !ok {
continue
}
role, _ := msg["role"].(string)
if !strings.EqualFold(strings.TrimSpace(role), "user") {
continue
}
tokens, _, _ := estimateClaudeContentTokens(msg["content"])
return claudeMaxMessageOverheadTokens + tokens
}
return 0
}
func estimateStructuredTokens(v any) int {
if v == nil {
return 0
}
raw, err := json.Marshal(v)
if err != nil {
return claudeMaxUnknownContentTokens
}
return estimateClaudeTextTokens(string(raw))
}
func estimateClaudeTextTokens(text string) int {
if tokens, ok := estimateTokensByThirdPartyTokenizer(text); ok {
return tokens
}
return estimateClaudeTextTokensHeuristic(text)
}
func estimateClaudeTextTokensHeuristic(text string) int {
normalized := strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
if normalized == "" {
return 0
}
asciiChars := 0
nonASCIIChars := 0
for _, r := range normalized {
if r <= 127 {
asciiChars++
} else {
nonASCIIChars++
}
}
tokens := nonASCIIChars
if asciiChars > 0 {
tokens += (asciiChars + 3) / 4
}
if words := len(strings.Fields(normalized)); words > tokens {
tokens = words
}
if tokens <= 0 {
return 1
}
return tokens
}

View File

@@ -0,0 +1,156 @@
package service
import (
"strings"
"testing"
)
func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
usage := &ClaudeUsage{
InputTokens: 1200,
CacheCreationInputTokens: 0,
CacheCreation5mTokens: 0,
CacheCreation1hTokens: 0,
}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": strings.Repeat("cached context ", 200),
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "summarize quickly",
},
},
},
},
}
changed := projectUsageToClaudeMax1H(usage, parsed)
if !changed {
t.Fatalf("expected usage to be projected")
}
total := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if total != 1200 {
t.Fatalf("total tokens changed: got=%d want=%d", total, 1200)
}
if usage.CacheCreation5mTokens != 0 {
t.Fatalf("cache_creation_5m should be 0, got=%d", usage.CacheCreation5mTokens)
}
if usage.InputTokens <= 0 || usage.InputTokens >= 1200 {
t.Fatalf("simulated input out of range, got=%d", usage.InputTokens)
}
if usage.InputTokens > 100 {
t.Fatalf("simulated input should stay near cache breakpoint tail, got=%d", usage.InputTokens)
}
if usage.CacheCreation1hTokens <= 0 {
t.Fatalf("cache_creation_1h should be > 0, got=%d", usage.CacheCreation1hTokens)
}
if usage.CacheCreationInputTokens != usage.CacheCreation1hTokens {
t.Fatalf("cache_creation_input_tokens mismatch: got=%d want=%d", usage.CacheCreationInputTokens, usage.CacheCreation1hTokens)
}
}
func TestComputeClaudeMaxProjectedInputTokens_Deterministic(t *testing.T) {
parsed := &ParsedRequest{
Model: "claude-opus-4-5",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "build context",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "what is failing now",
},
},
},
},
}
got1 := computeClaudeMaxProjectedInputTokens(4096, parsed)
got2 := computeClaudeMaxProjectedInputTokens(4096, parsed)
if got1 != got2 {
t.Fatalf("non-deterministic input tokens: %d != %d", got1, got2)
}
}
func TestShouldSimulateClaudeMaxUsage(t *testing.T) {
group := &Group{
Platform: PlatformAnthropic,
SimulateClaudeMaxEnabled: true,
}
input := &RecordUsageInput{
Result: &ForwardResult{
Model: "claude-sonnet-4-5",
Usage: ClaudeUsage{
InputTokens: 3000,
CacheCreationInputTokens: 0,
CacheCreation5mTokens: 0,
CacheCreation1hTokens: 0,
},
},
ParsedRequest: &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "cached",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "tail",
},
},
},
},
},
APIKey: &APIKey{Group: group},
}
if !shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=true for claude group with cache signal")
}
input.ParsedRequest = &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "no cache signal"},
},
}
if shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=false when request has no cache signal")
}
input.ParsedRequest = &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "cached",
"cache_control": map[string]any{"type": "ephemeral"},
},
},
},
},
}
input.Result.Usage.CacheCreationInputTokens = 100
if shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=false when cache creation already exists")
}
}

View File

@@ -0,0 +1,41 @@
package service
import (
"sync"
tiktoken "github.com/pkoukk/tiktoken-go"
tiktokenloader "github.com/pkoukk/tiktoken-go-loader"
)
var (
claudeTokenizerOnce sync.Once
claudeTokenizer *tiktoken.Tiktoken
)
func getClaudeTokenizer() *tiktoken.Tiktoken {
claudeTokenizerOnce.Do(func() {
// Use offline loader to avoid runtime dictionary download.
tiktoken.SetBpeLoader(tiktokenloader.NewOfflineLoader())
// Use a high-capacity tokenizer as the default approximation for Claude payloads.
enc, err := tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
if err != nil {
enc, err = tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
}
if err == nil {
claudeTokenizer = enc
}
})
return claudeTokenizer
}
func estimateTokensByThirdPartyTokenizer(text string) (int, bool) {
enc := getClaudeTokenizer()
if enc == nil {
return 0, false
}
tokens := len(enc.EncodeOrdinary(text))
if tokens <= 0 {
return 0, false
}
return tokens, true
}

View File

@@ -331,8 +331,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
// Uses a detached context with timeout to prevent HTTP request cancellation from
// causing the entire batch to fail (which would show all concurrency as 0).
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
@@ -344,5 +345,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
}
return result, nil
}
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
// Use a detached context so that a cancelled HTTP request doesn't cause
// the Redis pipeline to fail and return all-zero concurrency counts.
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
}

View File

@@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
v, exists := c.Get(OpsSkipPassthroughKey)
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
boolVal, ok := v.(bool)
assert.True(t, ok, "value should be bool")
assert.True(t, ok, "value should be a bool")
assert.True(t, boolVal)
}

View File

@@ -0,0 +1,196 @@
package service
import (
"context"
"encoding/json"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/tidwall/sjson"
)
type claudeMaxResponseRewriteContext struct {
Parsed *ParsedRequest
Group *Group
}
type claudeMaxResponseRewriteContextKeyType struct{}
var claudeMaxResponseRewriteContextKey = claudeMaxResponseRewriteContextKeyType{}
func withClaudeMaxResponseRewriteContext(ctx context.Context, c *gin.Context, parsed *ParsedRequest) context.Context {
if ctx == nil {
ctx = context.Background()
}
value := claudeMaxResponseRewriteContext{
Parsed: parsed,
Group: claudeMaxGroupFromGinContext(c),
}
return context.WithValue(ctx, claudeMaxResponseRewriteContextKey, value)
}
func claudeMaxResponseRewriteContextFromContext(ctx context.Context) claudeMaxResponseRewriteContext {
if ctx == nil {
return claudeMaxResponseRewriteContext{}
}
value, _ := ctx.Value(claudeMaxResponseRewriteContextKey).(claudeMaxResponseRewriteContext)
return value
}
func claudeMaxGroupFromGinContext(c *gin.Context) *Group {
if c == nil {
return nil
}
raw, exists := c.Get("api_key")
if !exists {
return nil
}
apiKey, ok := raw.(*APIKey)
if !ok || apiKey == nil {
return nil
}
return apiKey.Group
}
func parsedRequestFromGinContext(c *gin.Context) *ParsedRequest {
if c == nil {
return nil
}
raw, exists := c.Get("parsed_request")
if !exists {
return nil
}
parsed, _ := raw.(*ParsedRequest)
return parsed
}
func applyClaudeMaxSimulationToUsage(ctx context.Context, usage *ClaudeUsage, model string, accountID int64) claudeMaxCacheBillingOutcome {
var out claudeMaxCacheBillingOutcome
if usage == nil {
return out
}
rewriteCtx := claudeMaxResponseRewriteContextFromContext(ctx)
return applyClaudeMaxCacheBillingPolicyToUsage(usage, rewriteCtx.Parsed, rewriteCtx.Group, model, accountID)
}
func applyClaudeMaxSimulationToUsageJSONMap(ctx context.Context, usageObj map[string]any, model string, accountID int64) claudeMaxCacheBillingOutcome {
var out claudeMaxCacheBillingOutcome
if usageObj == nil {
return out
}
usage := claudeUsageFromJSONMap(usageObj)
out = applyClaudeMaxSimulationToUsage(ctx, &usage, model, accountID)
if out.Simulated {
rewriteClaudeUsageJSONMap(usageObj, usage)
}
return out
}
func rewriteClaudeUsageJSONBytes(body []byte, usage ClaudeUsage) []byte {
updated := body
var err error
updated, err = sjson.SetBytes(updated, "usage.input_tokens", usage.InputTokens)
if err != nil {
return body
}
updated, err = sjson.SetBytes(updated, "usage.cache_creation_input_tokens", usage.CacheCreationInputTokens)
if err != nil {
return body
}
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_5m_input_tokens", usage.CacheCreation5mTokens)
if err != nil {
return body
}
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_1h_input_tokens", usage.CacheCreation1hTokens)
if err != nil {
return body
}
return updated
}
func claudeUsageFromJSONMap(usageObj map[string]any) ClaudeUsage {
var usage ClaudeUsage
if usageObj == nil {
return usage
}
usage.InputTokens = usageIntFromAny(usageObj["input_tokens"])
usage.OutputTokens = usageIntFromAny(usageObj["output_tokens"])
usage.CacheCreationInputTokens = usageIntFromAny(usageObj["cache_creation_input_tokens"])
usage.CacheReadInputTokens = usageIntFromAny(usageObj["cache_read_input_tokens"])
if ccObj, ok := usageObj["cache_creation"].(map[string]any); ok {
usage.CacheCreation5mTokens = usageIntFromAny(ccObj["ephemeral_5m_input_tokens"])
usage.CacheCreation1hTokens = usageIntFromAny(ccObj["ephemeral_1h_input_tokens"])
}
return usage
}
func rewriteClaudeUsageJSONMap(usageObj map[string]any, usage ClaudeUsage) {
if usageObj == nil {
return
}
usageObj["input_tokens"] = usage.InputTokens
usageObj["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
ccObj, _ := usageObj["cache_creation"].(map[string]any)
if ccObj == nil {
ccObj = make(map[string]any, 2)
usageObj["cache_creation"] = ccObj
}
ccObj["ephemeral_5m_input_tokens"] = usage.CacheCreation5mTokens
ccObj["ephemeral_1h_input_tokens"] = usage.CacheCreation1hTokens
}
func usageIntFromAny(v any) int {
switch value := v.(type) {
case int:
return value
case int64:
return int(value)
case float64:
return int(value)
case json.Number:
if n, err := value.Int64(); err == nil {
return int(n)
}
}
return 0
}
// setupClaudeMaxStreamingHook 为 Antigravity 流式路径设置 SSE usage 改写 hook。
func setupClaudeMaxStreamingHook(c *gin.Context, processor *antigravity.StreamingProcessor, originalModel string, accountID int64) {
group := claudeMaxGroupFromGinContext(c)
parsed := parsedRequestFromGinContext(c)
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
return
}
processor.SetUsageMapHook(func(usageMap map[string]any) {
svcUsage := claudeUsageFromJSONMap(usageMap)
outcome := applyClaudeMaxCacheBillingPolicyToUsage(&svcUsage, parsed, group, originalModel, accountID)
if outcome.Simulated {
rewriteClaudeUsageJSONMap(usageMap, svcUsage)
}
})
}
// applyClaudeMaxNonStreamingRewrite 为 Antigravity 非流式路径改写响应体中的 usage。
func applyClaudeMaxNonStreamingRewrite(c *gin.Context, claudeResp []byte, agUsage *antigravity.ClaudeUsage, originalModel string, accountID int64) []byte {
group := claudeMaxGroupFromGinContext(c)
parsed := parsedRequestFromGinContext(c)
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
return claudeResp
}
svcUsage := &ClaudeUsage{
InputTokens: agUsage.InputTokens,
OutputTokens: agUsage.OutputTokens,
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
CacheReadInputTokens: agUsage.CacheReadInputTokens,
}
outcome := applyClaudeMaxCacheBillingPolicyToUsage(svcUsage, parsed, group, originalModel, accountID)
if outcome.Simulated {
return rewriteClaudeUsageJSONBytes(claudeResp, *svcUsage)
}
return claudeResp
}

View File

@@ -187,6 +187,14 @@ func (m *mockAccountRepoForPlatform) BulkUpdate(ctx context.Context, ids []int64
return 0, nil
}
func (m *mockAccountRepoForPlatform) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForPlatform)(nil)

View File

@@ -0,0 +1,199 @@
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type usageLogRepoRecordUsageStub struct {
UsageLogRepository
last *UsageLog
inserted bool
err error
}
func (s *usageLogRepoRecordUsageStub) Create(_ context.Context, log *UsageLog) (bool, error) {
copied := *log
s.last = &copied
return s.inserted, s.err
}
func newGatewayServiceForRecordUsageTest(repo UsageLogRepository) *GatewayService {
return &GatewayService{
usageLogRepo: repo,
billingService: NewBillingService(&config.Config{}, nil),
cfg: &config.Config{RunMode: config.RunModeSimple},
deferredService: &DeferredService{},
}
}
func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsUsageAndSkipsTTLOverride(t *testing.T) {
repo := &usageLogRepoRecordUsageStub{inserted: true}
svc := newGatewayServiceForRecordUsageTest(repo)
groupID := int64(11)
input := &RecordUsageInput{
Result: &ForwardResult{
RequestID: "req-sim-1",
Model: "claude-sonnet-4",
Duration: time.Second,
Usage: ClaudeUsage{
InputTokens: 160,
},
},
ParsedRequest: &ParsedRequest{
Model: "claude-sonnet-4",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "long cached context for prior turns",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "please summarize the logs and provide root cause analysis",
},
},
},
},
},
APIKey: &APIKey{
ID: 1,
GroupID: &groupID,
Group: &Group{
ID: groupID,
Platform: PlatformAnthropic,
RateMultiplier: 1,
SimulateClaudeMaxEnabled: true,
},
},
User: &User{ID: 2},
Account: &Account{
ID: 3,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
},
}
err := svc.RecordUsage(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, repo.last)
log := repo.last
require.Equal(t, 80, log.InputTokens)
require.Equal(t, 80, log.CacheCreationTokens)
require.Equal(t, 0, log.CacheCreation5mTokens)
require.Equal(t, 80, log.CacheCreation1hTokens)
require.False(t, log.CacheTTLOverridden, "simulate outcome should skip account ttl override")
}
func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T) {
repo := &usageLogRepoRecordUsageStub{inserted: true}
svc := newGatewayServiceForRecordUsageTest(repo)
groupID := int64(12)
input := &RecordUsageInput{
Result: &ForwardResult{
RequestID: "req-sim-2",
Model: "claude-sonnet-4",
Duration: time.Second,
Usage: ClaudeUsage{
InputTokens: 40,
CacheCreationInputTokens: 120,
CacheCreation1hTokens: 120,
},
},
APIKey: &APIKey{
ID: 2,
GroupID: &groupID,
Group: &Group{
ID: groupID,
Platform: PlatformAnthropic,
RateMultiplier: 1,
SimulateClaudeMaxEnabled: false,
},
},
User: &User{ID: 3},
Account: &Account{
ID: 4,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
},
}
err := svc.RecordUsage(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, repo.last)
log := repo.last
require.Equal(t, 120, log.CacheCreationTokens)
require.Equal(t, 120, log.CacheCreation5mTokens)
require.Equal(t, 0, log.CacheCreation1hTokens)
require.True(t, log.CacheTTLOverridden)
}
func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationBypassesSimulation(t *testing.T) {
repo := &usageLogRepoRecordUsageStub{inserted: true}
svc := newGatewayServiceForRecordUsageTest(repo)
groupID := int64(13)
input := &RecordUsageInput{
Result: &ForwardResult{
RequestID: "req-sim-3",
Model: "claude-sonnet-4",
Duration: time.Second,
Usage: ClaudeUsage{
InputTokens: 20,
CacheCreationInputTokens: 120,
CacheCreation5mTokens: 120,
},
},
APIKey: &APIKey{
ID: 3,
GroupID: &groupID,
Group: &Group{
ID: groupID,
Platform: PlatformAnthropic,
RateMultiplier: 1,
SimulateClaudeMaxEnabled: true,
},
},
User: &User{ID: 4},
Account: &Account{
ID: 5,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
},
}
err := svc.RecordUsage(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, repo.last)
log := repo.last
require.Equal(t, 20, log.InputTokens)
require.Equal(t, 120, log.CacheCreation5mTokens)
require.Equal(t, 0, log.CacheCreation1hTokens)
require.Equal(t, 120, log.CacheCreationTokens)
require.False(t, log.CacheTTLOverridden, "existing cache_creation with SimulateClaudeMax enabled should skip account ttl override")
}

View File

@@ -0,0 +1,170 @@
package service
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestHandleNonStreamingResponse_UsageAlignedWithClaudeMaxSimulation(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &GatewayService{
cfg: &config.Config{},
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 11,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
}
group := &Group{
ID: 99,
Platform: PlatformAnthropic,
SimulateClaudeMaxEnabled: true,
}
parsed := &ParsedRequest{
Model: "claude-sonnet-4",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "long cached context",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "new user question",
},
},
},
},
}
upstreamBody := []byte(`{"id":"msg_1","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: ioNopCloserBytes(upstreamBody),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
c.Set("api_key", &APIKey{Group: group})
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
require.NoError(t, err)
require.NotNil(t, usage)
var rendered struct {
Usage ClaudeUsage `json:"usage"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rendered))
rendered.Usage.CacheCreation5mTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_5m_input_tokens").Int())
rendered.Usage.CacheCreation1hTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_1h_input_tokens").Int())
require.Equal(t, rendered.Usage.InputTokens, usage.InputTokens)
require.Equal(t, rendered.Usage.OutputTokens, usage.OutputTokens)
require.Equal(t, rendered.Usage.CacheCreationInputTokens, usage.CacheCreationInputTokens)
require.Equal(t, rendered.Usage.CacheCreation5mTokens, usage.CacheCreation5mTokens)
require.Equal(t, rendered.Usage.CacheCreation1hTokens, usage.CacheCreation1hTokens)
require.Equal(t, rendered.Usage.CacheReadInputTokens, usage.CacheReadInputTokens)
require.Greater(t, usage.CacheCreation1hTokens, 0)
require.Equal(t, 0, usage.CacheCreation5mTokens)
require.Less(t, usage.InputTokens, 120)
}
func TestHandleNonStreamingResponse_ClaudeMaxDisabled_NoSimulationIntercept(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &GatewayService{
cfg: &config.Config{},
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 12,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
}
group := &Group{
ID: 100,
Platform: PlatformAnthropic,
SimulateClaudeMaxEnabled: false,
}
parsed := &ParsedRequest{
Model: "claude-sonnet-4",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "long cached context",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "new user question",
},
},
},
},
}
upstreamBody := []byte(`{"id":"msg_2","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: ioNopCloserBytes(upstreamBody),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
c.Set("api_key", &APIKey{Group: group})
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, 120, usage.InputTokens)
require.Equal(t, 0, usage.CacheCreationInputTokens)
require.Equal(t, 0, usage.CacheCreation5mTokens)
require.Equal(t, 0, usage.CacheCreation1hTokens)
}
func ioNopCloserBytes(b []byte) *readCloserFromBytes {
return &readCloserFromBytes{Reader: bytes.NewReader(b)}
}
type readCloserFromBytes struct {
*bytes.Reader
}
func (r *readCloserFromBytes) Close() error {
return nil
}

View File

@@ -56,6 +56,12 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
const (
claudeMaxMessageOverheadTokens = 3
claudeMaxBlockOverheadTokens = 1
claudeMaxUnknownContentTokens = 4
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
@@ -1228,6 +1234,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID)
continue
}
// 配额检查
if !s.isAccountSchedulableForQuota(account) {
continue
}
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
filteredWindowCost++
@@ -1260,6 +1270,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
s.isAccountSchedulableForQuota(stickyAccount) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
@@ -1311,7 +1322,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, acc := range routingCandidates {
routingLoads = append(routingLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
@@ -1416,6 +1427,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
s.isAccountSchedulableForQuota(account) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) &&
s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
@@ -1480,6 +1492,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
// 配额检查
if !s.isAccountSchedulableForQuota(acc) {
continue
}
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
@@ -1499,7 +1515,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
@@ -2113,6 +2129,15 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
}
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内
// 仅适用于配置了 quota_limit 的 apikey 类型账号
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
if account.Type != AccountTypeAPIKey {
return true
}
return !account.IsQuotaExceeded()
}
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示可调度false 表示不可调度
@@ -2590,7 +2615,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -2644,6 +2669,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
@@ -2700,7 +2728,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
return account, nil
}
}
@@ -2743,6 +2771,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
@@ -2818,7 +2849,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
@@ -2874,6 +2905,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
@@ -2930,7 +2964,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil
}
@@ -2975,6 +3009,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
@@ -4317,6 +4354,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 处理正常响应
ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed)
// 触发上游接受回调(提前释放串行锁,不等流完成)
if parsed.OnUpstreamAccepted != nil {
@@ -5773,6 +5811,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage
skipAccountTTLOverride := false
pendingEventLines := make([]string, 0, 4)
@@ -5833,17 +5872,25 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
eventChanged = reconcileCachedTokens(u) || eventChanged
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
skipAccountTTLOverride = true
}
}
}
}
if eventType == "message_delta" {
if u, ok := event["usage"].(map[string]any); ok {
eventChanged = reconcileCachedTokens(u) || eventChanged
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
skipAccountTTLOverride = true
}
}
}
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() {
if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride {
overrideTarget := account.GetCacheTTLOverrideTarget()
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
@@ -6253,8 +6300,13 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
body = rewriteClaudeUsageJSONBytes(body, response.Usage)
}
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() {
if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated {
overrideTarget := account.GetCacheTTLOverrideTarget()
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
// 同步更新 body JSON 中的嵌套 cache_creation 对象
@@ -6363,6 +6415,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ParsedRequest *ParsedRequest
APIKey *APIKey
User *User
Account *Account
@@ -6379,6 +6432,89 @@ type APIKeyQuotaUpdater interface {
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
}
// postUsageBillingParams 统一扣费所需的参数
type postUsageBillingParams struct {
Cost *CostBreakdown
User *User
APIKey *APIKey
Account *Account
Subscription *UserSubscription
IsSubscriptionBill bool
AccountRateMultiplier float64
APIKeyService APIKeyQuotaUpdater
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
// - 订阅/余额扣费
// - API Key 配额更新
// - API Key 限速用量更新
// - 账号配额用量更新账号口径TotalCost × 账号计费倍率)
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
cost := p.Cost
// 1. 订阅 / 余额扣费
if p.IsSubscriptionBill {
if cost.TotalCost > 0 {
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
}
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
}
} else {
if cost.ActualCost > 0 {
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
}
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
}
}
// 2. API Key 配额
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 3. API Key 限速用量
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
}
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
}
// 4. 账号配额用量账号口径TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
}
}
// 5. 更新账号最近使用时间
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
}
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type billingDeps struct {
accountRepo AccountRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
deferredService *DeferredService
}
func (s *GatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
}
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
@@ -6396,9 +6532,19 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
result.Usage.InputTokens = 0
}
// Claude Max cache billing policy (group-level):
// - GatewayService 路径: Forward 已改写 usage含 cache tokens→ apply 见到 cache tokens 跳过 → simulatedClaudeMax=true通过第二条件
// - Antigravity 路径: Forward 中 hook 改写了客户端 SSE但 ForwardResult.Usage 是原始值 → apply 实际执行模拟 → simulatedClaudeMax=true
var apiKeyGroup *Group
if apiKey != nil {
apiKeyGroup = apiKey.Group
}
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, input.ParsedRequest, apiKeyGroup, result.Model, account.ID)
simulatedClaudeMax := claudeMaxOutcome.Simulated ||
(shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, input.ParsedRequest) && hasCacheCreationTokens(result.Usage))
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() {
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
@@ -6542,45 +6688,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
shouldBill := inserted || err != nil
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}
// 更新 API Key 配额(如果设置了配额限制)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err)
}
}
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
@@ -6740,44 +6862,21 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
shouldBill := inserted || err != nil
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
// API Key 独立配额扣费
if input.APIKeyService != nil && apiKey.Quota > 0 {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err)
}
}
}
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}

View File

@@ -176,6 +176,14 @@ func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64,
return 0, nil
}
func (m *mockAccountRepoForGemini) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (m *mockAccountRepoForGemini) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)

View File

@@ -50,6 +50,9 @@ type Group struct {
// MCP XML 协议注入开关(仅 antigravity 平台使用)
MCPXMLInject bool
// Claude usage 模拟开关:将无写缓存 usage 模拟为 claude-max 风格
SimulateClaudeMaxEnabled bool
// 支持的模型系列(仅 antigravity 平台使用)
// 可选值: claude, gemini_text, gemini_image
SupportedModelScopes []string

View File

@@ -590,7 +590,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
filtered = append(filtered, account)
loadReq = append(loadReq, AccountWithConcurrency{
ID: account.ID,
MaxConcurrency: account.Concurrency,
MaxConcurrency: account.EffectiveLoadFactor(),
})
}
if len(filtered) == 0 {

View File

@@ -319,6 +319,16 @@ func NewOpenAIGatewayService(
return svc
}
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
}
}
// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
// 应在应用优雅关闭时调用。
func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
@@ -1242,7 +1252,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
@@ -3474,37 +3484,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
shouldBill := inserted || err != nil
// Deduct based on billing type
if isSubscriptionBilling {
if shouldBill && cost.TotalCost > 0 {
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
if shouldBill && cost.ActualCost > 0 {
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}
// Update API key quota if applicable (only for balance mode with quota set)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err)
}
}
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}

View File

@@ -864,7 +864,8 @@ func isOpenAIWSClientDisconnectError(err error) bool {
strings.Contains(message, "unexpected eof") ||
strings.Contains(message, "use of closed network connection") ||
strings.Contains(message, "connection reset by peer") ||
strings.Contains(message, "broken pipe")
strings.Contains(message, "broken pipe") ||
strings.Contains(message, "an established connection was aborted")
}
func classifyOpenAIWSReadFallbackReason(err error) string {

View File

@@ -64,8 +64,9 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts
if acc.ID <= 0 {
continue
}
if prev, ok := unique[acc.ID]; !ok || acc.Concurrency > prev {
unique[acc.ID] = acc.Concurrency
lf := acc.EffectiveLoadFactor()
if prev, ok := unique[acc.ID]; !ok || lf > prev {
unique[acc.ID] = lf
}
}

View File

@@ -389,13 +389,9 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con
if acc.ID <= 0 {
continue
}
maxConc := acc.Concurrency
if maxConc < 0 {
maxConc = 0
}
batch = append(batch, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: maxConc,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
if len(batch) == 0 {

View File

@@ -34,9 +34,10 @@ func TestCalculateProgress_BasicFields(t *testing.T) {
assert.Equal(t, int64(100), progress.ID)
assert.Equal(t, "Premium", progress.GroupName)
assert.Equal(t, sub.ExpiresAt, progress.ExpiresAt)
assert.Equal(t, 29, progress.ExpiresInDays) // 约 30 天
assert.Nil(t, progress.Daily, "无日限额时 Daily 应为 nil")
assert.Nil(t, progress.Weekly, "无周限额时 Weekly 应为 nil")
assert.GreaterOrEqual(t, progress.ExpiresInDays, 29)
assert.LessOrEqual(t, progress.ExpiresInDays, 30)
assert.Nil(t, progress.Daily)
assert.Nil(t, progress.Weekly)
assert.Nil(t, progress.Monthly, "无月限额时 Monthly 应为 nil")
}

View File

@@ -0,0 +1,42 @@
-- Add claude-sonnet-4-6 to model_mapping for all Antigravity accounts
--
-- Background:
-- Antigravity now supports claude-sonnet-4-6
--
-- Strategy:
-- Directly overwrite the entire model_mapping with updated mappings
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
UPDATE accounts
SET credentials = jsonb_set(
credentials,
'{model_mapping}',
'{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-6": "claude-opus-4-6-thinking",
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview"
}'::jsonb
)
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND credentials->'model_mapping' IS NOT NULL;

View File

@@ -0,0 +1,45 @@
-- Add gemini-3.1-pro-high, gemini-3.1-pro-low, gemini-3.1-pro-preview to model_mapping
--
-- Background:
-- Antigravity now supports gemini-3.1-pro-high and gemini-3.1-pro-low
--
-- Strategy:
-- Directly overwrite the entire model_mapping with updated mappings
-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
UPDATE accounts
SET credentials = jsonb_set(
credentials,
'{model_mapping}',
'{
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-6": "claude-opus-4-6-thinking",
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
"claude-sonnet-4-6": "claude-sonnet-4-6",
"claude-sonnet-4-5": "claude-sonnet-4-5",
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
"gemini-3-flash": "gemini-3-flash",
"gemini-3-pro-high": "gemini-3-pro-high",
"gemini-3-pro-low": "gemini-3-pro-low",
"gemini-3-pro-image": "gemini-3-pro-image",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview"
}'::jsonb
)
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND credentials->'model_mapping' IS NOT NULL;

View File

@@ -0,0 +1,3 @@
ALTER TABLE groups
ADD COLUMN IF NOT EXISTS simulate_claude_max_enabled BOOLEAN NOT NULL DEFAULT FALSE;

View File

@@ -0,0 +1 @@
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS load_factor INTEGER;