mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-17 21:34:45 +08:00
Add invalid-request fallback routing
This commit is contained in:
@@ -56,6 +56,8 @@ type Group struct {
|
|||||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||||
// 非 Claude Code 请求降级使用的分组 ID
|
// 非 Claude Code 请求降级使用的分组 ID
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
|
// 无效请求兜底使用的分组 ID
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||||
// 模型路由配置:模型模式 -> 优先账号ID列表
|
// 模型路由配置:模型模式 -> 优先账号ID列表
|
||||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||||
// 是否启用模型路由配置
|
// 是否启用模型路由配置
|
||||||
@@ -172,7 +174,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID:
|
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
@@ -322,6 +324,13 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
|||||||
_m.FallbackGroupID = new(int64)
|
_m.FallbackGroupID = new(int64)
|
||||||
*_m.FallbackGroupID = value.Int64
|
*_m.FallbackGroupID = value.Int64
|
||||||
}
|
}
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field fallback_group_id_on_invalid_request", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.FallbackGroupIDOnInvalidRequest = new(int64)
|
||||||
|
*_m.FallbackGroupIDOnInvalidRequest = value.Int64
|
||||||
|
}
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
if value, ok := values[i].(*[]byte); !ok {
|
if value, ok := values[i].(*[]byte); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field model_routing", values[i])
|
return fmt.Errorf("unexpected type %T for field model_routing", values[i])
|
||||||
@@ -487,6 +496,11 @@ func (_m *Group) String() string {
|
|||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
}
|
}
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
if v := _m.FallbackGroupIDOnInvalidRequest; v != nil {
|
||||||
|
builder.WriteString("fallback_group_id_on_invalid_request=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
builder.WriteString("model_routing=")
|
builder.WriteString("model_routing=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
|
builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
|||||||
@@ -53,6 +53,8 @@ const (
|
|||||||
FieldClaudeCodeOnly = "claude_code_only"
|
FieldClaudeCodeOnly = "claude_code_only"
|
||||||
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
||||||
FieldFallbackGroupID = "fallback_group_id"
|
FieldFallbackGroupID = "fallback_group_id"
|
||||||
|
// FieldFallbackGroupIDOnInvalidRequest holds the string denoting the fallback_group_id_on_invalid_request field in the database.
|
||||||
|
FieldFallbackGroupIDOnInvalidRequest = "fallback_group_id_on_invalid_request"
|
||||||
// FieldModelRouting holds the string denoting the model_routing field in the database.
|
// FieldModelRouting holds the string denoting the model_routing field in the database.
|
||||||
FieldModelRouting = "model_routing"
|
FieldModelRouting = "model_routing"
|
||||||
// FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
|
// FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
|
||||||
@@ -151,6 +153,7 @@ var Columns = []string{
|
|||||||
FieldImagePrice4k,
|
FieldImagePrice4k,
|
||||||
FieldClaudeCodeOnly,
|
FieldClaudeCodeOnly,
|
||||||
FieldFallbackGroupID,
|
FieldFallbackGroupID,
|
||||||
|
FieldFallbackGroupIDOnInvalidRequest,
|
||||||
FieldModelRouting,
|
FieldModelRouting,
|
||||||
FieldModelRoutingEnabled,
|
FieldModelRoutingEnabled,
|
||||||
}
|
}
|
||||||
@@ -317,6 +320,11 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
|
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByFallbackGroupIDOnInvalidRequest orders the results by the fallback_group_id_on_invalid_request field.
|
||||||
|
func ByFallbackGroupIDOnInvalidRequest(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldFallbackGroupIDOnInvalidRequest, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByModelRoutingEnabled orders the results by the model_routing_enabled field.
|
// ByModelRoutingEnabled orders the results by the model_routing_enabled field.
|
||||||
func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
|
func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
|
return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
|
||||||
|
|||||||
@@ -150,6 +150,11 @@ func FallbackGroupID(v int64) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequest applies equality check predicate on the "fallback_group_id_on_invalid_request" field. It's identical to FallbackGroupIDOnInvalidRequestEQ.
|
||||||
|
func FallbackGroupIDOnInvalidRequest(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
|
// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
|
||||||
func ModelRoutingEnabled(v bool) predicate.Group {
|
func ModelRoutingEnabled(v bool) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
|
return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
|
||||||
@@ -1070,6 +1075,56 @@ func FallbackGroupIDNotNil() predicate.Group {
|
|||||||
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
|
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestEQ applies the EQ predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestEQ(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestNEQ applies the NEQ predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestNEQ(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNEQ(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestIn applies the In predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestIn(vs ...int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldIn(FieldFallbackGroupIDOnInvalidRequest, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestNotIn applies the NotIn predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestNotIn(vs ...int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNotIn(FieldFallbackGroupIDOnInvalidRequest, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestGT applies the GT predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestGT(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldGT(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestGTE applies the GTE predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestGTE(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldGTE(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestLT applies the LT predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestLT(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldLT(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestLTE applies the LTE predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestLTE(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldLTE(FieldFallbackGroupIDOnInvalidRequest, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestIsNil applies the IsNil predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestIsNil() predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldIsNull(FieldFallbackGroupIDOnInvalidRequest))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestNotNil applies the NotNil predicate on the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func FallbackGroupIDOnInvalidRequestNotNil() predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupIDOnInvalidRequest))
|
||||||
|
}
|
||||||
|
|
||||||
// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
|
// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
|
||||||
func ModelRoutingIsNil() predicate.Group {
|
func ModelRoutingIsNil() predicate.Group {
|
||||||
return predicate.Group(sql.FieldIsNull(FieldModelRouting))
|
return predicate.Group(sql.FieldIsNull(FieldModelRouting))
|
||||||
|
|||||||
@@ -286,6 +286,20 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_c *GroupCreate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupCreate {
|
||||||
|
_c.mutation.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
|
||||||
|
func (_c *GroupCreate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetFallbackGroupIDOnInvalidRequest(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
|
func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
|
||||||
_c.mutation.SetModelRouting(v)
|
_c.mutation.SetModelRouting(v)
|
||||||
@@ -640,6 +654,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||||
_node.FallbackGroupID = &value
|
_node.FallbackGroupID = &value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.FallbackGroupIDOnInvalidRequest(); ok {
|
||||||
|
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
|
||||||
|
_node.FallbackGroupIDOnInvalidRequest = &value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.ModelRouting(); ok {
|
if value, ok := _c.mutation.ModelRouting(); ok {
|
||||||
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||||
_node.ModelRouting = value
|
_node.ModelRouting = value
|
||||||
@@ -1128,6 +1146,30 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsert) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert {
|
||||||
|
u.Set(group.FieldFallbackGroupIDOnInvalidRequest, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsert) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsert {
|
||||||
|
u.SetExcluded(group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsert) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert {
|
||||||
|
u.Add(group.FieldFallbackGroupIDOnInvalidRequest, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsert) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsert {
|
||||||
|
u.SetNull(group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
|
func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
|
||||||
u.Set(group.FieldModelRouting, v)
|
u.Set(group.FieldModelRouting, v)
|
||||||
@@ -1581,6 +1623,34 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.AddFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertOne) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateFallbackGroupIDOnInvalidRequest()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
|
func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
|
||||||
return u.Update(func(s *GroupUpsert) {
|
return u.Update(func(s *GroupUpsert) {
|
||||||
@@ -2205,6 +2275,34 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertBulk) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertBulk) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.AddFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertBulk) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateFallbackGroupIDOnInvalidRequest()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (u *GroupUpsertBulk) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
|
func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
|
||||||
return u.Update(func(s *GroupUpsert) {
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
|||||||
@@ -395,6 +395,33 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate {
|
||||||
|
_u.mutation.ResetFallbackGroupIDOnInvalidRequest()
|
||||||
|
_u.mutation.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetFallbackGroupIDOnInvalidRequest(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdate) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate {
|
||||||
|
_u.mutation.AddFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdate) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdate {
|
||||||
|
_u.mutation.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
|
func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
|
||||||
_u.mutation.SetModelRouting(v)
|
_u.mutation.SetModelRouting(v)
|
||||||
@@ -829,6 +856,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.FallbackGroupIDCleared() {
|
if _u.mutation.FallbackGroupIDCleared() {
|
||||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok {
|
||||||
|
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok {
|
||||||
|
_spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() {
|
||||||
|
_spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.ModelRouting(); ok {
|
if value, ok := _u.mutation.ModelRouting(); ok {
|
||||||
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||||
}
|
}
|
||||||
@@ -1513,6 +1549,33 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdateOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne {
|
||||||
|
_u.mutation.ResetFallbackGroupIDOnInvalidRequest()
|
||||||
|
_u.mutation.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdateOne) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetFallbackGroupIDOnInvalidRequest(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdateOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne {
|
||||||
|
_u.mutation.AddFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (_u *GroupUpdateOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdateOne {
|
||||||
|
_u.mutation.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
|
func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
|
||||||
_u.mutation.SetModelRouting(v)
|
_u.mutation.SetModelRouting(v)
|
||||||
@@ -1977,6 +2040,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
|||||||
if _u.mutation.FallbackGroupIDCleared() {
|
if _u.mutation.FallbackGroupIDCleared() {
|
||||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok {
|
||||||
|
_spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok {
|
||||||
|
_spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() {
|
||||||
|
_spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.ModelRouting(); ok {
|
if value, ok := _u.mutation.ModelRouting(); ok {
|
||||||
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
_spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -226,6 +226,7 @@ var (
|
|||||||
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
||||||
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
||||||
|
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
|
||||||
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
{Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||||
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
|
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3833,61 +3833,63 @@ func (m *AccountGroupMutation) ResetEdge(name string) error {
|
|||||||
// GroupMutation represents an operation that mutates the Group nodes in the graph.
|
// GroupMutation represents an operation that mutates the Group nodes in the graph.
|
||||||
type GroupMutation struct {
|
type GroupMutation struct {
|
||||||
config
|
config
|
||||||
op Op
|
op Op
|
||||||
typ string
|
typ string
|
||||||
id *int64
|
id *int64
|
||||||
created_at *time.Time
|
created_at *time.Time
|
||||||
updated_at *time.Time
|
updated_at *time.Time
|
||||||
deleted_at *time.Time
|
deleted_at *time.Time
|
||||||
name *string
|
name *string
|
||||||
description *string
|
description *string
|
||||||
rate_multiplier *float64
|
rate_multiplier *float64
|
||||||
addrate_multiplier *float64
|
addrate_multiplier *float64
|
||||||
is_exclusive *bool
|
is_exclusive *bool
|
||||||
status *string
|
status *string
|
||||||
platform *string
|
platform *string
|
||||||
subscription_type *string
|
subscription_type *string
|
||||||
daily_limit_usd *float64
|
daily_limit_usd *float64
|
||||||
adddaily_limit_usd *float64
|
adddaily_limit_usd *float64
|
||||||
weekly_limit_usd *float64
|
weekly_limit_usd *float64
|
||||||
addweekly_limit_usd *float64
|
addweekly_limit_usd *float64
|
||||||
monthly_limit_usd *float64
|
monthly_limit_usd *float64
|
||||||
addmonthly_limit_usd *float64
|
addmonthly_limit_usd *float64
|
||||||
default_validity_days *int
|
default_validity_days *int
|
||||||
adddefault_validity_days *int
|
adddefault_validity_days *int
|
||||||
image_price_1k *float64
|
image_price_1k *float64
|
||||||
addimage_price_1k *float64
|
addimage_price_1k *float64
|
||||||
image_price_2k *float64
|
image_price_2k *float64
|
||||||
addimage_price_2k *float64
|
addimage_price_2k *float64
|
||||||
image_price_4k *float64
|
image_price_4k *float64
|
||||||
addimage_price_4k *float64
|
addimage_price_4k *float64
|
||||||
claude_code_only *bool
|
claude_code_only *bool
|
||||||
fallback_group_id *int64
|
fallback_group_id *int64
|
||||||
addfallback_group_id *int64
|
addfallback_group_id *int64
|
||||||
model_routing *map[string][]int64
|
fallback_group_id_on_invalid_request *int64
|
||||||
model_routing_enabled *bool
|
addfallback_group_id_on_invalid_request *int64
|
||||||
clearedFields map[string]struct{}
|
model_routing *map[string][]int64
|
||||||
api_keys map[int64]struct{}
|
model_routing_enabled *bool
|
||||||
removedapi_keys map[int64]struct{}
|
clearedFields map[string]struct{}
|
||||||
clearedapi_keys bool
|
api_keys map[int64]struct{}
|
||||||
redeem_codes map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
removedredeem_codes map[int64]struct{}
|
clearedapi_keys bool
|
||||||
clearedredeem_codes bool
|
redeem_codes map[int64]struct{}
|
||||||
subscriptions map[int64]struct{}
|
removedredeem_codes map[int64]struct{}
|
||||||
removedsubscriptions map[int64]struct{}
|
clearedredeem_codes bool
|
||||||
clearedsubscriptions bool
|
subscriptions map[int64]struct{}
|
||||||
usage_logs map[int64]struct{}
|
removedsubscriptions map[int64]struct{}
|
||||||
removedusage_logs map[int64]struct{}
|
clearedsubscriptions bool
|
||||||
clearedusage_logs bool
|
usage_logs map[int64]struct{}
|
||||||
accounts map[int64]struct{}
|
removedusage_logs map[int64]struct{}
|
||||||
removedaccounts map[int64]struct{}
|
clearedusage_logs bool
|
||||||
clearedaccounts bool
|
accounts map[int64]struct{}
|
||||||
allowed_users map[int64]struct{}
|
removedaccounts map[int64]struct{}
|
||||||
removedallowed_users map[int64]struct{}
|
clearedaccounts bool
|
||||||
clearedallowed_users bool
|
allowed_users map[int64]struct{}
|
||||||
done bool
|
removedallowed_users map[int64]struct{}
|
||||||
oldValue func(context.Context) (*Group, error)
|
clearedallowed_users bool
|
||||||
predicates []predicate.Group
|
done bool
|
||||||
|
oldValue func(context.Context) (*Group, error)
|
||||||
|
predicates []predicate.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ent.Mutation = (*GroupMutation)(nil)
|
var _ ent.Mutation = (*GroupMutation)(nil)
|
||||||
@@ -4976,6 +4978,76 @@ func (m *GroupMutation) ResetFallbackGroupID() {
|
|||||||
delete(m.clearedFields, group.FieldFallbackGroupID)
|
delete(m.clearedFields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) {
|
||||||
|
m.fallback_group_id_on_invalid_request = &i
|
||||||
|
m.addfallback_group_id_on_invalid_request = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation.
|
||||||
|
func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) {
|
||||||
|
v := m.fallback_group_id_on_invalid_request
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity.
|
||||||
|
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.FallbackGroupIDOnInvalidRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) {
|
||||||
|
if m.addfallback_group_id_on_invalid_request != nil {
|
||||||
|
*m.addfallback_group_id_on_invalid_request += i
|
||||||
|
} else {
|
||||||
|
m.addfallback_group_id_on_invalid_request = &i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation.
|
||||||
|
func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) {
|
||||||
|
v := m.addfallback_group_id_on_invalid_request
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() {
|
||||||
|
m.fallback_group_id_on_invalid_request = nil
|
||||||
|
m.addfallback_group_id_on_invalid_request = nil
|
||||||
|
m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation.
|
||||||
|
func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool {
|
||||||
|
_, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field.
|
||||||
|
func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() {
|
||||||
|
m.fallback_group_id_on_invalid_request = nil
|
||||||
|
m.addfallback_group_id_on_invalid_request = nil
|
||||||
|
delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|
||||||
// SetModelRouting sets the "model_routing" field.
|
// SetModelRouting sets the "model_routing" field.
|
||||||
func (m *GroupMutation) SetModelRouting(value map[string][]int64) {
|
func (m *GroupMutation) SetModelRouting(value map[string][]int64) {
|
||||||
m.model_routing = &value
|
m.model_routing = &value
|
||||||
@@ -5419,7 +5491,7 @@ func (m *GroupMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *GroupMutation) Fields() []string {
|
func (m *GroupMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 21)
|
fields := make([]string, 0, 22)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, group.FieldCreatedAt)
|
fields = append(fields, group.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -5477,6 +5549,9 @@ func (m *GroupMutation) Fields() []string {
|
|||||||
if m.fallback_group_id != nil {
|
if m.fallback_group_id != nil {
|
||||||
fields = append(fields, group.FieldFallbackGroupID)
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
|
if m.fallback_group_id_on_invalid_request != nil {
|
||||||
|
fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
if m.model_routing != nil {
|
if m.model_routing != nil {
|
||||||
fields = append(fields, group.FieldModelRouting)
|
fields = append(fields, group.FieldModelRouting)
|
||||||
}
|
}
|
||||||
@@ -5529,6 +5604,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.ClaudeCodeOnly()
|
return m.ClaudeCodeOnly()
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
return m.FallbackGroupID()
|
return m.FallbackGroupID()
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
return m.FallbackGroupIDOnInvalidRequest()
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
return m.ModelRouting()
|
return m.ModelRouting()
|
||||||
case group.FieldModelRoutingEnabled:
|
case group.FieldModelRoutingEnabled:
|
||||||
@@ -5580,6 +5657,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
|||||||
return m.OldClaudeCodeOnly(ctx)
|
return m.OldClaudeCodeOnly(ctx)
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
return m.OldFallbackGroupID(ctx)
|
return m.OldFallbackGroupID(ctx)
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
return m.OldFallbackGroupIDOnInvalidRequest(ctx)
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
return m.OldModelRouting(ctx)
|
return m.OldModelRouting(ctx)
|
||||||
case group.FieldModelRoutingEnabled:
|
case group.FieldModelRoutingEnabled:
|
||||||
@@ -5726,6 +5805,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetFallbackGroupID(v)
|
m.SetFallbackGroupID(v)
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
v, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return nil
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
v, ok := value.(map[string][]int64)
|
v, ok := value.(map[string][]int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -5775,6 +5861,9 @@ func (m *GroupMutation) AddedFields() []string {
|
|||||||
if m.addfallback_group_id != nil {
|
if m.addfallback_group_id != nil {
|
||||||
fields = append(fields, group.FieldFallbackGroupID)
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
|
if m.addfallback_group_id_on_invalid_request != nil {
|
||||||
|
fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5801,6 +5890,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedImagePrice4k()
|
return m.AddedImagePrice4k()
|
||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
return m.AddedFallbackGroupID()
|
return m.AddedFallbackGroupID()
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
return m.AddedFallbackGroupIDOnInvalidRequest()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -5873,6 +5964,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddFallbackGroupID(v)
|
m.AddFallbackGroupID(v)
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
v, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddFallbackGroupIDOnInvalidRequest(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group numeric field %s", name)
|
return fmt.Errorf("unknown Group numeric field %s", name)
|
||||||
}
|
}
|
||||||
@@ -5908,6 +6006,9 @@ func (m *GroupMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(group.FieldFallbackGroupID) {
|
if m.FieldCleared(group.FieldFallbackGroupID) {
|
||||||
fields = append(fields, group.FieldFallbackGroupID)
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
}
|
}
|
||||||
|
if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) {
|
||||||
|
fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
if m.FieldCleared(group.FieldModelRouting) {
|
if m.FieldCleared(group.FieldModelRouting) {
|
||||||
fields = append(fields, group.FieldModelRouting)
|
fields = append(fields, group.FieldModelRouting)
|
||||||
}
|
}
|
||||||
@@ -5952,6 +6053,9 @@ func (m *GroupMutation) ClearField(name string) error {
|
|||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
m.ClearFallbackGroupID()
|
m.ClearFallbackGroupID()
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
m.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
return nil
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
m.ClearModelRouting()
|
m.ClearModelRouting()
|
||||||
return nil
|
return nil
|
||||||
@@ -6020,6 +6124,9 @@ func (m *GroupMutation) ResetField(name string) error {
|
|||||||
case group.FieldFallbackGroupID:
|
case group.FieldFallbackGroupID:
|
||||||
m.ResetFallbackGroupID()
|
m.ResetFallbackGroupID()
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldFallbackGroupIDOnInvalidRequest:
|
||||||
|
m.ResetFallbackGroupIDOnInvalidRequest()
|
||||||
|
return nil
|
||||||
case group.FieldModelRouting:
|
case group.FieldModelRouting:
|
||||||
m.ResetModelRouting()
|
m.ResetModelRouting()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -281,7 +281,7 @@ func init() {
|
|||||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||||
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
||||||
groupDescModelRoutingEnabled := groupFields[17].Descriptor()
|
groupDescModelRoutingEnabled := groupFields[18].Descriptor()
|
||||||
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
||||||
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
||||||
promocodeFields := schema.PromoCode{}.Fields()
|
promocodeFields := schema.PromoCode{}.Fields()
|
||||||
|
|||||||
@@ -95,6 +95,10 @@ func (Group) Fields() []ent.Field {
|
|||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||||
|
field.Int64("fallback_group_id_on_invalid_request").
|
||||||
|
Optional().
|
||||||
|
Nillable().
|
||||||
|
Comment("无效请求兜底使用的分组 ID"),
|
||||||
|
|
||||||
// 模型路由配置 (added by migration 040)
|
// 模型路由配置 (added by migration 040)
|
||||||
field.JSON("model_routing", map[string][]int64{}).
|
field.JSON("model_routing", map[string][]int64{}).
|
||||||
|
|||||||
@@ -35,11 +35,12 @@ type CreateGroupRequest struct {
|
|||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
@@ -58,11 +59,12 @@ type UpdateGroupRequest struct {
|
|||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||||
@@ -155,22 +157,23 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
Platform: req.Platform,
|
Platform: req.Platform,
|
||||||
RateMultiplier: req.RateMultiplier,
|
RateMultiplier: req.RateMultiplier,
|
||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
FallbackGroupID: req.FallbackGroupID,
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
ModelRouting: req.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRouting: req.ModelRouting,
|
||||||
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@@ -196,23 +199,24 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
Platform: req.Platform,
|
Platform: req.Platform,
|
||||||
RateMultiplier: req.RateMultiplier,
|
RateMultiplier: req.RateMultiplier,
|
||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
FallbackGroupID: req.FallbackGroupID,
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
ModelRouting: req.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRouting: req.ModelRouting,
|
||||||
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
|
|||||||
@@ -73,27 +73,28 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &Group{
|
return &Group{
|
||||||
ID: g.ID,
|
ID: g.ID,
|
||||||
Name: g.Name,
|
Name: g.Name,
|
||||||
Description: g.Description,
|
Description: g.Description,
|
||||||
Platform: g.Platform,
|
Platform: g.Platform,
|
||||||
RateMultiplier: g.RateMultiplier,
|
RateMultiplier: g.RateMultiplier,
|
||||||
IsExclusive: g.IsExclusive,
|
IsExclusive: g.IsExclusive,
|
||||||
Status: g.Status,
|
Status: g.Status,
|
||||||
SubscriptionType: g.SubscriptionType,
|
SubscriptionType: g.SubscriptionType,
|
||||||
DailyLimitUSD: g.DailyLimitUSD,
|
DailyLimitUSD: g.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||||
ImagePrice1K: g.ImagePrice1K,
|
ImagePrice1K: g.ImagePrice1K,
|
||||||
ImagePrice2K: g.ImagePrice2K,
|
ImagePrice2K: g.ImagePrice2K,
|
||||||
ImagePrice4K: g.ImagePrice4K,
|
ImagePrice4K: g.ImagePrice4K,
|
||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
ModelRouting: g.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRouting: g.ModelRouting,
|
||||||
CreatedAt: g.CreatedAt,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
UpdatedAt: g.UpdatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
AccountCount: g.AccountCount,
|
UpdatedAt: g.UpdatedAt,
|
||||||
|
AccountCount: g.AccountCount,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ type Group struct {
|
|||||||
// Claude Code 客户端限制
|
// Claude Code 客户端限制
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
|
// 无效请求兜底分组
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
|
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
@@ -325,136 +326,186 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
maxAccountSwitches := h.maxAccountSwitches
|
currentAPIKey := apiKey
|
||||||
switchCount := 0
|
currentSubscription := subscription
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
var fallbackGroupID *int64
|
||||||
lastFailoverStatus := 0
|
if apiKey.Group != nil {
|
||||||
|
fallbackGroupID = apiKey.Group.FallbackGroupIDOnInvalidRequest
|
||||||
|
}
|
||||||
|
fallbackUsed := false
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// 选择支持该模型的账号
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
switchCount := 0
|
||||||
if err != nil {
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
if len(failedAccountIDs) == 0 {
|
lastFailoverStatus := 0
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
retryWithFallback := false
|
||||||
return
|
|
||||||
}
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
account := selection.Account
|
|
||||||
setOpsSelectedAccount(c, account.ID)
|
|
||||||
|
|
||||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
for {
|
||||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
// 选择支持该模型的账号
|
||||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||||
selection.ReleaseFunc()
|
|
||||||
}
|
|
||||||
if reqStream {
|
|
||||||
sendMockWarmupStream(c, reqModel)
|
|
||||||
} else {
|
|
||||||
sendMockWarmupResponse(c, reqModel)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 获取账号并发槽位
|
|
||||||
accountReleaseFunc := selection.ReleaseFunc
|
|
||||||
if !selection.Acquired {
|
|
||||||
if selection.WaitPlan == nil {
|
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
accountWaitCounted := false
|
|
||||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Increment account wait count failed: %v", err)
|
if len(failedAccountIDs) == 0 {
|
||||||
} else if !canWait {
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err == nil && canWait {
|
|
||||||
accountWaitCounted = true
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if accountWaitCounted {
|
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
|
||||||
c,
|
|
||||||
account.ID,
|
|
||||||
selection.WaitPlan.MaxConcurrency,
|
|
||||||
selection.WaitPlan.Timeout,
|
|
||||||
reqStream,
|
|
||||||
&streamStarted,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if accountWaitCounted {
|
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
|
||||||
accountWaitCounted = false
|
|
||||||
}
|
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
|
||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
|
||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
|
||||||
var result *service.ForwardResult
|
|
||||||
if account.Platform == service.PlatformAntigravity {
|
|
||||||
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
|
|
||||||
} else {
|
|
||||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
|
|
||||||
}
|
|
||||||
if accountReleaseFunc != nil {
|
|
||||||
accountReleaseFunc()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
var failoverErr *service.UpstreamFailoverError
|
|
||||||
if errors.As(err, &failoverErr) {
|
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
if switchCount >= maxAccountSwitches {
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switchCount++
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
return
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
// 错误响应已在Forward中处理,这里只记录日志
|
account := selection.Account
|
||||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
|
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||||
|
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||||
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
if reqStream {
|
||||||
|
sendMockWarmupStream(c, reqModel)
|
||||||
|
} else {
|
||||||
|
sendMockWarmupResponse(c, reqModel)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 获取账号并发槽位
|
||||||
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
|
if !selection.Acquired {
|
||||||
|
if selection.WaitPlan == nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
accountWaitCounted := false
|
||||||
|
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Increment account wait count failed: %v", err)
|
||||||
|
} else if !canWait {
|
||||||
|
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||||
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil && canWait {
|
||||||
|
accountWaitCounted = true
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if accountWaitCounted {
|
||||||
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
|
c,
|
||||||
|
account.ID,
|
||||||
|
selection.WaitPlan.MaxConcurrency,
|
||||||
|
selection.WaitPlan.Timeout,
|
||||||
|
reqStream,
|
||||||
|
&streamStarted,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if accountWaitCounted {
|
||||||
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
accountWaitCounted = false
|
||||||
|
}
|
||||||
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||||
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||||
|
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||||
|
|
||||||
|
// 转发请求 - 根据账号平台分流
|
||||||
|
var result *service.ForwardResult
|
||||||
|
if account.Platform == service.PlatformAntigravity {
|
||||||
|
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
|
||||||
|
} else {
|
||||||
|
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
|
||||||
|
}
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
var promptTooLongErr *service.PromptTooLongError
|
||||||
|
if errors.As(err, &promptTooLongErr) {
|
||||||
|
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed)
|
||||||
|
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
|
||||||
|
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Resolve fallback group failed: %v", err)
|
||||||
|
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if fallbackGroup.Platform != service.PlatformAnthropic ||
|
||||||
|
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
|
||||||
|
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
|
||||||
|
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType)
|
||||||
|
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
|
||||||
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
|
||||||
|
status, code, message := billingErrorDetails(err)
|
||||||
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "")
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
currentAPIKey = fallbackAPIKey
|
||||||
|
currentSubscription = nil
|
||||||
|
fallbackUsed = true
|
||||||
|
retryWithFallback = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switchCount++
|
||||||
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 错误响应已在Forward中处理,这里只记录日志
|
||||||
|
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
|
// 异步记录使用量(subscription已在函数开头获取)
|
||||||
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
APIKey: currentAPIKey,
|
||||||
|
User: currentAPIKey.User,
|
||||||
|
Account: usedAccount,
|
||||||
|
Subscription: currentSubscription,
|
||||||
|
UserAgent: ua,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("Record usage failed: %v", err)
|
||||||
|
}
|
||||||
|
}(result, account, userAgent, clientIP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
if !retryWithFallback {
|
||||||
userAgent := c.GetHeader("User-Agent")
|
return
|
||||||
clientIP := ip.GetClientIP(c)
|
}
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
|
||||||
Result: result,
|
|
||||||
APIKey: apiKey,
|
|
||||||
User: apiKey.User,
|
|
||||||
Account: usedAccount,
|
|
||||||
Subscription: subscription,
|
|
||||||
UserAgent: ua,
|
|
||||||
IPAddress: clientIP,
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("Record usage failed: %v", err)
|
|
||||||
}
|
|
||||||
}(result, account, userAgent, clientIP)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -518,6 +569,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service.APIKey {
|
||||||
|
if apiKey == nil || group == nil {
|
||||||
|
return apiKey
|
||||||
|
}
|
||||||
|
cloned := *apiKey
|
||||||
|
groupID := group.ID
|
||||||
|
cloned.GroupID = &groupID
|
||||||
|
cloned.Group = group
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|
||||||
// Usage handles getting account balance for CC Switch integration
|
// Usage handles getting account balance for CC Switch integration
|
||||||
// GET /v1/usage
|
// GET /v1/usage
|
||||||
func (h *GatewayHandler) Usage(c *gin.Context) {
|
func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||||
|
|||||||
@@ -136,6 +136,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
group.FieldImagePrice4k,
|
group.FieldImagePrice4k,
|
||||||
group.FieldClaudeCodeOnly,
|
group.FieldClaudeCodeOnly,
|
||||||
group.FieldFallbackGroupID,
|
group.FieldFallbackGroupID,
|
||||||
|
group.FieldFallbackGroupIDOnInvalidRequest,
|
||||||
group.FieldModelRoutingEnabled,
|
group.FieldModelRoutingEnabled,
|
||||||
group.FieldModelRouting,
|
group.FieldModelRouting,
|
||||||
)
|
)
|
||||||
@@ -406,28 +407,29 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &service.Group{
|
return &service.Group{
|
||||||
ID: g.ID,
|
ID: g.ID,
|
||||||
Name: g.Name,
|
Name: g.Name,
|
||||||
Description: derefString(g.Description),
|
Description: derefString(g.Description),
|
||||||
Platform: g.Platform,
|
Platform: g.Platform,
|
||||||
RateMultiplier: g.RateMultiplier,
|
RateMultiplier: g.RateMultiplier,
|
||||||
IsExclusive: g.IsExclusive,
|
IsExclusive: g.IsExclusive,
|
||||||
Status: g.Status,
|
Status: g.Status,
|
||||||
Hydrated: true,
|
Hydrated: true,
|
||||||
SubscriptionType: g.SubscriptionType,
|
SubscriptionType: g.SubscriptionType,
|
||||||
DailyLimitUSD: g.DailyLimitUsd,
|
DailyLimitUSD: g.DailyLimitUsd,
|
||||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||||
ImagePrice1K: g.ImagePrice1k,
|
ImagePrice1K: g.ImagePrice1k,
|
||||||
ImagePrice2K: g.ImagePrice2k,
|
ImagePrice2K: g.ImagePrice2k,
|
||||||
ImagePrice4K: g.ImagePrice4k,
|
ImagePrice4K: g.ImagePrice4k,
|
||||||
DefaultValidityDays: g.DefaultValidityDays,
|
DefaultValidityDays: g.DefaultValidityDays,
|
||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
ModelRouting: g.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRouting: g.ModelRouting,
|
||||||
CreatedAt: g.CreatedAt,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
UpdatedAt: g.UpdatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
|
UpdatedAt: g.UpdatedAt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||||
|
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
||||||
|
|
||||||
// 设置模型路由配置
|
// 设置模型路由配置
|
||||||
@@ -116,6 +117,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
} else {
|
} else {
|
||||||
builder = builder.ClearFallbackGroupID()
|
builder = builder.ClearFallbackGroupID()
|
||||||
}
|
}
|
||||||
|
// 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置
|
||||||
|
if groupIn.FallbackGroupIDOnInvalidRequest != nil {
|
||||||
|
builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest)
|
||||||
|
} else {
|
||||||
|
builder = builder.ClearFallbackGroupIDOnInvalidRequest()
|
||||||
|
}
|
||||||
|
|
||||||
// 处理 ModelRouting:nil 时清除,否则设置
|
// 处理 ModelRouting:nil 时清除,否则设置
|
||||||
if groupIn.ModelRouting != nil {
|
if groupIn.ModelRouting != nil {
|
||||||
|
|||||||
@@ -108,6 +108,8 @@ type CreateGroupInput struct {
|
|||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||||
FallbackGroupID *int64 // 降级分组 ID
|
FallbackGroupID *int64 // 降级分组 ID
|
||||||
|
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64
|
ModelRouting map[string][]int64
|
||||||
ModelRoutingEnabled bool // 是否启用模型路由
|
ModelRoutingEnabled bool // 是否启用模型路由
|
||||||
@@ -130,6 +132,8 @@ type UpdateGroupInput struct {
|
|||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||||
FallbackGroupID *int64 // 降级分组 ID
|
FallbackGroupID *int64 // 降级分组 ID
|
||||||
|
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64
|
ModelRouting map[string][]int64
|
||||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||||
@@ -572,24 +576,35 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest
|
||||||
|
if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 {
|
||||||
|
fallbackOnInvalidRequest = nil
|
||||||
|
}
|
||||||
|
// 校验无效请求兜底分组
|
||||||
|
if fallbackOnInvalidRequest != nil {
|
||||||
|
if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
group := &Group{
|
group := &Group{
|
||||||
Name: input.Name,
|
Name: input.Name,
|
||||||
Description: input.Description,
|
Description: input.Description,
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
RateMultiplier: input.RateMultiplier,
|
RateMultiplier: input.RateMultiplier,
|
||||||
IsExclusive: input.IsExclusive,
|
IsExclusive: input.IsExclusive,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
SubscriptionType: subscriptionType,
|
SubscriptionType: subscriptionType,
|
||||||
DailyLimitUSD: dailyLimit,
|
DailyLimitUSD: dailyLimit,
|
||||||
WeeklyLimitUSD: weeklyLimit,
|
WeeklyLimitUSD: weeklyLimit,
|
||||||
MonthlyLimitUSD: monthlyLimit,
|
MonthlyLimitUSD: monthlyLimit,
|
||||||
ImagePrice1K: imagePrice1K,
|
ImagePrice1K: imagePrice1K,
|
||||||
ImagePrice2K: imagePrice2K,
|
ImagePrice2K: imagePrice2K,
|
||||||
ImagePrice4K: imagePrice4K,
|
ImagePrice4K: imagePrice4K,
|
||||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||||
FallbackGroupID: input.FallbackGroupID,
|
FallbackGroupID: input.FallbackGroupID,
|
||||||
ModelRouting: input.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
||||||
|
ModelRouting: input.ModelRouting,
|
||||||
}
|
}
|
||||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -651,6 +666,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
|
||||||
|
// currentGroupID: 当前分组 ID(新建时为 0)
|
||||||
|
// platform/subscriptionType: 当前分组的有效平台/订阅类型
|
||||||
|
// fallbackGroupID: 兜底分组 ID
|
||||||
|
func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error {
|
||||||
|
if platform != PlatformAnthropic && platform != PlatformAntigravity {
|
||||||
|
return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups")
|
||||||
|
}
|
||||||
|
if subscriptionType == SubscriptionTypeSubscription {
|
||||||
|
return fmt.Errorf("subscription groups cannot set invalid request fallback")
|
||||||
|
}
|
||||||
|
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
|
||||||
|
return fmt.Errorf("cannot set self as invalid request fallback group")
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fallback group not found: %w", err)
|
||||||
|
}
|
||||||
|
if fallbackGroup.Platform != PlatformAnthropic {
|
||||||
|
return fmt.Errorf("fallback group must be anthropic platform")
|
||||||
|
}
|
||||||
|
if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription {
|
||||||
|
return fmt.Errorf("fallback group cannot be subscription type")
|
||||||
|
}
|
||||||
|
if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
|
||||||
|
return fmt.Errorf("fallback group cannot have invalid request fallback configured")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||||
group, err := s.groupRepo.GetByID(ctx, id)
|
group, err := s.groupRepo.GetByID(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -717,6 +763,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
group.FallbackGroupID = nil
|
group.FallbackGroupID = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest
|
||||||
|
if input.FallbackGroupIDOnInvalidRequest != nil {
|
||||||
|
if *input.FallbackGroupIDOnInvalidRequest > 0 {
|
||||||
|
fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest
|
||||||
|
} else {
|
||||||
|
fallbackOnInvalidRequest = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if fallbackOnInvalidRequest != nil {
|
||||||
|
if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest
|
||||||
|
|
||||||
// 模型路由配置
|
// 模型路由配置
|
||||||
if input.ModelRouting != nil {
|
if input.ModelRouting != nil {
|
||||||
|
|||||||
@@ -378,3 +378,374 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int
|
|||||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type groupRepoStubForInvalidRequestFallback struct {
|
||||||
|
groups map[int64]*Group
|
||||||
|
created *Group
|
||||||
|
updated *Group
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error {
|
||||||
|
s.created = g
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error {
|
||||||
|
s.updated = g
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||||
|
return s.GetByIDLite(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||||
|
if g, ok := s.groups[id]; ok {
|
||||||
|
return g, nil
|
||||||
|
}
|
||||||
|
return nil, ErrGroupNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||||
|
panic("unexpected DeleteCascade call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected List call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListWithFilters call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) {
|
||||||
|
panic("unexpected ListActive call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||||
|
panic("unexpected ListActiveByPlatform call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||||
|
panic("unexpected ExistsByName call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||||
|
panic("unexpected GetAccountCount call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||||
|
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||||
|
require.Nil(t, repo.created)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeSubscription,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||||
|
require.Nil(t, repo.created)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fallback *Group
|
||||||
|
wantMessage string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "openai_target",
|
||||||
|
fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
wantMessage: "fallback group must be anthropic platform",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "antigravity_target",
|
||||||
|
fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
wantMessage: "fallback group must be anthropic platform",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subscription_group",
|
||||||
|
fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||||
|
wantMessage: "fallback group cannot be subscription type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested_fallback",
|
||||||
|
fallback: &Group{
|
||||||
|
ID: 10,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(),
|
||||||
|
},
|
||||||
|
wantMessage: "fallback group cannot have invalid request fallback configured",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
fallbackID := tc.fallback.ID
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
fallbackID: tc.fallback,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), tc.wantMessage)
|
||||||
|
require.Nil(t, repo.created)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "fallback group not found")
|
||||||
|
require.Nil(t, repo.created)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.NotNil(t, repo.created)
|
||||||
|
require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||||
|
zero := int64(0)
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &zero,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.NotNil(t, repo.created)
|
||||||
|
require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||||
|
require.Nil(t, repo.updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
SubscriptionType: SubscriptionTypeSubscription,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||||
|
require.Nil(t, repo.updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
clear := int64(0)
|
||||||
|
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
FallbackGroupIDOnInvalidRequest: &clear,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.NotNil(t, repo.updated)
|
||||||
|
require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "fallback group cannot be subscription type")
|
||||||
|
require.Nil(t, repo.updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.NotNil(t, repo.updated)
|
||||||
|
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||||
|
fallbackID := int64(10)
|
||||||
|
existing := &Group{
|
||||||
|
ID: 1,
|
||||||
|
Name: "g1",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
repo := &groupRepoStubForInvalidRequestFallback{
|
||||||
|
groups: map[int64]*Group{
|
||||||
|
existing.ID: existing,
|
||||||
|
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|
||||||
|
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||||
|
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, group)
|
||||||
|
require.NotNil(t, repo.updated)
|
||||||
|
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||||
|
}
|
||||||
|
|||||||
@@ -62,6 +62,17 @@ type antigravityRetryLoopResult struct {
|
|||||||
resp *http.Response
|
resp *http.Response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PromptTooLongError 表示上游明确返回 prompt too long
|
||||||
|
type PromptTooLongError struct {
|
||||||
|
StatusCode int
|
||||||
|
RequestID string
|
||||||
|
Body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *PromptTooLongError) Error() string {
|
||||||
|
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
||||||
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
||||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||||
@@ -930,6 +941,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
|
|
||||||
// 处理错误响应(重试后仍失败或不触发重试)
|
// 处理错误响应(重试后仍失败或不触发重试)
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500))
|
||||||
|
}
|
||||||
|
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
|
||||||
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
upstreamDetail := ""
|
||||||
|
if logBody {
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "prompt_too_long",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
return nil, &PromptTooLongError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
RequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Body: respBody,
|
||||||
|
}
|
||||||
|
}
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||||
|
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
@@ -1019,21 +1063,55 @@ func isSignatureRelatedError(respBody []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isPromptTooLongError(respBody []byte) bool {
|
||||||
|
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||||
|
if msg == "" {
|
||||||
|
msg = strings.ToLower(string(respBody))
|
||||||
|
}
|
||||||
|
return strings.Contains(msg, "prompt is too long")
|
||||||
|
}
|
||||||
|
|
||||||
func extractAntigravityErrorMessage(body []byte) string {
|
func extractAntigravityErrorMessage(body []byte) string {
|
||||||
var payload map[string]any
|
var payload map[string]any
|
||||||
if err := json.Unmarshal(body, &payload); err != nil {
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
parseNestedMessage := func(msg string) string {
|
||||||
|
trimmed := strings.TrimSpace(msg)
|
||||||
|
if trimmed == "" || !strings.HasPrefix(trimmed, "{") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var nested map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(trimmed), &nested); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if errObj, ok := nested["error"].(map[string]any); ok {
|
||||||
|
if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
|
||||||
|
return innerMsg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
|
||||||
|
return innerMsg
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// Google-style: {"error": {"message": "..."}}
|
// Google-style: {"error": {"message": "..."}}
|
||||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||||
|
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
|
||||||
|
return innerMsg
|
||||||
|
}
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback: top-level message
|
// Fallback: top-level message
|
||||||
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||||
|
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
|
||||||
|
return innerMsg
|
||||||
|
}
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2209,6 +2287,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
|
|||||||
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||||||
|
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||||||
statusStr := "UNKNOWN"
|
statusStr := "UNKNOWN"
|
||||||
switch status {
|
switch status {
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -81,3 +87,77 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
|
|||||||
require.Equal(t, "secret plan", blocks[0]["text"])
|
require.Equal(t, "secret plan", blocks[0]["text"])
|
||||||
require.Equal(t, "tool_use", blocks[1]["type"])
|
require.Equal(t, "tool_use", blocks[1]["type"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsPromptTooLongError(t *testing.T) {
|
||||||
|
require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`)))
|
||||||
|
require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`)))
|
||||||
|
require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`)))
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpUpstreamStub struct {
|
||||||
|
resp *http.Response
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||||
|
return s.resp, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(writer)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"model": "claude-opus-4-5",
|
||||||
|
"messages": []map[string]any{
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
},
|
||||||
|
"max_tokens": 1,
|
||||||
|
"stream": false,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
respBody := []byte(`{"error":{"message":"Prompt is too long"}}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Header: http.Header{"X-Request-Id": []string{"req-1"}},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "acc-1",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, body)
|
||||||
|
require.Nil(t, result)
|
||||||
|
|
||||||
|
var promptErr *PromptTooLongError
|
||||||
|
require.ErrorAs(t, err, &promptErr)
|
||||||
|
require.Equal(t, http.StatusBadRequest, promptErr.StatusCode)
|
||||||
|
require.Equal(t, "req-1", promptErr.RequestID)
|
||||||
|
require.NotEmpty(t, promptErr.Body)
|
||||||
|
|
||||||
|
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, events, 1)
|
||||||
|
require.Equal(t, "prompt_too_long", events[0].Kind)
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,20 +23,21 @@ type APIKeyAuthUserSnapshot struct {
|
|||||||
|
|
||||||
// APIKeyAuthGroupSnapshot 分组快照
|
// APIKeyAuthGroupSnapshot 分组快照
|
||||||
type APIKeyAuthGroupSnapshot struct {
|
type APIKeyAuthGroupSnapshot struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Platform string `json:"platform"`
|
Platform string `json:"platform"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
SubscriptionType string `json:"subscription_type"`
|
SubscriptionType string `json:"subscription_type"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||||
|
|
||||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||||
// Only anthropic groups use these fields; others may leave them empty.
|
// Only anthropic groups use these fields; others may leave them empty.
|
||||||
|
|||||||
@@ -207,22 +207,23 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
}
|
}
|
||||||
if apiKey.Group != nil {
|
if apiKey.Group != nil {
|
||||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||||
ID: apiKey.Group.ID,
|
ID: apiKey.Group.ID,
|
||||||
Name: apiKey.Group.Name,
|
Name: apiKey.Group.Name,
|
||||||
Platform: apiKey.Group.Platform,
|
Platform: apiKey.Group.Platform,
|
||||||
Status: apiKey.Group.Status,
|
Status: apiKey.Group.Status,
|
||||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||||
ModelRouting: apiKey.Group.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
|
||||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
ModelRouting: apiKey.Group.ModelRouting,
|
||||||
|
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return snapshot
|
return snapshot
|
||||||
@@ -250,23 +251,24 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
}
|
}
|
||||||
if snapshot.Group != nil {
|
if snapshot.Group != nil {
|
||||||
apiKey.Group = &Group{
|
apiKey.Group = &Group{
|
||||||
ID: snapshot.Group.ID,
|
ID: snapshot.Group.ID,
|
||||||
Name: snapshot.Group.Name,
|
Name: snapshot.Group.Name,
|
||||||
Platform: snapshot.Group.Platform,
|
Platform: snapshot.Group.Platform,
|
||||||
Status: snapshot.Group.Status,
|
Status: snapshot.Group.Status,
|
||||||
Hydrated: true,
|
Hydrated: true,
|
||||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||||
ModelRouting: snapshot.Group.ModelRouting,
|
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
|
||||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
ModelRouting: snapshot.Group.ModelRouting,
|
||||||
|
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return apiKey
|
return apiKey
|
||||||
|
|||||||
@@ -55,6 +55,15 @@ func shortSessionHash(sessionHash string) string {
|
|||||||
return sessionHash[:8]
|
return sessionHash[:8]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeClaudeModelForAnthropic(requestedModel string) string {
|
||||||
|
for _, prefix := range anthropicPrefixMappings {
|
||||||
|
if strings.HasPrefix(requestedModel, prefix) {
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return requestedModel
|
||||||
|
}
|
||||||
|
|
||||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||||
var (
|
var (
|
||||||
@@ -71,6 +80,12 @@ var (
|
|||||||
"You are a file search specialist for Claude Code", // Explore Agent 版
|
"You are a file search specialist for Claude Code", // Explore Agent 版
|
||||||
"You are a helpful AI assistant tasked with summarizing conversations", // Compact 版
|
"You are a helpful AI assistant tasked with summarizing conversations", // Compact 版
|
||||||
}
|
}
|
||||||
|
|
||||||
|
anthropicPrefixMappings = []string{
|
||||||
|
"claude-opus-4-5",
|
||||||
|
"claude-haiku-4-5",
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||||
@@ -951,6 +966,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
|
|||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
|
||||||
|
return s.resolveGroupByID(ctx, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
|
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
|
||||||
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
|
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
|
||||||
return nil
|
return nil
|
||||||
@@ -1016,7 +1035,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 强制平台模式不检查 Claude Code 限制
|
// 强制平台模式不检查 Claude Code 限制
|
||||||
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
|
if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" {
|
||||||
return nil, groupID, nil
|
return nil, groupID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1719,6 +1738,9 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
|||||||
// Antigravity 平台使用专门的模型支持检查
|
// Antigravity 平台使用专门的模型支持检查
|
||||||
return IsAntigravityModelSupported(requestedModel)
|
return IsAntigravityModelSupported(requestedModel)
|
||||||
}
|
}
|
||||||
|
if account.Platform == PlatformAnthropic {
|
||||||
|
requestedModel = normalizeClaudeModelForAnthropic(requestedModel)
|
||||||
|
}
|
||||||
// 其他平台使用账户的模型支持检查
|
// 其他平台使用账户的模型支持检查
|
||||||
return account.IsModelSupported(requestedModel)
|
return account.IsModelSupported(requestedModel)
|
||||||
}
|
}
|
||||||
@@ -2115,17 +2137,29 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||||
body = enforceCacheControlLimit(body)
|
body = enforceCacheControlLimit(body)
|
||||||
|
|
||||||
// 应用模型映射(仅对apikey类型账号)
|
// 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
|
||||||
originalModel := reqModel
|
originalModel := reqModel
|
||||||
|
mappedModel := reqModel
|
||||||
|
mappingSource := ""
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey {
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
mappedModel = account.GetMappedModel(reqModel)
|
||||||
if mappedModel != reqModel {
|
if mappedModel != reqModel {
|
||||||
// 替换请求体中的模型名
|
mappingSource = "account"
|
||||||
body = s.replaceModelInBody(body, mappedModel)
|
|
||||||
reqModel = mappedModel
|
|
||||||
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if mappingSource == "" && account.Platform == PlatformAnthropic {
|
||||||
|
normalized := normalizeClaudeModelForAnthropic(reqModel)
|
||||||
|
if normalized != reqModel {
|
||||||
|
mappedModel = normalized
|
||||||
|
mappingSource = "prefix"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mappedModel != reqModel {
|
||||||
|
// 替换请求体中的模型名
|
||||||
|
body = s.replaceModelInBody(body, mappedModel)
|
||||||
|
reqModel = mappedModel
|
||||||
|
log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
|
||||||
|
}
|
||||||
|
|
||||||
// 获取凭证
|
// 获取凭证
|
||||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||||
@@ -3426,16 +3460,28 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应用模型映射(仅对 apikey 类型账号)
|
// 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
|
||||||
if account.Type == AccountTypeAPIKey {
|
if reqModel != "" {
|
||||||
if reqModel != "" {
|
mappedModel := reqModel
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
mappingSource := ""
|
||||||
|
if account.Type == AccountTypeAPIKey {
|
||||||
|
mappedModel = account.GetMappedModel(reqModel)
|
||||||
if mappedModel != reqModel {
|
if mappedModel != reqModel {
|
||||||
body = s.replaceModelInBody(body, mappedModel)
|
mappingSource = "account"
|
||||||
reqModel = mappedModel
|
|
||||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if mappingSource == "" && account.Platform == PlatformAnthropic {
|
||||||
|
normalized := normalizeClaudeModelForAnthropic(reqModel)
|
||||||
|
if normalized != reqModel {
|
||||||
|
mappedModel = normalized
|
||||||
|
mappingSource = "prefix"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mappedModel != reqModel {
|
||||||
|
body = s.replaceModelInBody(body, mappedModel)
|
||||||
|
reqModel = mappedModel
|
||||||
|
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取凭证
|
// 获取凭证
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ type Group struct {
|
|||||||
// Claude Code 客户端限制
|
// Claude Code 客户端限制
|
||||||
ClaudeCodeOnly bool
|
ClaudeCodeOnly bool
|
||||||
FallbackGroupID *int64
|
FallbackGroupID *int64
|
||||||
|
// 无效请求兜底分组(仅 anthropic 平台使用)
|
||||||
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
|
|
||||||
// 模型路由配置
|
// 模型路由配置
|
||||||
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
|
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
-- 043_add_group_invalid_request_fallback.sql
|
||||||
|
-- 添加无效请求兜底分组配置
|
||||||
|
|
||||||
|
-- 添加 fallback_group_id_on_invalid_request 字段:无效请求兜底使用的分组
|
||||||
|
ALTER TABLE groups
|
||||||
|
ADD COLUMN IF NOT EXISTS fallback_group_id_on_invalid_request BIGINT REFERENCES groups(id) ON DELETE SET NULL;
|
||||||
|
|
||||||
|
-- 添加索引优化查询
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_groups_fallback_group_id_on_invalid_request
|
||||||
|
ON groups(fallback_group_id_on_invalid_request) WHERE deleted_at IS NULL AND fallback_group_id_on_invalid_request IS NOT NULL;
|
||||||
|
|
||||||
|
-- 添加字段注释
|
||||||
|
COMMENT ON COLUMN groups.fallback_group_id_on_invalid_request IS '无效请求兜底使用的分组 ID';
|
||||||
@@ -919,6 +919,11 @@ export default {
|
|||||||
fallbackHint: 'Non-Claude Code requests will use this group. Leave empty to reject directly.',
|
fallbackHint: 'Non-Claude Code requests will use this group. Leave empty to reject directly.',
|
||||||
noFallback: 'No Fallback (Reject)'
|
noFallback: 'No Fallback (Reject)'
|
||||||
},
|
},
|
||||||
|
invalidRequestFallback: {
|
||||||
|
title: 'Invalid Request Fallback Group',
|
||||||
|
hint: 'Triggered only when upstream explicitly returns prompt too long. Leave empty to disable fallback.',
|
||||||
|
noFallback: 'No Fallback'
|
||||||
|
},
|
||||||
modelRouting: {
|
modelRouting: {
|
||||||
title: 'Model Routing',
|
title: 'Model Routing',
|
||||||
tooltip: 'Configure specific model requests to be routed to designated accounts. Supports wildcard matching, e.g., claude-opus-* matches all opus models.',
|
tooltip: 'Configure specific model requests to be routed to designated accounts. Supports wildcard matching, e.g., claude-opus-* matches all opus models.',
|
||||||
|
|||||||
@@ -995,6 +995,11 @@ export default {
|
|||||||
fallbackHint: '非 Claude Code 请求将使用此分组,留空则直接拒绝',
|
fallbackHint: '非 Claude Code 请求将使用此分组,留空则直接拒绝',
|
||||||
noFallback: '不降级(直接拒绝)'
|
noFallback: '不降级(直接拒绝)'
|
||||||
},
|
},
|
||||||
|
invalidRequestFallback: {
|
||||||
|
title: '无效请求兜底分组',
|
||||||
|
hint: '仅当上游明确返回 prompt too long 时才会触发,留空表示不兜底',
|
||||||
|
noFallback: '不兜底'
|
||||||
|
},
|
||||||
modelRouting: {
|
modelRouting: {
|
||||||
title: '模型路由配置',
|
title: '模型路由配置',
|
||||||
tooltip: '配置特定模型请求优先路由到指定账号。支持通配符匹配,如 claude-opus-* 匹配所有 opus 模型。',
|
tooltip: '配置特定模型请求优先路由到指定账号。支持通配符匹配,如 claude-opus-* 匹配所有 opus 模型。',
|
||||||
|
|||||||
@@ -269,6 +269,7 @@ export interface Group {
|
|||||||
// Claude Code 客户端限制
|
// Claude Code 客户端限制
|
||||||
claude_code_only: boolean
|
claude_code_only: boolean
|
||||||
fallback_group_id: number | null
|
fallback_group_id: number | null
|
||||||
|
fallback_group_id_on_invalid_request: number | null
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
model_routing: Record<string, number[]> | null
|
model_routing: Record<string, number[]> | null
|
||||||
model_routing_enabled: boolean
|
model_routing_enabled: boolean
|
||||||
@@ -322,6 +323,7 @@ export interface CreateGroupRequest {
|
|||||||
image_price_4k?: number | null
|
image_price_4k?: number | null
|
||||||
claude_code_only?: boolean
|
claude_code_only?: boolean
|
||||||
fallback_group_id?: number | null
|
fallback_group_id?: number | null
|
||||||
|
fallback_group_id_on_invalid_request?: number | null
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface UpdateGroupRequest {
|
export interface UpdateGroupRequest {
|
||||||
@@ -340,6 +342,7 @@ export interface UpdateGroupRequest {
|
|||||||
image_price_4k?: number | null
|
image_price_4k?: number | null
|
||||||
claude_code_only?: boolean
|
claude_code_only?: boolean
|
||||||
fallback_group_id?: number | null
|
fallback_group_id?: number | null
|
||||||
|
fallback_group_id_on_invalid_request?: number | null
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================== Account & Proxy Types ====================
|
// ==================== Account & Proxy Types ====================
|
||||||
|
|||||||
@@ -460,6 +460,20 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- 无效请求兜底(仅 anthropic/antigravity 平台,且非订阅分组) -->
|
||||||
|
<div
|
||||||
|
v-if="['anthropic', 'antigravity'].includes(createForm.platform) && createForm.subscription_type !== 'subscription'"
|
||||||
|
class="border-t pt-4"
|
||||||
|
>
|
||||||
|
<label class="input-label">{{ t('admin.groups.invalidRequestFallback.title') }}</label>
|
||||||
|
<Select
|
||||||
|
v-model="createForm.fallback_group_id_on_invalid_request"
|
||||||
|
:options="invalidRequestFallbackOptions"
|
||||||
|
:placeholder="t('admin.groups.invalidRequestFallback.noFallback')"
|
||||||
|
/>
|
||||||
|
<p class="input-hint">{{ t('admin.groups.invalidRequestFallback.hint') }}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- 模型路由配置(仅 anthropic 平台) -->
|
<!-- 模型路由配置(仅 anthropic 平台) -->
|
||||||
<div v-if="createForm.platform === 'anthropic'" class="border-t pt-4">
|
<div v-if="createForm.platform === 'anthropic'" class="border-t pt-4">
|
||||||
<div class="mb-1.5 flex items-center gap-1">
|
<div class="mb-1.5 flex items-center gap-1">
|
||||||
@@ -904,6 +918,20 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- 无效请求兜底(仅 anthropic/antigravity 平台,且非订阅分组) -->
|
||||||
|
<div
|
||||||
|
v-if="['anthropic', 'antigravity'].includes(editForm.platform) && editForm.subscription_type !== 'subscription'"
|
||||||
|
class="border-t pt-4"
|
||||||
|
>
|
||||||
|
<label class="input-label">{{ t('admin.groups.invalidRequestFallback.title') }}</label>
|
||||||
|
<Select
|
||||||
|
v-model="editForm.fallback_group_id_on_invalid_request"
|
||||||
|
:options="invalidRequestFallbackOptionsForEdit"
|
||||||
|
:placeholder="t('admin.groups.invalidRequestFallback.noFallback')"
|
||||||
|
/>
|
||||||
|
<p class="input-hint">{{ t('admin.groups.invalidRequestFallback.hint') }}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- 模型路由配置(仅 anthropic 平台) -->
|
<!-- 模型路由配置(仅 anthropic 平台) -->
|
||||||
<div v-if="editForm.platform === 'anthropic'" class="border-t pt-4">
|
<div v-if="editForm.platform === 'anthropic'" class="border-t pt-4">
|
||||||
<div class="mb-1.5 flex items-center gap-1">
|
<div class="mb-1.5 flex items-center gap-1">
|
||||||
@@ -1202,6 +1230,44 @@ const fallbackGroupOptionsForEdit = computed(() => {
|
|||||||
return options
|
return options
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 无效请求兜底分组选项(创建时)- 仅包含 anthropic 平台、非订阅且未配置兜底的分组
|
||||||
|
const invalidRequestFallbackOptions = computed(() => {
|
||||||
|
const options: { value: number | null; label: string }[] = [
|
||||||
|
{ value: null, label: t('admin.groups.invalidRequestFallback.noFallback') }
|
||||||
|
]
|
||||||
|
const eligibleGroups = groups.value.filter(
|
||||||
|
(g) =>
|
||||||
|
g.platform === 'anthropic' &&
|
||||||
|
g.status === 'active' &&
|
||||||
|
g.subscription_type !== 'subscription' &&
|
||||||
|
g.fallback_group_id_on_invalid_request === null
|
||||||
|
)
|
||||||
|
eligibleGroups.forEach((g) => {
|
||||||
|
options.push({ value: g.id, label: g.name })
|
||||||
|
})
|
||||||
|
return options
|
||||||
|
})
|
||||||
|
|
||||||
|
// 无效请求兜底分组选项(编辑时)- 排除自身
|
||||||
|
const invalidRequestFallbackOptionsForEdit = computed(() => {
|
||||||
|
const options: { value: number | null; label: string }[] = [
|
||||||
|
{ value: null, label: t('admin.groups.invalidRequestFallback.noFallback') }
|
||||||
|
]
|
||||||
|
const currentId = editingGroup.value?.id
|
||||||
|
const eligibleGroups = groups.value.filter(
|
||||||
|
(g) =>
|
||||||
|
g.platform === 'anthropic' &&
|
||||||
|
g.status === 'active' &&
|
||||||
|
g.subscription_type !== 'subscription' &&
|
||||||
|
g.fallback_group_id_on_invalid_request === null &&
|
||||||
|
g.id !== currentId
|
||||||
|
)
|
||||||
|
eligibleGroups.forEach((g) => {
|
||||||
|
options.push({ value: g.id, label: g.name })
|
||||||
|
})
|
||||||
|
return options
|
||||||
|
})
|
||||||
|
|
||||||
const groups = ref<Group[]>([])
|
const groups = ref<Group[]>([])
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const searchQuery = ref('')
|
const searchQuery = ref('')
|
||||||
@@ -1243,6 +1309,7 @@ const createForm = reactive({
|
|||||||
// Claude Code 客户端限制(仅 anthropic 平台使用)
|
// Claude Code 客户端限制(仅 anthropic 平台使用)
|
||||||
claude_code_only: false,
|
claude_code_only: false,
|
||||||
fallback_group_id: null as number | null,
|
fallback_group_id: null as number | null,
|
||||||
|
fallback_group_id_on_invalid_request: null as number | null,
|
||||||
// 模型路由开关
|
// 模型路由开关
|
||||||
model_routing_enabled: false
|
model_routing_enabled: false
|
||||||
})
|
})
|
||||||
@@ -1414,6 +1481,7 @@ const editForm = reactive({
|
|||||||
// Claude Code 客户端限制(仅 anthropic 平台使用)
|
// Claude Code 客户端限制(仅 anthropic 平台使用)
|
||||||
claude_code_only: false,
|
claude_code_only: false,
|
||||||
fallback_group_id: null as number | null,
|
fallback_group_id: null as number | null,
|
||||||
|
fallback_group_id_on_invalid_request: null as number | null,
|
||||||
// 模型路由开关
|
// 模型路由开关
|
||||||
model_routing_enabled: false
|
model_routing_enabled: false
|
||||||
})
|
})
|
||||||
@@ -1497,6 +1565,7 @@ const closeCreateModal = () => {
|
|||||||
createForm.image_price_4k = null
|
createForm.image_price_4k = null
|
||||||
createForm.claude_code_only = false
|
createForm.claude_code_only = false
|
||||||
createForm.fallback_group_id = null
|
createForm.fallback_group_id = null
|
||||||
|
createForm.fallback_group_id_on_invalid_request = null
|
||||||
createModelRoutingRules.value = []
|
createModelRoutingRules.value = []
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1546,6 +1615,7 @@ const handleEdit = async (group: Group) => {
|
|||||||
editForm.image_price_4k = group.image_price_4k
|
editForm.image_price_4k = group.image_price_4k
|
||||||
editForm.claude_code_only = group.claude_code_only || false
|
editForm.claude_code_only = group.claude_code_only || false
|
||||||
editForm.fallback_group_id = group.fallback_group_id
|
editForm.fallback_group_id = group.fallback_group_id
|
||||||
|
editForm.fallback_group_id_on_invalid_request = group.fallback_group_id_on_invalid_request
|
||||||
editForm.model_routing_enabled = group.model_routing_enabled || false
|
editForm.model_routing_enabled = group.model_routing_enabled || false
|
||||||
// 加载模型路由规则(异步加载账号名称)
|
// 加载模型路由规则(异步加载账号名称)
|
||||||
editModelRoutingRules.value = await convertApiFormatToRoutingRules(group.model_routing)
|
editModelRoutingRules.value = await convertApiFormatToRoutingRules(group.model_routing)
|
||||||
@@ -1571,6 +1641,10 @@ const handleUpdateGroup = async () => {
|
|||||||
const payload = {
|
const payload = {
|
||||||
...editForm,
|
...editForm,
|
||||||
fallback_group_id: editForm.fallback_group_id === null ? 0 : editForm.fallback_group_id,
|
fallback_group_id: editForm.fallback_group_id === null ? 0 : editForm.fallback_group_id,
|
||||||
|
fallback_group_id_on_invalid_request:
|
||||||
|
editForm.fallback_group_id_on_invalid_request === null
|
||||||
|
? 0
|
||||||
|
: editForm.fallback_group_id_on_invalid_request,
|
||||||
model_routing: convertRoutingRulesToApiFormat(editModelRoutingRules.value)
|
model_routing: convertRoutingRulesToApiFormat(editModelRoutingRules.value)
|
||||||
}
|
}
|
||||||
await adminAPI.groups.update(editingGroup.value.id, payload)
|
await adminAPI.groups.update(editingGroup.value.id, payload)
|
||||||
@@ -1612,6 +1686,16 @@ watch(
|
|||||||
if (newVal === 'subscription') {
|
if (newVal === 'subscription') {
|
||||||
createForm.rate_multiplier = 1.0
|
createForm.rate_multiplier = 1.0
|
||||||
createForm.is_exclusive = true
|
createForm.is_exclusive = true
|
||||||
|
createForm.fallback_group_id_on_invalid_request = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
watch(
|
||||||
|
() => createForm.platform,
|
||||||
|
(newVal) => {
|
||||||
|
if (!['anthropic', 'antigravity'].includes(newVal)) {
|
||||||
|
createForm.fallback_group_id_on_invalid_request = null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user