diff --git a/backend/ent/group.go b/backend/ent/group.go index 0d0c0538..1eb05e0e 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -56,10 +56,16 @@ type Group struct { ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + // 无效请求兜底使用的分组 ID + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` // 模型路由配置:模型模式 -> 优先账号ID列表 ModelRouting map[string][]int64 `json:"model_routing,omitempty"` // 是否启用模型路由配置 ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"` + // 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台) + McpXMLInject bool `json:"mcp_xml_inject,omitempty"` + // 支持的模型系列:claude, gemini_text, gemini_image + SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -166,13 +172,13 @@ func (*Group) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case group.FieldModelRouting: + case group.FieldModelRouting, group.FieldSupportedModelScopes: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled: + case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -322,6 +328,13 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.FallbackGroupID = new(int64) *_m.FallbackGroupID = value.Int64 } + case group.FieldFallbackGroupIDOnInvalidRequest: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field fallback_group_id_on_invalid_request", values[i]) + } else if value.Valid { + _m.FallbackGroupIDOnInvalidRequest = new(int64) + *_m.FallbackGroupIDOnInvalidRequest = value.Int64 + } case group.FieldModelRouting: if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field model_routing", values[i]) @@ -336,6 +349,20 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.ModelRoutingEnabled = value.Bool } + case group.FieldMcpXMLInject: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field mcp_xml_inject", values[i]) + } else if value.Valid { + _m.McpXMLInject = value.Bool + } + case group.FieldSupportedModelScopes: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field supported_model_scopes", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.SupportedModelScopes); err != nil { + return fmt.Errorf("unmarshal field supported_model_scopes: %w", err) + } + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -487,11 +514,22 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + if v := _m.FallbackGroupIDOnInvalidRequest; v != nil { + builder.WriteString("fallback_group_id_on_invalid_request=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") builder.WriteString("model_routing=") builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting)) builder.WriteString(", ") builder.WriteString("model_routing_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled)) + builder.WriteString(", ") + builder.WriteString("mcp_xml_inject=") + builder.WriteString(fmt.Sprintf("%v", _m.McpXMLInject)) + builder.WriteString(", ") + builder.WriteString("supported_model_scopes=") + builder.WriteString(fmt.Sprintf("%v", _m.SupportedModelScopes)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index d66d3edc..278b2daf 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -53,10 +53,16 @@ const ( FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. FieldFallbackGroupID = "fallback_group_id" + // FieldFallbackGroupIDOnInvalidRequest holds the string denoting the fallback_group_id_on_invalid_request field in the database. + FieldFallbackGroupIDOnInvalidRequest = "fallback_group_id_on_invalid_request" // FieldModelRouting holds the string denoting the model_routing field in the database. FieldModelRouting = "model_routing" // FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database. FieldModelRoutingEnabled = "model_routing_enabled" + // FieldMcpXMLInject holds the string denoting the mcp_xml_inject field in the database. + FieldMcpXMLInject = "mcp_xml_inject" + // FieldSupportedModelScopes holds the string denoting the supported_model_scopes field in the database. + FieldSupportedModelScopes = "supported_model_scopes" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -151,8 +157,11 @@ var Columns = []string{ FieldImagePrice4k, FieldClaudeCodeOnly, FieldFallbackGroupID, + FieldFallbackGroupIDOnInvalidRequest, FieldModelRouting, FieldModelRoutingEnabled, + FieldMcpXMLInject, + FieldSupportedModelScopes, } var ( @@ -212,6 +221,10 @@ var ( DefaultClaudeCodeOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. DefaultModelRoutingEnabled bool + // DefaultMcpXMLInject holds the default value on creation for the "mcp_xml_inject" field. + DefaultMcpXMLInject bool + // DefaultSupportedModelScopes holds the default value on creation for the "supported_model_scopes" field. + DefaultSupportedModelScopes []string ) // OrderOption defines the ordering options for the Group queries. @@ -317,11 +330,21 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc() } +// ByFallbackGroupIDOnInvalidRequest orders the results by the fallback_group_id_on_invalid_request field. +func ByFallbackGroupIDOnInvalidRequest(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFallbackGroupIDOnInvalidRequest, opts...).ToFunc() +} + // ByModelRoutingEnabled orders the results by the model_routing_enabled field. func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc() } +// ByMcpXMLInject orders the results by the mcp_xml_inject field. +func ByMcpXMLInject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMcpXMLInject, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 6ce9e4c6..b6fa2c33 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -150,11 +150,21 @@ func FallbackGroupID(v int64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v)) } +// FallbackGroupIDOnInvalidRequest applies equality check predicate on the "fallback_group_id_on_invalid_request" field. It's identical to FallbackGroupIDOnInvalidRequestEQ. +func FallbackGroupIDOnInvalidRequest(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + // ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ. func ModelRoutingEnabled(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v)) } +// McpXMLInject applies equality check predicate on the "mcp_xml_inject" field. It's identical to McpXMLInjectEQ. +func McpXMLInject(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1070,6 +1080,56 @@ func FallbackGroupIDNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID)) } +// FallbackGroupIDOnInvalidRequestEQ applies the EQ predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestNEQ applies the NEQ predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestIn applies the In predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldFallbackGroupIDOnInvalidRequest, vs...)) +} + +// FallbackGroupIDOnInvalidRequestNotIn applies the NotIn predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNotIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldFallbackGroupIDOnInvalidRequest, vs...)) +} + +// FallbackGroupIDOnInvalidRequestGT applies the GT predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestGT(v int64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestGTE applies the GTE predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestGTE(v int64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestLT applies the LT predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestLT(v int64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestLTE applies the LTE predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestLTE(v int64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestIsNil applies the IsNil predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldFallbackGroupIDOnInvalidRequest)) +} + +// FallbackGroupIDOnInvalidRequestNotNil applies the NotNil predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldFallbackGroupIDOnInvalidRequest)) +} + // ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field. func ModelRoutingIsNil() predicate.Group { return predicate.Group(sql.FieldIsNull(FieldModelRouting)) @@ -1090,6 +1150,16 @@ func ModelRoutingEnabledNEQ(v bool) predicate.Group { return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v)) } +// McpXMLInjectEQ applies the EQ predicate on the "mcp_xml_inject" field. +func McpXMLInjectEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v)) +} + +// McpXMLInjectNEQ applies the NEQ predicate on the "mcp_xml_inject" field. +func McpXMLInjectNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldMcpXMLInject, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 0f251e0b..9d845b61 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -286,6 +286,20 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate { return _c } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_c *GroupCreate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupCreate { + _c.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _c +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_c *GroupCreate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupCreate { + if v != nil { + _c.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _c +} + // SetModelRouting sets the "model_routing" field. func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate { _c.mutation.SetModelRouting(v) @@ -306,6 +320,26 @@ func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate { return _c } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (_c *GroupCreate) SetMcpXMLInject(v bool) *GroupCreate { + _c.mutation.SetMcpXMLInject(v) + return _c +} + +// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil. +func (_c *GroupCreate) SetNillableMcpXMLInject(v *bool) *GroupCreate { + if v != nil { + _c.SetMcpXMLInject(*v) + } + return _c +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (_c *GroupCreate) SetSupportedModelScopes(v []string) *GroupCreate { + _c.mutation.SetSupportedModelScopes(v) + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -479,6 +513,14 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultModelRoutingEnabled _c.mutation.SetModelRoutingEnabled(v) } + if _, ok := _c.mutation.McpXMLInject(); !ok { + v := group.DefaultMcpXMLInject + _c.mutation.SetMcpXMLInject(v) + } + if _, ok := _c.mutation.SupportedModelScopes(); !ok { + v := group.DefaultSupportedModelScopes + _c.mutation.SetSupportedModelScopes(v) + } return nil } @@ -537,6 +579,12 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.ModelRoutingEnabled(); !ok { return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)} } + if _, ok := _c.mutation.McpXMLInject(); !ok { + return &ValidationError{Name: "mcp_xml_inject", err: errors.New(`ent: missing required field "Group.mcp_xml_inject"`)} + } + if _, ok := _c.mutation.SupportedModelScopes(); !ok { + return &ValidationError{Name: "supported_model_scopes", err: errors.New(`ent: missing required field "Group.supported_model_scopes"`)} + } return nil } @@ -640,6 +688,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) _node.FallbackGroupID = &value } + if value, ok := _c.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + _node.FallbackGroupIDOnInvalidRequest = &value + } if value, ok := _c.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) _node.ModelRouting = value @@ -648,6 +700,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) _node.ModelRoutingEnabled = value } + if value, ok := _c.mutation.McpXMLInject(); ok { + _spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value) + _node.McpXMLInject = value + } + if value, ok := _c.mutation.SupportedModelScopes(); ok { + _spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value) + _node.SupportedModelScopes = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1128,6 +1188,30 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert { return u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert { + u.Set(group.FieldFallbackGroupIDOnInvalidRequest, v) + return u +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsert) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsert { + u.SetExcluded(group.FieldFallbackGroupIDOnInvalidRequest) + return u +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert { + u.Add(group.FieldFallbackGroupIDOnInvalidRequest, v) + return u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsert { + u.SetNull(group.FieldFallbackGroupIDOnInvalidRequest) + return u +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert { u.Set(group.FieldModelRouting, v) @@ -1158,6 +1242,30 @@ func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert { return u } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (u *GroupUpsert) SetMcpXMLInject(v bool) *GroupUpsert { + u.Set(group.FieldMcpXMLInject, v) + return u +} + +// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create. +func (u *GroupUpsert) UpdateMcpXMLInject() *GroupUpsert { + u.SetExcluded(group.FieldMcpXMLInject) + return u +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (u *GroupUpsert) SetSupportedModelScopes(v []string) *GroupUpsert { + u.Set(group.FieldSupportedModelScopes, v) + return u +} + +// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSupportedModelScopes() *GroupUpsert { + u.SetExcluded(group.FieldSupportedModelScopes) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1581,6 +1689,34 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne { }) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupIDOnInvalidRequest(v) + }) +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupIDOnInvalidRequest(v) + }) +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupIDOnInvalidRequest() + }) +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupIDOnInvalidRequest() + }) +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -1616,6 +1752,34 @@ func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne { }) } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (u *GroupUpsertOne) SetMcpXMLInject(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetMcpXMLInject(v) + }) +} + +// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateMcpXMLInject() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateMcpXMLInject() + }) +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (u *GroupUpsertOne) SetSupportedModelScopes(v []string) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSupportedModelScopes(v) + }) +} + +// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSupportedModelScopes() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSupportedModelScopes() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2205,6 +2369,34 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk { }) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupIDOnInvalidRequest(v) + }) +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupIDOnInvalidRequest(v) + }) +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupIDOnInvalidRequest() + }) +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupIDOnInvalidRequest() + }) +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { @@ -2240,6 +2432,34 @@ func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk { }) } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (u *GroupUpsertBulk) SetMcpXMLInject(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetMcpXMLInject(v) + }) +} + +// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateMcpXMLInject() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateMcpXMLInject() + }) +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (u *GroupUpsertBulk) SetSupportedModelScopes(v []string) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSupportedModelScopes(v) + }) +} + +// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSupportedModelScopes() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSupportedModelScopes() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index c3cc2708..9e7246ea 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -10,6 +10,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/apikey" @@ -395,6 +396,33 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate { return _u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate { + _u.mutation.ResetFallbackGroupIDOnInvalidRequest() + _u.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdate { + if v != nil { + _u.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _u +} + +// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate { + _u.mutation.AddFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdate { + _u.mutation.ClearFallbackGroupIDOnInvalidRequest() + return _u +} + // SetModelRouting sets the "model_routing" field. func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate { _u.mutation.SetModelRouting(v) @@ -421,6 +449,32 @@ func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate { return _u } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (_u *GroupUpdate) SetMcpXMLInject(v bool) *GroupUpdate { + _u.mutation.SetMcpXMLInject(v) + return _u +} + +// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableMcpXMLInject(v *bool) *GroupUpdate { + if v != nil { + _u.SetMcpXMLInject(*v) + } + return _u +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (_u *GroupUpdate) SetSupportedModelScopes(v []string) *GroupUpdate { + _u.mutation.SetSupportedModelScopes(v) + return _u +} + +// AppendSupportedModelScopes appends value to the "supported_model_scopes" field. +func (_u *GroupUpdate) AppendSupportedModelScopes(v []string) *GroupUpdate { + _u.mutation.AppendSupportedModelScopes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -829,6 +883,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.FallbackGroupIDCleared() { _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) } + if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok { + _spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() { + _spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64) + } if value, ok := _u.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) } @@ -838,6 +901,17 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.ModelRoutingEnabled(); ok { _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.McpXMLInject(); ok { + _spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value) + } + if value, ok := _u.mutation.SupportedModelScopes(); ok { + _spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedSupportedModelScopes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, group.FieldSupportedModelScopes, value) + }) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1513,6 +1587,33 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne { return _u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne { + _u.mutation.ResetFallbackGroupIDOnInvalidRequest() + _u.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdateOne { + if v != nil { + _u.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _u +} + +// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne { + _u.mutation.AddFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdateOne { + _u.mutation.ClearFallbackGroupIDOnInvalidRequest() + return _u +} + // SetModelRouting sets the "model_routing" field. func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne { _u.mutation.SetModelRouting(v) @@ -1539,6 +1640,32 @@ func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOn return _u } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (_u *GroupUpdateOne) SetMcpXMLInject(v bool) *GroupUpdateOne { + _u.mutation.SetMcpXMLInject(v) + return _u +} + +// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableMcpXMLInject(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetMcpXMLInject(*v) + } + return _u +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (_u *GroupUpdateOne) SetSupportedModelScopes(v []string) *GroupUpdateOne { + _u.mutation.SetSupportedModelScopes(v) + return _u +} + +// AppendSupportedModelScopes appends value to the "supported_model_scopes" field. +func (_u *GroupUpdateOne) AppendSupportedModelScopes(v []string) *GroupUpdateOne { + _u.mutation.AppendSupportedModelScopes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1977,6 +2104,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.FallbackGroupIDCleared() { _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) } + if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok { + _spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() { + _spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64) + } if value, ok := _u.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) } @@ -1986,6 +2122,17 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.ModelRoutingEnabled(); ok { _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.McpXMLInject(); ok { + _spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value) + } + if value, ok := _u.mutation.SupportedModelScopes(); ok { + _spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedSupportedModelScopes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, group.FieldSupportedModelScopes, value) + }) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index ee6b69c8..dc91f6a5 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -331,8 +331,11 @@ var ( {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, + {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, {Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "model_routing_enabled", Type: field.TypeBool, Default: false}, + {Name: "mcp_xml_inject", Type: field.TypeBool, Default: true}, + {Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 3cc3b36f..77d208e1 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -5753,61 +5753,66 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error { // GroupMutation represents an operation that mutates the Group nodes in the graph. type GroupMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - rate_multiplier *float64 - addrate_multiplier *float64 - is_exclusive *bool - status *string - platform *string - subscription_type *string - daily_limit_usd *float64 - adddaily_limit_usd *float64 - weekly_limit_usd *float64 - addweekly_limit_usd *float64 - monthly_limit_usd *float64 - addmonthly_limit_usd *float64 - default_validity_days *int - adddefault_validity_days *int - image_price_1k *float64 - addimage_price_1k *float64 - image_price_2k *float64 - addimage_price_2k *float64 - image_price_4k *float64 - addimage_price_4k *float64 - claude_code_only *bool - fallback_group_id *int64 - addfallback_group_id *int64 - model_routing *map[string][]int64 - model_routing_enabled *bool - clearedFields map[string]struct{} - api_keys map[int64]struct{} - removedapi_keys map[int64]struct{} - clearedapi_keys bool - redeem_codes map[int64]struct{} - removedredeem_codes map[int64]struct{} - clearedredeem_codes bool - subscriptions map[int64]struct{} - removedsubscriptions map[int64]struct{} - clearedsubscriptions bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - accounts map[int64]struct{} - removedaccounts map[int64]struct{} - clearedaccounts bool - allowed_users map[int64]struct{} - removedallowed_users map[int64]struct{} - clearedallowed_users bool - done bool - oldValue func(context.Context) (*Group, error) - predicates []predicate.Group + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + rate_multiplier *float64 + addrate_multiplier *float64 + is_exclusive *bool + status *string + platform *string + subscription_type *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + default_validity_days *int + adddefault_validity_days *int + image_price_1k *float64 + addimage_price_1k *float64 + image_price_2k *float64 + addimage_price_2k *float64 + image_price_4k *float64 + addimage_price_4k *float64 + claude_code_only *bool + fallback_group_id *int64 + addfallback_group_id *int64 + fallback_group_id_on_invalid_request *int64 + addfallback_group_id_on_invalid_request *int64 + model_routing *map[string][]int64 + model_routing_enabled *bool + mcp_xml_inject *bool + supported_model_scopes *[]string + appendsupported_model_scopes []string + clearedFields map[string]struct{} + api_keys map[int64]struct{} + removedapi_keys map[int64]struct{} + clearedapi_keys bool + redeem_codes map[int64]struct{} + removedredeem_codes map[int64]struct{} + clearedredeem_codes bool + subscriptions map[int64]struct{} + removedsubscriptions map[int64]struct{} + clearedsubscriptions bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + allowed_users map[int64]struct{} + removedallowed_users map[int64]struct{} + clearedallowed_users bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group } var _ ent.Mutation = (*GroupMutation)(nil) @@ -6896,6 +6901,76 @@ func (m *GroupMutation) ResetFallbackGroupID() { delete(m.clearedFields, group.FieldFallbackGroupID) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) { + m.fallback_group_id_on_invalid_request = &i + m.addfallback_group_id_on_invalid_request = nil +} + +// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.fallback_group_id_on_invalid_request + if v == nil { + return + } + return *v, true +} + +// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err) + } + return oldValue.FallbackGroupIDOnInvalidRequest, nil +} + +// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) { + if m.addfallback_group_id_on_invalid_request != nil { + *m.addfallback_group_id_on_invalid_request += i + } else { + m.addfallback_group_id_on_invalid_request = &i + } +} + +// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.addfallback_group_id_on_invalid_request + if v == nil { + return + } + return *v, true +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{} +} + +// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] + return ok +} + +// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest) +} + // SetModelRouting sets the "model_routing" field. func (m *GroupMutation) SetModelRouting(value map[string][]int64) { m.model_routing = &value @@ -6981,6 +7056,93 @@ func (m *GroupMutation) ResetModelRoutingEnabled() { m.model_routing_enabled = nil } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (m *GroupMutation) SetMcpXMLInject(b bool) { + m.mcp_xml_inject = &b +} + +// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation. +func (m *GroupMutation) McpXMLInject() (r bool, exists bool) { + v := m.mcp_xml_inject + if v == nil { + return + } + return *v, true +} + +// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMcpXMLInject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err) + } + return oldValue.McpXMLInject, nil +} + +// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field. +func (m *GroupMutation) ResetMcpXMLInject() { + m.mcp_xml_inject = nil +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (m *GroupMutation) SetSupportedModelScopes(s []string) { + m.supported_model_scopes = &s + m.appendsupported_model_scopes = nil +} + +// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation. +func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) { + v := m.supported_model_scopes + if v == nil { + return + } + return *v, true +} + +// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err) + } + return oldValue.SupportedModelScopes, nil +} + +// AppendSupportedModelScopes adds s to the "supported_model_scopes" field. +func (m *GroupMutation) AppendSupportedModelScopes(s []string) { + m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...) +} + +// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation. +func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) { + if len(m.appendsupported_model_scopes) == 0 { + return nil, false + } + return m.appendsupported_model_scopes, true +} + +// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field. +func (m *GroupMutation) ResetSupportedModelScopes() { + m.supported_model_scopes = nil + m.appendsupported_model_scopes = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -7339,7 +7501,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 21) + fields := make([]string, 0, 24) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -7397,12 +7559,21 @@ func (m *GroupMutation) Fields() []string { if m.fallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } + if m.fallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } if m.model_routing != nil { fields = append(fields, group.FieldModelRouting) } if m.model_routing_enabled != nil { fields = append(fields, group.FieldModelRoutingEnabled) } + if m.mcp_xml_inject != nil { + fields = append(fields, group.FieldMcpXMLInject) + } + if m.supported_model_scopes != nil { + fields = append(fields, group.FieldSupportedModelScopes) + } return fields } @@ -7449,10 +7620,16 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: return m.FallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.FallbackGroupIDOnInvalidRequest() case group.FieldModelRouting: return m.ModelRouting() case group.FieldModelRoutingEnabled: return m.ModelRoutingEnabled() + case group.FieldMcpXMLInject: + return m.McpXMLInject() + case group.FieldSupportedModelScopes: + return m.SupportedModelScopes() } return nil, false } @@ -7500,10 +7677,16 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: return m.OldFallbackGroupID(ctx) + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.OldFallbackGroupIDOnInvalidRequest(ctx) case group.FieldModelRouting: return m.OldModelRouting(ctx) case group.FieldModelRoutingEnabled: return m.OldModelRoutingEnabled(ctx) + case group.FieldMcpXMLInject: + return m.OldMcpXMLInject(ctx) + case group.FieldSupportedModelScopes: + return m.OldSupportedModelScopes(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -7646,6 +7829,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetFallbackGroupID(v) return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupIDOnInvalidRequest(v) + return nil case group.FieldModelRouting: v, ok := value.(map[string][]int64) if !ok { @@ -7660,6 +7850,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetModelRoutingEnabled(v) return nil + case group.FieldMcpXMLInject: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMcpXMLInject(v) + return nil + case group.FieldSupportedModelScopes: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSupportedModelScopes(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -7695,6 +7899,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } + if m.addfallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } return fields } @@ -7721,6 +7928,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice4k() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.AddedFallbackGroupIDOnInvalidRequest() } return nil, false } @@ -7793,6 +8002,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddFallbackGroupID(v) return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupIDOnInvalidRequest(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -7828,6 +8044,9 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } + if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } if m.FieldCleared(group.FieldModelRouting) { fields = append(fields, group.FieldModelRouting) } @@ -7872,6 +8091,9 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ClearFallbackGroupIDOnInvalidRequest() + return nil case group.FieldModelRouting: m.ClearModelRouting() return nil @@ -7940,12 +8162,21 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldFallbackGroupID: m.ResetFallbackGroupID() return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ResetFallbackGroupIDOnInvalidRequest() + return nil case group.FieldModelRouting: m.ResetModelRouting() return nil case group.FieldModelRoutingEnabled: m.ResetModelRoutingEnabled() return nil + case group.FieldMcpXMLInject: + m.ResetMcpXMLInject() + return nil + case group.FieldSupportedModelScopes: + m.ResetSupportedModelScopes() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index c963e23e..f1fea8cc 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -342,9 +342,17 @@ func init() { // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[17].Descriptor() + groupDescModelRoutingEnabled := groupFields[18].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) + // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. + groupDescMcpXMLInject := groupFields[19].Descriptor() + // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. + group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) + // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. + groupDescSupportedModelScopes := groupFields[20].Descriptor() + // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. + group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index ccd72eac..8a3c1a90 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -95,6 +95,10 @@ func (Group) Fields() []ent.Field { Optional(). Nillable(). Comment("非 Claude Code 请求降级使用的分组 ID"), + field.Int64("fallback_group_id_on_invalid_request"). + Optional(). + Nillable(). + Comment("无效请求兜底使用的分组 ID"), // 模型路由配置 (added by migration 040) field.JSON("model_routing", map[string][]int64{}). @@ -106,6 +110,17 @@ func (Group) Fields() []ent.Field { field.Bool("model_routing_enabled"). Default(false). Comment("是否启用模型路由配置"), + + // MCP XML 协议注入开关 (added by migration 042) + field.Bool("mcp_xml_inject"). + Default(true). + Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"), + + // 支持的模型系列 (added by migration 046) + field.JSON("supported_model_scopes", []string{}). + Default([]string{"claude", "gemini_text", "gemini_image"}). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}). + Comment("支持的模型系列:claude, gemini_text, gemini_image"), } } diff --git a/backend/go.mod b/backend/go.mod index 4c3e6246..9a36a0f1 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -4,6 +4,8 @@ go 1.25.6 require ( entgo.io/ent v0.14.5 + github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/dgraph-io/ristretto v0.2.0 github.com/gin-gonic/gin v1.9.1 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 @@ -11,7 +13,10 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/imroc/req/v3 v3.57.0 github.com/lib/pq v1.10.9 + github.com/pquerna/otp v1.5.0 github.com/redis/go-redis/v9 v9.17.2 + github.com/refraction-networking/utls v1.8.1 + github.com/robfig/cron/v3 v3.0.1 github.com/shirou/gopsutil/v4 v4.25.6 github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.11.1 @@ -25,13 +30,13 @@ require ( golang.org/x/sync v0.19.0 golang.org/x/term v0.38.0 gopkg.in/yaml.v3 v3.0.1 + modernc.org/sqlite v1.44.3 ) require ( ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect - github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect @@ -48,7 +53,6 @@ require ( github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/dgraph-io/ristretto v0.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v28.5.1+incompatible // indirect @@ -107,13 +111,10 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect - github.com/pquerna/otp v1.5.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.57.1 // indirect - github.com/refraction-networking/utls v1.8.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.2.0 // indirect - github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect @@ -149,12 +150,10 @@ require ( golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.32.0 // indirect golang.org/x/tools v0.39.0 // indirect - golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.67.6 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect - modernc.org/sqlite v1.44.1 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index 0addb5bb..371623ad 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -55,6 +55,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE= github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU= +github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= +github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -113,6 +115,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -123,6 +127,9 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZYSo= @@ -345,8 +352,6 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= @@ -374,9 +379,8 @@ golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= -golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY= -golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= +golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM= golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -399,12 +403,32 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= -modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas= -modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.44.3 h1:+39JvV/HWMcYslAwRxHb8067w+2zowvFOUrOWIy9PjY= +modernc.org/sqlite v1.44.3/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 3655e07f..35a6a5b7 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -29,6 +29,7 @@ const ( AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) AccountTypeAPIKey = "apikey" // API Key类型账号 + AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) ) // Redeem type constants diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index bbf5d026..6d42f726 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -84,7 +84,7 @@ type CreateAccountRequest struct { Name string `json:"name" binding:"required"` Notes *string `json:"notes"` Platform string `json:"platform" binding:"required"` - Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"` + Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"` Credentials map[string]any `json:"credentials" binding:"required"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` @@ -102,7 +102,7 @@ type CreateAccountRequest struct { type UpdateAccountRequest struct { Name string `json:"name"` Notes *string `json:"notes"` - Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"` + Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index f93edbc8..d10d678b 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -35,14 +35,18 @@ type CreateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled bool `json:"model_routing_enabled"` + MCPXMLInject *bool `json:"mcp_xml_inject"` + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string `json:"supported_model_scopes"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -60,14 +64,18 @@ type UpdateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - ClaudeCodeOnly *bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + ClaudeCodeOnly *bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled *bool `json:"model_routing_enabled"` + MCPXMLInject *bool `json:"mcp_xml_inject"` + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes *[]string `json:"supported_model_scopes"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -159,23 +167,26 @@ func (h *GroupHandler) Create(c *gin.Context) { } group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, - CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, + MCPXMLInject: req.MCPXMLInject, + SupportedModelScopes: req.SupportedModelScopes, + CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { response.ErrorFrom(c, err) @@ -201,24 +212,27 @@ func (h *GroupHandler) Update(c *gin.Context) { } group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - Status: req.Status, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, - CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + Status: req.Status, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, + MCPXMLInject: req.MCPXMLInject, + SupportedModelScopes: req.SupportedModelScopes, + CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 3b3884d8..4f8d1eeb 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -108,10 +108,12 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { return nil } out := &AdminGroup{ - Group: groupFromServiceBase(g), - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - AccountCount: g.AccountCount, + Group: groupFromServiceBase(g), + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.MCPXMLInject, + SupportedModelScopes: g.SupportedModelScopes, + AccountCount: g.AccountCount, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) @@ -141,8 +143,10 @@ func groupFromServiceBase(g *service.Group) Group { ImagePrice4K: g.ImagePrice4K, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + // 无效请求兜底分组 + FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index af659f39..8e6faf02 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -72,6 +72,8 @@ type Group struct { // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` + // 无效请求兜底分组 + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` @@ -86,8 +88,13 @@ type AdminGroup struct { ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled bool `json:"model_routing_enabled"` - AccountGroups []AccountGroup `json:"account_groups,omitempty"` - AccountCount int64 `json:"account_count,omitempty"` + // MCP XML 协议注入(仅 antigravity 平台使用) + MCPXMLInject bool `json:"mcp_xml_inject"` + + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string `json:"supported_model_scopes"` + AccountGroups []AccountGroup `json:"account_groups,omitempty"` + AccountCount int64 `json:"account_count,omitempty"` } type Account struct { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 217e083a..ccf06b7f 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -14,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" @@ -284,10 +285,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body) + result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body) } else { - result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) + result, err = h.geminiCompatService.Forward(requestCtx, c, account, body) } if accountReleaseFunc != nil { accountReleaseFunc() @@ -335,140 +340,193 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } } - maxAccountSwitches := h.maxAccountSwitches - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + currentAPIKey := apiKey + currentSubscription := subscription + var fallbackGroupID *int64 + if apiKey.Group != nil { + fallbackGroupID = apiKey.Group.FallbackGroupIDOnInvalidRequest + } + fallbackUsed := false for { - // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) - if err != nil { - if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) - return - } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) - return - } - account := selection.Account - setOpsSelectedAccount(c, account.ID) + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 + retryWithFallback := false - // 检查请求拦截(预热请求、SUGGESTION MODE等) - if account.IsInterceptWarmupEnabled() { - interceptType := detectInterceptType(body) - if interceptType != InterceptTypeNone { - if selection.Acquired && selection.ReleaseFunc != nil { - selection.ReleaseFunc() - } - if reqStream { - sendMockInterceptStream(c, reqModel, interceptType) - } else { - sendMockInterceptResponse(c, reqModel, interceptType) - } - return - } - } - - // 3. 获取账号并发槽位 - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + for { + // 选择支持该模型的账号 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - defer func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - }() - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } - } - // 账号槽位/等待计数需要在超时或断开时安全回收 - accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - - // 转发请求 - 根据账号平台分流 - var result *service.ForwardResult - if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) - } else { - result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) - } - if accountReleaseFunc != nil { - accountReleaseFunc() - } - if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - lastFailoverStatus = failoverErr.StatusCode - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) - continue + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return } - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Account %d: Forward request failed: %v", account.ID, err) + account := selection.Account + setOpsSelectedAccount(c, account.ID) + + // 检查请求拦截(预热请求、SUGGESTION MODE等) + if account.IsInterceptWarmupEnabled() { + interceptType := detectInterceptType(body) + if interceptType != InterceptTypeNone { + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + if reqStream { + sendMockInterceptStream(c, reqModel, interceptType) + } else { + sendMockInterceptResponse(c, reqModel, interceptType) + } + return + } + } + + // 3. 获取账号并发槽位 + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + }() + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } + } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + // 转发请求 - 根据账号平台分流 + var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } + if account.Platform == service.PlatformAntigravity { + result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body) + } else { + result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) + } + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var promptTooLongErr *service.PromptTooLongError + if errors.As(err, &promptTooLongErr) { + log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed) + if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 { + fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID) + if err != nil { + log.Printf("Resolve fallback group failed: %v", err) + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + if fallbackGroup.Platform != service.PlatformAnthropic || + fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription || + fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { + log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType) + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup) + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil { + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + // 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度 + ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "") + c.Request = c.Request.WithContext(ctx) + currentAPIKey = fallbackAPIKey + currentSubscription = nil + fallbackUsed = true + retryWithFallback = true + break + } + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + lastFailoverStatus = failoverErr.StatusCode + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + switchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + // 错误响应已在Forward中处理,这里只记录日志 + log.Printf("Account %d: Forward request failed: %v", account.ID, err) + return + } + + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + // 异步记录使用量(subscription已在函数开头获取) + go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: currentAPIKey, + User: currentAPIKey.User, + Account: usedAccount, + Subscription: currentSubscription, + UserAgent: ua, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account, userAgent, clientIP) + return + } + if !retryWithFallback { return } - - // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) - userAgent := c.GetHeader("User-Agent") - clientIP := ip.GetClientIP(c) - - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, - }); err != nil { - log.Printf("Record usage failed: %v", err) - } - }(result, account, userAgent, clientIP) - return } } @@ -532,6 +590,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) { }) } +func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service.APIKey { + if apiKey == nil || group == nil { + return apiKey + } + cloned := *apiKey + groupID := group.ID + cloned.GroupID = &groupID + cloned.Group = group + return &cloned +} + // Usage handles getting account balance and usage statistics for CC Switch integration // GET /v1/usage func (h *GatewayHandler) Usage(c *gin.Context) { diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 23848644..787e3760 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" @@ -335,10 +336,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 5) forward (根据平台分流) var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body) + result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body) } else { - result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) + result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body) } if accountReleaseFunc != nil { accountReleaseFunc() diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index c7d657b9..d1712c98 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -40,17 +40,48 @@ const ( // URL 可用性 TTL(不可用 URL 的恢复时间) URLAvailabilityTTL = 5 * time.Minute + + // Antigravity API 端点 + antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com" + antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" ) // BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致) var BaseURLs = []string{ - "https://cloudcode-pa.googleapis.com", // prod (优先) - "https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用) + antigravityProdBaseURL, // prod (优先) + antigravityDailyBaseURL, // daily sandbox (备用) } // BaseURL 默认 URL(保持向后兼容) var BaseURL = BaseURLs[0] +// ForwardBaseURLs 返回 API 转发用的 URL 顺序(daily 优先) +func ForwardBaseURLs() []string { + if len(BaseURLs) == 0 { + return nil + } + urls := append([]string(nil), BaseURLs...) + dailyIndex := -1 + for i, url := range urls { + if url == antigravityDailyBaseURL { + dailyIndex = i + break + } + } + if dailyIndex <= 0 { + return urls + } + reordered := make([]string, 0, len(urls)) + reordered = append(reordered, urls[dailyIndex]) + for i, url := range urls { + if i == dailyIndex { + continue + } + reordered = append(reordered, url) + } + return reordered +} + // URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级) type URLAvailability struct { mu sync.RWMutex @@ -100,22 +131,37 @@ func (u *URLAvailability) IsAvailable(url string) bool { // GetAvailableURLs 返回可用的 URL 列表 // 最近成功的 URL 优先,其他按默认顺序 func (u *URLAvailability) GetAvailableURLs() []string { + return u.GetAvailableURLsWithBase(BaseURLs) +} + +// GetAvailableURLsWithBase 返回可用的 URL 列表(使用自定义顺序) +// 最近成功的 URL 优先,其他按传入顺序 +func (u *URLAvailability) GetAvailableURLsWithBase(baseURLs []string) []string { u.mu.RLock() defer u.mu.RUnlock() now := time.Now() - result := make([]string, 0, len(BaseURLs)) + result := make([]string, 0, len(baseURLs)) // 如果有最近成功的 URL 且可用,放在最前面 if u.lastSuccess != "" { - expiry, exists := u.unavailable[u.lastSuccess] - if !exists || now.After(expiry) { - result = append(result, u.lastSuccess) + found := false + for _, url := range baseURLs { + if url == u.lastSuccess { + found = true + break + } + } + if found { + expiry, exists := u.unavailable[u.lastSuccess] + if !exists || now.After(expiry) { + result = append(result, u.lastSuccess) + } } } - // 添加其他可用的 URL(按默认顺序) - for _, url := range BaseURLs { + // 添加其他可用的 URL(按传入顺序) + for _, url := range baseURLs { // 跳过已添加的 lastSuccess if url == u.lastSuccess { continue diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 63f6ee7c..a75bf6b3 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -44,11 +44,13 @@ type TransformOptions struct { // IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词; // 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。 IdentityPatch string + EnableMCPXML bool } func DefaultTransformOptions() TransformOptions { return TransformOptions{ EnableIdentityPatch: true, + EnableMCPXML: true, } } @@ -257,8 +259,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans // 添加用户的 system prompt parts = append(parts, userSystemParts...) - // 检测是否有 MCP 工具,如有则注入 XML 调用协议 - if hasMCPTools(tools) { + // 检测是否有 MCP 工具,如有且启用了 MCP XML 注入则注入 XML 调用协议 + if opts.EnableMCPXML && hasMCPTools(tools) { parts = append(parts, GeminiPart{Text: mcpXMLProtocol}) } @@ -492,9 +494,23 @@ func parseToolResultContent(content json.RawMessage, isError bool) string { } // buildGenerationConfig 构建 generationConfig +const ( + defaultMaxOutputTokens = 64000 + maxOutputTokensUpperBound = 65000 + maxOutputTokensClaude = 64000 +) + +func maxOutputTokensLimit(model string) int { + if strings.HasPrefix(model, "claude-") { + return maxOutputTokensClaude + } + return maxOutputTokensUpperBound +} + func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { + maxLimit := maxOutputTokensLimit(req.Model) config := &GeminiGenerationConfig{ - MaxOutputTokens: 64000, // 默认最大输出 + MaxOutputTokens: defaultMaxOutputTokens, // 默认最大输出 StopSequences: DefaultStopSequences, } @@ -518,6 +534,10 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { } } + if config.MaxOutputTokens > maxLimit { + config.MaxOutputTokens = maxLimit + } + // 其他参数 if req.Temperature != nil { config.Temperature = req.Temperature diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 27bb5ac5..fd7512f7 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -14,6 +14,9 @@ const ( // RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。 RetryCount Key = "ctx_retry_count" + // AccountSwitchCount 表示请求过程中发生的账号切换次数 + AccountSwitchCount Key = "ctx_account_switch_count" + // IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端 IsClaudeCodeClient Key = "ctx_is_claude_code_client" // Group 认证后的分组信息,由 API Key 认证中间件设置 diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 31b92281..c0cfd256 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -142,8 +142,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldImagePrice4k, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, + group.FieldFallbackGroupIDOnInvalidRequest, group.FieldModelRoutingEnabled, group.FieldModelRouting, + group.FieldMcpXMLInject, + group.FieldSupportedModelScopes, ) }). Only(ctx) @@ -459,28 +462,31 @@ func groupEntityToService(g *dbent.Group) *service.Group { return nil } return &service.Group{ - ID: g.ID, - Name: g.Name, - Description: derefString(g.Description), - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - Hydrated: true, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUsd, - WeeklyLimitUSD: g.WeeklyLimitUsd, - MonthlyLimitUSD: g.MonthlyLimitUsd, - ImagePrice1K: g.ImagePrice1k, - ImagePrice2K: g.ImagePrice2k, - ImagePrice4K: g.ImagePrice4k, - DefaultValidityDays: g.DefaultValidityDays, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + ID: g.ID, + Name: g.Name, + Description: derefString(g.Description), + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + Hydrated: true, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUsd, + WeeklyLimitUSD: g.WeeklyLimitUsd, + MonthlyLimitUSD: g.MonthlyLimitUsd, + ImagePrice1K: g.ImagePrice1k, + ImagePrice2K: g.ImagePrice2k, + ImagePrice4K: g.ImagePrice4k, + DefaultValidityDays: g.DefaultValidityDays, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.McpXMLInject, + SupportedModelScopes: g.SupportedModelScopes, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index a5b0512d..d8cec491 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -50,13 +50,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). - SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) + SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). + SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). + SetMcpXMLInject(groupIn.MCPXMLInject) // 设置模型路由配置 if groupIn.ModelRouting != nil { builder = builder.SetModelRouting(groupIn.ModelRouting) } + // 设置支持的模型系列(始终设置,空数组表示不限制) + builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes) + created, err := builder.Save(ctx) if err == nil { groupIn.ID = created.ID @@ -87,7 +92,6 @@ func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.G if err != nil { return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) } - return groupEntityToService(m), nil } @@ -108,7 +112,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice4k(groupIn.ImagePrice4K). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). - SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) + SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). + SetMcpXMLInject(groupIn.MCPXMLInject) // 处理 FallbackGroupID:nil 时清除,否则设置 if groupIn.FallbackGroupID != nil { @@ -116,6 +121,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er } else { builder = builder.ClearFallbackGroupID() } + // 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置 + if groupIn.FallbackGroupIDOnInvalidRequest != nil { + builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest) + } else { + builder = builder.ClearFallbackGroupIDOnInvalidRequest() + } // 处理 ModelRouting:nil 时清除,否则设置 if groupIn.ModelRouting != nil { @@ -124,6 +135,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er builder = builder.ClearModelRouting() } + // 处理 SupportedModelScopes(始终设置,空数组表示不限制) + builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes) + updated, err := builder.Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) diff --git a/backend/internal/repository/ops_repo_metrics.go b/backend/internal/repository/ops_repo_metrics.go index 713e0eb9..f1e57c38 100644 --- a/backend/internal/repository/ops_repo_metrics.go +++ b/backend/internal/repository/ops_repo_metrics.go @@ -43,6 +43,7 @@ INSERT INTO ops_system_metrics ( upstream_529_count, token_consumed, + account_switch_count, qps, tps, @@ -81,14 +82,14 @@ INSERT INTO ops_system_metrics ( $1,$2,$3,$4, $5,$6,$7,$8, $9,$10,$11, - $12,$13,$14, - $15,$16,$17,$18,$19,$20, - $21,$22,$23,$24,$25,$26, - $27,$28,$29,$30, - $31,$32, - $33,$34, - $35,$36,$37, - $38,$39 + $12,$13,$14,$15, + $16,$17,$18,$19,$20,$21, + $22,$23,$24,$25,$26,$27, + $28,$29,$30,$31, + $32,$33, + $34,$35, + $36,$37,$38, + $39,$40 )` _, err := r.db.ExecContext( @@ -109,6 +110,7 @@ INSERT INTO ops_system_metrics ( input.Upstream529Count, input.TokenConsumed, + input.AccountSwitchCount, opsNullFloat64(input.QPS), opsNullFloat64(input.TPS), @@ -177,7 +179,8 @@ SELECT db_conn_waiting, goroutine_count, - concurrency_queue_depth + concurrency_queue_depth, + account_switch_count FROM ops_system_metrics WHERE window_minutes = $1 AND platform IS NULL @@ -199,6 +202,7 @@ LIMIT 1` var dbWaiting sql.NullInt64 var goroutines sql.NullInt64 var queueDepth sql.NullInt64 + var accountSwitchCount sql.NullInt64 if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan( &out.ID, @@ -217,6 +221,7 @@ LIMIT 1` &dbWaiting, &goroutines, &queueDepth, + &accountSwitchCount, ); err != nil { return nil, err } @@ -273,6 +278,10 @@ LIMIT 1` v := int(queueDepth.Int64) out.ConcurrencyQueueDepth = &v } + if accountSwitchCount.Valid { + v := accountSwitchCount.Int64 + out.AccountSwitchCount = &v + } return &out, nil } diff --git a/backend/internal/repository/ops_repo_trends.go b/backend/internal/repository/ops_repo_trends.go index 022d1187..14394ed8 100644 --- a/backend/internal/repository/ops_repo_trends.go +++ b/backend/internal/repository/ops_repo_trends.go @@ -56,18 +56,44 @@ error_buckets AS ( AND COALESCE(status_code, 0) >= 400 GROUP BY 1 ), +switch_buckets AS ( + SELECT ` + errorBucketExpr + ` AS bucket, + COALESCE(SUM(CASE + WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1 + ELSE 0 + END), 0) AS switch_count + FROM ops_error_logs + CROSS JOIN LATERAL jsonb_array_elements( + COALESCE(NULLIF(upstream_errors, 'null'::jsonb), '[]'::jsonb) + ) AS ev + ` + errorWhere + ` + AND upstream_errors IS NOT NULL + GROUP BY 1 +), combined AS ( - SELECT COALESCE(u.bucket, e.bucket) AS bucket, - COALESCE(u.success_count, 0) AS success_count, - COALESCE(e.error_count, 0) AS error_count, - COALESCE(u.token_consumed, 0) AS token_consumed - FROM usage_buckets u - FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket + SELECT + bucket, + SUM(success_count) AS success_count, + SUM(error_count) AS error_count, + SUM(token_consumed) AS token_consumed, + SUM(switch_count) AS switch_count + FROM ( + SELECT bucket, success_count, 0 AS error_count, token_consumed, 0 AS switch_count + FROM usage_buckets + UNION ALL + SELECT bucket, 0, error_count, 0, 0 + FROM error_buckets + UNION ALL + SELECT bucket, 0, 0, 0, switch_count + FROM switch_buckets + ) t + GROUP BY bucket ) SELECT bucket, (success_count + error_count) AS request_count, - token_consumed + token_consumed, + switch_count FROM combined ORDER BY bucket ASC` @@ -84,13 +110,18 @@ ORDER BY bucket ASC` var bucket time.Time var requests int64 var tokens sql.NullInt64 - if err := rows.Scan(&bucket, &requests, &tokens); err != nil { + var switches sql.NullInt64 + if err := rows.Scan(&bucket, &requests, &tokens, &switches); err != nil { return nil, err } tokenConsumed := int64(0) if tokens.Valid { tokenConsumed = tokens.Int64 } + switchCount := int64(0) + if switches.Valid { + switchCount = switches.Int64 + } denom := float64(bucketSeconds) if denom <= 0 { @@ -103,6 +134,7 @@ ORDER BY bucket ASC` BucketStart: bucket.UTC(), RequestCount: requests, TokenConsumed: tokenConsumed, + SwitchCount: switchCount, QPS: qps, TPS: tps, }) @@ -385,6 +417,7 @@ func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points [] BucketStart: cursor, RequestCount: 0, TokenConsumed: 0, + SwitchCount: 0, QPS: 0, TPS: 0, }) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 0419942b..44264e72 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -186,6 +186,7 @@ func TestAPIContracts(t *testing.T) { "image_price_4k": null, "claude_code_only": false, "fallback_group_id": null, + "fallback_group_id_on_invalid_request": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -607,7 +608,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingService := service.NewSettingService(settingRepo, cfg) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) - authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil, nil) + authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 705ec6da..c512f235 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -111,9 +111,14 @@ type CreateGroupInput struct { ImagePrice4K *float64 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled bool // 是否启用模型路由 + MCPXMLInject *bool + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -135,9 +140,14 @@ type UpdateGroupInput struct { ImagePrice4K *float64 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled *bool // 是否启用模型路由 + MCPXMLInject *bool + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes *[]string // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -594,6 +604,22 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn return nil, err } } + fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest + if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 { + fallbackOnInvalidRequest = nil + } + // 校验无效请求兜底分组 + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } + + // MCPXMLInject:默认为 true,仅当显式传入 false 时关闭 + mcpXMLInject := true + if input.MCPXMLInject != nil { + mcpXMLInject = *input.MCPXMLInject + } // 如果指定了复制账号的源分组,先获取账号 ID 列表 var accountIDsToCopy []int64 @@ -628,22 +654,25 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn } group := &Group{ - Name: input.Name, - Description: input.Description, - Platform: platform, - RateMultiplier: input.RateMultiplier, - IsExclusive: input.IsExclusive, - Status: StatusActive, - SubscriptionType: subscriptionType, - DailyLimitUSD: dailyLimit, - WeeklyLimitUSD: weeklyLimit, - MonthlyLimitUSD: monthlyLimit, - ImagePrice1K: imagePrice1K, - ImagePrice2K: imagePrice2K, - ImagePrice4K: imagePrice4K, - ClaudeCodeOnly: input.ClaudeCodeOnly, - FallbackGroupID: input.FallbackGroupID, - ModelRouting: input.ModelRouting, + Name: input.Name, + Description: input.Description, + Platform: platform, + RateMultiplier: input.RateMultiplier, + IsExclusive: input.IsExclusive, + Status: StatusActive, + SubscriptionType: subscriptionType, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, + ImagePrice1K: imagePrice1K, + ImagePrice2K: imagePrice2K, + ImagePrice4K: imagePrice4K, + ClaudeCodeOnly: input.ClaudeCodeOnly, + FallbackGroupID: input.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, + ModelRouting: input.ModelRouting, + MCPXMLInject: mcpXMLInject, + SupportedModelScopes: input.SupportedModelScopes, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -714,6 +743,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro } } +// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性 +// currentGroupID: 当前分组 ID(新建时为 0) +// platform/subscriptionType: 当前分组的有效平台/订阅类型 +// fallbackGroupID: 兜底分组 ID +func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error { + if platform != PlatformAnthropic && platform != PlatformAntigravity { + return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups") + } + if subscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("subscription groups cannot set invalid request fallback") + } + if currentGroupID > 0 && currentGroupID == fallbackGroupID { + return fmt.Errorf("cannot set self as invalid request fallback group") + } + + fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID) + if err != nil { + return fmt.Errorf("fallback group not found: %w", err) + } + if fallbackGroup.Platform != PlatformAnthropic { + return fmt.Errorf("fallback group must be anthropic platform") + } + if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("fallback group cannot be subscription type") + } + if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { + return fmt.Errorf("fallback group cannot have invalid request fallback configured") + } + return nil +} + func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { @@ -780,6 +840,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.FallbackGroupID = nil } } + fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest + if input.FallbackGroupIDOnInvalidRequest != nil { + if *input.FallbackGroupIDOnInvalidRequest > 0 { + fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest + } else { + fallbackOnInvalidRequest = nil + } + } + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } + group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest // 模型路由配置 if input.ModelRouting != nil { @@ -788,6 +862,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.ModelRoutingEnabled != nil { group.ModelRoutingEnabled = *input.ModelRoutingEnabled } + if input.MCPXMLInject != nil { + group.MCPXMLInject = *input.MCPXMLInject + } + + // 支持的模型系列(仅 antigravity 平台使用) + if input.SupportedModelScopes != nil { + group.SupportedModelScopes = *input.SupportedModelScopes + } if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 1daee89f..d921a086 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -394,3 +394,382 @@ func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) { panic("unexpected GetAccountIDsByGroupIDs call") } + +type groupRepoStubForInvalidRequestFallback struct { + groups map[int64]*Group + created *Group + updated *Group +} + +func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error { + s.created = g + return nil +} + +func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error { + s.updated = g + return nil +} + +func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) { + return s.GetByIDLite(ctx, id) +} + +func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) { + if g, ok := s.groups[id]; ok { + return g, nil + } + return nil, ErrGroupNotFound +} + +func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error { + panic("unexpected Delete call") +} + +func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) { + panic("unexpected DeleteCascade call") +} + +func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) { + panic("unexpected ListActive call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) { + panic("unexpected ListActiveByPlatform call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) { + panic("unexpected ExistsByName call") +} + +func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) { + panic("unexpected GetAccountCount call") +} + +func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) { + panic("unexpected DeleteAccountGroupsByGroupID call") +} + +func (s *groupRepoStubForInvalidRequestFallback) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) { + panic("unexpected GetAccountIDsByGroupIDs call") +} + +func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error { + panic("unexpected BindAccountsToGroup call") +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformOpenAI, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeSubscription, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) { + tests := []struct { + name string + fallback *Group + wantMessage string + }{ + { + name: "openai_target", + fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard}, + wantMessage: "fallback group must be anthropic platform", + }, + { + name: "antigravity_target", + fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard}, + wantMessage: "fallback group must be anthropic platform", + }, + { + name: "subscription_group", + fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription}, + wantMessage: "fallback group cannot be subscription type", + }, + { + name: "nested_fallback", + fallback: &Group{ + ID: 10, + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(), + }, + wantMessage: "fallback group cannot have invalid request fallback configured", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fallbackID := tc.fallback.ID + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: tc.fallback, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantMessage) + require.Nil(t, repo.created) + }) + } +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{} + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "fallback group not found") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAntigravity, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) { + zero := int64(0) + repo := &groupRepoStubForInvalidRequestFallback{} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &zero, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + Platform: PlatformOpenAI, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + SubscriptionType: SubscriptionTypeSubscription, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + clear := int64(0) + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + Platform: PlatformOpenAI, + FallbackGroupIDOnInvalidRequest: &clear, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "fallback group cannot be subscription type") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAntigravity, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest) +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 9b8156e6..d16c4259 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -13,23 +13,34 @@ import ( "net" "net/http" "os" + "strconv" "strings" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/gin-gonic/gin" "github.com/google/uuid" ) const ( - antigravityStickySessionTTL = time.Hour - antigravityMaxRetries = 3 - antigravityRetryBaseDelay = 1 * time.Second - antigravityRetryMaxDelay = 16 * time.Second + antigravityStickySessionTTL = time.Hour + antigravityDefaultMaxRetries = 3 + antigravityRetryBaseDelay = 1 * time.Second + antigravityRetryMaxDelay = 16 * time.Second ) -const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT" +const ( + antigravityMaxRetriesEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES" + antigravityMaxRetriesAfterSwitchEnv = "GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES" + antigravityMaxRetriesClaudeEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE" + antigravityMaxRetriesGeminiTextEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT" + antigravityMaxRetriesGeminiImageEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE" + antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT" + antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL" + antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" +) // antigravityRetryLoopParams 重试循环的参数 type antigravityRetryLoopParams struct { @@ -41,6 +52,7 @@ type antigravityRetryLoopParams struct { action string body []byte quotaScope AntigravityQuotaScope + maxRetries int c *gin.Context httpUpstream HTTPUpstream settingService *SettingService @@ -52,11 +64,28 @@ type antigravityRetryLoopResult struct { resp *http.Response } +// PromptTooLongError 表示上游明确返回 prompt too long +type PromptTooLongError struct { + StatusCode int + RequestID string + Body []byte +} + +func (e *PromptTooLongError) Error() string { + return fmt.Sprintf("prompt too long: status=%d", e.StatusCode) +} + // antigravityRetryLoop 执行带 URL fallback 的重试循环 func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { - availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + baseURLs := antigravity.ForwardBaseURLs() + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLsWithBase(baseURLs) if len(availableURLs) == 0 { - availableURLs = antigravity.BaseURLs + availableURLs = baseURLs + } + + maxRetries := p.maxRetries + if maxRetries <= 0 { + maxRetries = antigravityDefaultMaxRetries } var resp *http.Response @@ -76,7 +105,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe urlFallbackLoop: for urlIdx, baseURL := range availableURLs { usedBaseURL = baseURL - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + for attempt := 1; attempt <= maxRetries; attempt++ { select { case <-p.ctx.Done(): log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err()) @@ -109,8 +138,8 @@ urlFallbackLoop: log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) continue urlFallbackLoop } - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err) + if attempt < maxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, maxRetries, err) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { log.Printf("%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() @@ -134,7 +163,7 @@ urlFallbackLoop: } // 账户/模型配额限流,重试 3 次(指数退避) - if attempt < antigravityMaxRetries { + if attempt < maxRetries { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ @@ -147,7 +176,7 @@ urlFallbackLoop: Message: upstreamMsg, Detail: getUpstreamDetail(respBody), }) - log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, maxRetries, truncateForLog(respBody, 200)) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { log.Printf("%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() @@ -171,7 +200,7 @@ urlFallbackLoop: respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() - if attempt < antigravityMaxRetries { + if attempt < maxRetries { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ @@ -184,7 +213,7 @@ urlFallbackLoop: Message: upstreamMsg, Detail: getUpstreamDetail(respBody), }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, maxRetries, truncateForLog(respBody, 500)) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { log.Printf("%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() @@ -390,6 +419,11 @@ type TestConnectionResult struct { // TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) // 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { + // 上游透传账号使用专用测试方法 + if account.Type == AccountTypeUpstream { + return s.testUpstreamConnection(ctx, account, modelID) + } + // 获取 token if s.tokenProvider == nil { return nil, errors.New("antigravity token provider not configured") @@ -484,6 +518,87 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account return nil, lastErr } +// testUpstreamConnection 测试上游透传账号连接 +func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if baseURL == "" || apiKey == "" { + return nil, errors.New("upstream account missing base_url or api_key") + } + baseURL = strings.TrimSuffix(baseURL, "/") + + // 使用 Claude 模型进行测试 + if modelID == "" { + modelID = "claude-sonnet-4-20250514" + } + + // 构建最小测试请求 + testReq := map[string]any{ + "model": modelID, + "max_tokens": 1, + "messages": []map[string]any{ + {"role": "user", "content": "."}, + }, + } + requestBody, err := json.Marshal(testReq) + if err != nil { + return nil, fmt.Errorf("构建请求失败: %w", err) + } + + // 构建 HTTP 请求 + upstreamURL := baseURL + "/v1/messages" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, upstreamURL) + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, fmt.Errorf("请求失败: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + } + + // 提取响应文本 + var respData map[string]any + text := "" + if json.Unmarshal(respBody, &respData) == nil { + if content, ok := respData["content"].([]any); ok && len(content) > 0 { + if block, ok := content[0].(map[string]any); ok { + if t, ok := block["text"].(string); ok { + text = t + } + } + } + } + + return &TestConnectionResult{ + Text: text, + MappedModel: modelID, + }, nil +} + // buildGeminiTestRequest 构建 Gemini 格式测试请求 // 使用最小 token 消耗:输入 "." + maxOutputTokens: 1 func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) { @@ -534,6 +649,10 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex } opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx) opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx) + + if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && group != nil { + opts.EnableMCPXML = group.MCPXMLInject + } return opts } @@ -702,6 +821,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool { // Forward 转发 Claude 协议请求(Claude → Gemini 转换) func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + // 上游透传账号直接转发,不走 OAuth token 刷新 + if account.Type == AccountTypeUpstream { + return s.ForwardUpstream(ctx, c, account, body) + } + startTime := time.Now() sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -718,6 +842,12 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, originalModel := claudeReq.Model mappedModel := s.getMappedModel(account, claudeReq.Model) quotaScope, _ := resolveAntigravityQuotaScope(originalModel) + billingModel := originalModel + if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" { + billingModel = mappedModel + } + afterSwitch := antigravityHasAccountSwitch(ctx) + maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch) // 获取 access_token if s.tokenProvider == nil { @@ -766,6 +896,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, httpUpstream: s.httpUpstream, settingService: s.settingService, handleError: s.handleUpstreamError, + maxRetries: maxRetries, }) if err != nil { return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") @@ -842,6 +973,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, httpUpstream: s.httpUpstream, settingService: s.settingService, handleError: s.handleUpstreamError, + maxRetries: maxRetries, }) if retryErr != nil { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ @@ -917,6 +1049,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 处理错误响应(重试后仍失败或不触发重试) if resp.StatusCode >= 400 { + if resp.StatusCode == http.StatusBadRequest { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500)) + } + if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody + maxBytes := 2048 + if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + } + upstreamDetail := "" + if logBody { + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "prompt_too_long", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &PromptTooLongError{ + StatusCode: resp.StatusCode, + RequestID: resp.Header.Get("x-request-id"), + Body: respBody, + } + } s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) if s.shouldFailoverUpstreamError(resp.StatusCode) { @@ -978,7 +1143,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 + Model: billingModel, // 计费模型(可按映射模型覆盖) Stream: claudeReq.Stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -1006,21 +1171,55 @@ func isSignatureRelatedError(respBody []byte) bool { return false } +func isPromptTooLongError(respBody []byte) bool { + msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) + if msg == "" { + msg = strings.ToLower(string(respBody)) + } + return strings.Contains(msg, "prompt is too long") +} + func extractAntigravityErrorMessage(body []byte) string { var payload map[string]any if err := json.Unmarshal(body, &payload); err != nil { return "" } + parseNestedMessage := func(msg string) string { + trimmed := strings.TrimSpace(msg) + if trimmed == "" || !strings.HasPrefix(trimmed, "{") { + return "" + } + var nested map[string]any + if err := json.Unmarshal([]byte(trimmed), &nested); err != nil { + return "" + } + if errObj, ok := nested["error"].(map[string]any); ok { + if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + } + if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + return "" + } + // Google-style: {"error": {"message": "..."}} if errObj, ok := payload["error"].(map[string]any); ok { if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" { + if innerMsg := parseNestedMessage(msg); innerMsg != "" { + return innerMsg + } return msg } } // Fallback: top-level message if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" { + if innerMsg := parseNestedMessage(msg); innerMsg != "" { + return innerMsg + } return msg } @@ -1248,6 +1447,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque return changed, nil } +// ForwardUpstream 透传请求到上游 Antigravity 服务 +// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token +func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) + + // 获取上游配置 + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if baseURL == "" || apiKey == "" { + return nil, fmt.Errorf("upstream account missing base_url or api_key") + } + baseURL = strings.TrimSuffix(baseURL, "/") + + // 解析请求获取模型信息 + var claudeReq antigravity.ClaudeRequest + if err := json.Unmarshal(body, &claudeReq); err != nil { + return nil, fmt.Errorf("parse claude request: %w", err) + } + if strings.TrimSpace(claudeReq.Model) == "" { + return nil, fmt.Errorf("missing model") + } + originalModel := claudeReq.Model + billingModel := originalModel + + // 构建上游请求 URL + upstreamURL := baseURL + "/v1/messages" + + // 创建请求 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create upstream request: %w", err) + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) // Claude API 兼容 + + // 透传 Claude 相关 headers + if v := c.GetHeader("anthropic-version"); v != "" { + req.Header.Set("anthropic-version", v) + } + if v := c.GetHeader("anthropic-beta"); v != "" { + req.Header.Set("anthropic-beta", v) + } + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + log.Printf("%s upstream request failed: %v", prefix, err) + return nil, fmt.Errorf("upstream request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // 处理错误响应 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + // 429 错误时标记账号限流 + if resp.StatusCode == http.StatusTooManyRequests { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude) + } + + // 透传上游错误 + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(resp.StatusCode) + _, _ = c.Writer.Write(respBody) + + return &ForwardResult{ + Model: billingModel, + }, nil + } + + // 处理成功响应(流式/非流式) + var usage *ClaudeUsage + var firstTokenMs *int + + if claudeReq.Stream { + // 流式响应:透传 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + + usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime) + } else { + // 非流式响应:直接透传 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read upstream response: %w", err) + } + + // 提取 usage + usage = s.extractClaudeUsage(respBody) + + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(http.StatusOK) + _, _ = c.Writer.Write(respBody) + } + + // 构建计费结果 + duration := time.Since(startTime) + log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds()) + + return &ForwardResult{ + Model: billingModel, + Stream: claudeReq.Stream, + Duration: duration, + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + }, + }, nil +} + +// streamUpstreamResponse 透传上游流式响应并提取 usage +func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) { + usage := &ClaudeUsage{} + var firstTokenMs *int + var firstTokenRecorded bool + + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 1024*1024) + + for scanner.Scan() { + line := scanner.Bytes() + + // 记录首 token 时间 + if !firstTokenRecorded && len(line) > 0 { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + firstTokenRecorded = true + } + + // 尝试从 message_delta 或 message_stop 事件提取 usage + if bytes.HasPrefix(line, []byte("data: ")) { + dataStr := bytes.TrimPrefix(line, []byte("data: ")) + var event map[string]any + if json.Unmarshal(dataStr, &event) == nil { + if u, ok := event["usage"].(map[string]any); ok { + if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheCreationInputTokens = int(v) + } + } + } + } + + // 透传行 + _, _ = c.Writer.Write(line) + _, _ = c.Writer.Write([]byte("\n")) + c.Writer.Flush() + } + + return usage, firstTokenMs +} + +// extractClaudeUsage 从非流式 Claude 响应提取 usage +func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + var resp map[string]any + if json.Unmarshal(body, &resp) != nil { + return usage + } + if u, ok := resp["usage"].(map[string]any); ok { + if v, ok := u["input_tokens"].(float64); ok { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok { + usage.CacheCreationInputTokens = int(v) + } + } + return usage +} + // ForwardGemini 转发 Gemini 协议请求 func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { startTime := time.Now() @@ -1287,6 +1688,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co } mappedModel := s.getMappedModel(account, originalModel) + billingModel := originalModel + if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" { + billingModel = mappedModel + } + afterSwitch := antigravityHasAccountSwitch(ctx) + maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch) // 获取 access_token if s.tokenProvider == nil { @@ -1306,8 +1713,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co proxyURL = account.Proxy.URL() } + // 过滤掉 parts 为空的消息(Gemini API 不接受空 parts) + filteredBody, err := filterEmptyPartsFromGeminiRequest(body) + if err != nil { + log.Printf("[Antigravity] Failed to filter empty parts: %v", err) + filteredBody = body + } + // Antigravity 上游要求必须包含身份提示词,注入到请求中 - injectedBody, err := injectIdentityPatchToGeminiRequest(body) + injectedBody, err := injectIdentityPatchToGeminiRequest(filteredBody) if err != nil { return nil, err } @@ -1344,6 +1758,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co httpUpstream: s.httpUpstream, settingService: s.settingService, handleError: s.handleUpstreamError, + maxRetries: maxRetries, }) if err != nil { return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") @@ -1493,7 +1908,7 @@ handleSuccess: return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: originalModel, + Model: billingModel, Stream: stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -1544,6 +1959,81 @@ func antigravityUseScopeRateLimit() bool { return true } +func antigravityHasAccountSwitch(ctx context.Context) bool { + if ctx == nil { + return false + } + if v, ok := ctx.Value(ctxkey.AccountSwitchCount).(int); ok { + return v > 0 + } + return false +} + +func antigravityMaxRetries() int { + raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesEnv)) + if raw == "" { + return antigravityDefaultMaxRetries + } + value, err := strconv.Atoi(raw) + if err != nil || value <= 0 { + return antigravityDefaultMaxRetries + } + return value +} + +func antigravityMaxRetriesAfterSwitch() int { + raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesAfterSwitchEnv)) + if raw == "" { + return antigravityMaxRetries() + } + value, err := strconv.Atoi(raw) + if err != nil || value <= 0 { + return antigravityMaxRetries() + } + return value +} + +// antigravityMaxRetriesForModel 根据模型类型获取重试次数 +// 优先使用模型细分配置,未设置则回退到平台级配置 +func antigravityMaxRetriesForModel(model string, afterSwitch bool) int { + var envKey string + if strings.HasPrefix(model, "claude-") { + envKey = antigravityMaxRetriesClaudeEnv + } else if isImageGenerationModel(model) { + envKey = antigravityMaxRetriesGeminiImageEnv + } else if strings.HasPrefix(model, "gemini-") { + envKey = antigravityMaxRetriesGeminiTextEnv + } + + if envKey != "" { + if raw := strings.TrimSpace(os.Getenv(envKey)); raw != "" { + if value, err := strconv.Atoi(raw); err == nil && value > 0 { + return value + } + } + } + if afterSwitch { + return antigravityMaxRetriesAfterSwitch() + } + return antigravityMaxRetries() +} + +func antigravityUseMappedModelForBilling() bool { + v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityBillingModelEnv))) + return v == "1" || v == "true" || v == "yes" || v == "on" +} + +func antigravityFallbackCooldownSeconds() (time.Duration, bool) { + raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv)) + if raw == "" { + return 0, false + } + seconds, err := strconv.Atoi(raw) + if err != nil || seconds <= 0 { + return 0, false + } + return time.Duration(seconds) * time.Second, true +} func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { // 429 使用 Gemini 格式解析(从 body 解析重置时间) if statusCode == 429 { @@ -1556,6 +2046,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes } defaultDur := time.Duration(fallbackMinutes) * time.Minute + if fallbackDur, ok := antigravityFallbackCooldownSeconds(); ok { + defaultDur = fallbackDur + } ra := time.Now().Add(defaultDur) if useScopeLimit { log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) @@ -2193,6 +2686,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg) } +func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error { + return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body) +} + func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error { statusStr := "UNKNOWN" switch status { @@ -2618,3 +3115,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) { return json.Marshal(payload) } + +// filterEmptyPartsFromGeminiRequest 过滤 Gemini 请求中 parts 为空的消息 +// Gemini API 不接受 parts 为空数组的消息,会返回 400 错误 +func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return nil, err + } + + contents, ok := payload["contents"].([]any) + if !ok || len(contents) == 0 { + return body, nil + } + + filtered := make([]any, 0, len(contents)) + modified := false + + for _, c := range contents { + contentMap, ok := c.(map[string]any) + if !ok { + filtered = append(filtered, c) + continue + } + + parts, hasParts := contentMap["parts"] + if !hasParts { + filtered = append(filtered, c) + continue + } + + partsSlice, ok := parts.([]any) + if !ok { + filtered = append(filtered, c) + continue + } + + // 跳过 parts 为空数组的消息 + if len(partsSlice) == 0 { + modified = true + continue + } + + filtered = append(filtered, c) + } + + if !modified { + return body, nil + } + + payload["contents"] = filtered + return json.Marshal(payload) +} diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 05ad9bbd..32a591ef 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -1,10 +1,16 @@ package service import ( + "bytes" + "context" "encoding/json" + "io" + "net/http" + "net/http/httptest" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -81,3 +87,106 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) { require.Equal(t, "secret plan", blocks[0]["text"]) require.Equal(t, "tool_use", blocks[1]["type"]) } + +func TestIsPromptTooLongError(t *testing.T) { + require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`))) + require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`))) + require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`))) +} + +type httpUpstreamStub struct { + resp *http.Response + err error +} + +func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return s.resp, s.err +} + +func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { + return s.resp, s.err +} + +func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "model": "claude-opus-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + "max_tokens": 1, + "stream": false, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request = req + + respBody := []byte(`{"error":{"message":"Prompt is too long"}}`) + resp := &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"X-Request-Id": []string{"req-1"}}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + account := &Account{ + ID: 1, + Name: "acc-1", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + }, + } + + result, err := svc.Forward(context.Background(), c, account, body) + require.Nil(t, result) + + var promptErr *PromptTooLongError + require.ErrorAs(t, err, &promptErr) + require.Equal(t, http.StatusBadRequest, promptErr.StatusCode) + require.Equal(t, "req-1", promptErr.RequestID) + require.NotEmpty(t, promptErr.Body) + + raw, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := raw.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, "prompt_too_long", events[0].Kind) +} + +func TestAntigravityMaxRetriesForModel_AfterSwitch(t *testing.T) { + t.Setenv(antigravityMaxRetriesEnv, "4") + t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "7") + t.Setenv(antigravityMaxRetriesClaudeEnv, "") + t.Setenv(antigravityMaxRetriesGeminiTextEnv, "") + t.Setenv(antigravityMaxRetriesGeminiImageEnv, "") + + got := antigravityMaxRetriesForModel("claude-sonnet-4-5", false) + require.Equal(t, 4, got) + + got = antigravityMaxRetriesForModel("claude-sonnet-4-5", true) + require.Equal(t, 7, got) +} + +func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) { + t.Setenv(antigravityMaxRetriesEnv, "5") + t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "") + t.Setenv(antigravityMaxRetriesClaudeEnv, "") + t.Setenv(antigravityMaxRetriesGeminiTextEnv, "") + t.Setenv(antigravityMaxRetriesGeminiImageEnv, "") + + got := antigravityMaxRetriesForModel("gemini-2.5-flash", true) + require.Equal(t, 5, got) +} diff --git a/backend/internal/service/antigravity_quota_scope.go b/backend/internal/service/antigravity_quota_scope.go index 34cd9a4c..e1a0a1f2 100644 --- a/backend/internal/service/antigravity_quota_scope.go +++ b/backend/internal/service/antigravity_quota_scope.go @@ -1,6 +1,7 @@ package service import ( + "slices" "strings" "time" ) @@ -16,6 +17,21 @@ const ( AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image" ) +// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中 +func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool { + if len(supportedScopes) == 0 { + // 未配置时默认全部支持 + return true + } + supported := slices.Contains(supportedScopes, string(scope)) + return supported +} + +// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本) +func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { + return resolveAntigravityQuotaScope(requestedModel) +} + // resolveAntigravityQuotaScope 根据模型名称解析配额域 func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { model := normalizeAntigravityModelName(requestedModel) diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 6cbeb98a..d15b5817 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -32,25 +32,30 @@ type APIKeyAuthUserSnapshot struct { // APIKeyAuthGroupSnapshot 分组快照 type APIKeyAuthGroupSnapshot struct { - ID int64 `json:"id"` - Name string `json:"name"` - Platform string `json:"platform"` - Status string `json:"status"` - SubscriptionType string `json:"subscription_type"` - RateMultiplier float64 `json:"rate_multiplier"` - DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` - WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` - MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` - ImagePrice1K *float64 `json:"image_price_1k,omitempty"` - ImagePrice2K *float64 `json:"image_price_2k,omitempty"` - ImagePrice4K *float64 `json:"image_price_4k,omitempty"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + Status string `json:"status"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + ImagePrice1K *float64 `json:"image_price_1k,omitempty"` + ImagePrice2K *float64 `json:"image_price_2k,omitempty"` + ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` // Model routing is used by gateway account selection, so it must be part of auth cache snapshot. // Only anthropic groups use these fields; others may leave them empty. ModelRouting map[string][]int64 `json:"model_routing,omitempty"` ModelRoutingEnabled bool `json:"model_routing_enabled"` + MCPXMLInject bool `json:"mcp_xml_inject"` + + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` } // APIKeyAuthCacheEntry 缓存条目,支持负缓存 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 979ff77d..f5bba7d0 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -226,22 +226,25 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { } if apiKey.Group != nil { snapshot.Group = &APIKeyAuthGroupSnapshot{ - ID: apiKey.Group.ID, - Name: apiKey.Group.Name, - Platform: apiKey.Group.Platform, - Status: apiKey.Group.Status, - SubscriptionType: apiKey.Group.SubscriptionType, - RateMultiplier: apiKey.Group.RateMultiplier, - DailyLimitUSD: apiKey.Group.DailyLimitUSD, - WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, - MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, - ImagePrice1K: apiKey.Group.ImagePrice1K, - ImagePrice2K: apiKey.Group.ImagePrice2K, - ImagePrice4K: apiKey.Group.ImagePrice4K, - ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, - FallbackGroupID: apiKey.Group.FallbackGroupID, - ModelRouting: apiKey.Group.ModelRouting, - ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, + ID: apiKey.Group.ID, + Name: apiKey.Group.Name, + Platform: apiKey.Group.Platform, + Status: apiKey.Group.Status, + SubscriptionType: apiKey.Group.SubscriptionType, + RateMultiplier: apiKey.Group.RateMultiplier, + DailyLimitUSD: apiKey.Group.DailyLimitUSD, + WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, + MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + ImagePrice1K: apiKey.Group.ImagePrice1K, + ImagePrice2K: apiKey.Group.ImagePrice2K, + ImagePrice4K: apiKey.Group.ImagePrice4K, + ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, + FallbackGroupID: apiKey.Group.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest, + ModelRouting: apiKey.Group.ModelRouting, + ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, + MCPXMLInject: apiKey.Group.MCPXMLInject, + SupportedModelScopes: apiKey.Group.SupportedModelScopes, } } return snapshot @@ -272,23 +275,26 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho } if snapshot.Group != nil { apiKey.Group = &Group{ - ID: snapshot.Group.ID, - Name: snapshot.Group.Name, - Platform: snapshot.Group.Platform, - Status: snapshot.Group.Status, - Hydrated: true, - SubscriptionType: snapshot.Group.SubscriptionType, - RateMultiplier: snapshot.Group.RateMultiplier, - DailyLimitUSD: snapshot.Group.DailyLimitUSD, - WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, - MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, - ImagePrice1K: snapshot.Group.ImagePrice1K, - ImagePrice2K: snapshot.Group.ImagePrice2K, - ImagePrice4K: snapshot.Group.ImagePrice4K, - ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, - FallbackGroupID: snapshot.Group.FallbackGroupID, - ModelRouting: snapshot.Group.ModelRouting, - ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, + ID: snapshot.Group.ID, + Name: snapshot.Group.Name, + Platform: snapshot.Group.Platform, + Status: snapshot.Group.Status, + Hydrated: true, + SubscriptionType: snapshot.Group.SubscriptionType, + RateMultiplier: snapshot.Group.RateMultiplier, + DailyLimitUSD: snapshot.Group.DailyLimitUSD, + WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, + MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + ImagePrice1K: snapshot.Group.ImagePrice1K, + ImagePrice2K: snapshot.Group.ImagePrice2K, + ImagePrice4K: snapshot.Group.ImagePrice4K, + ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, + FallbackGroupID: snapshot.Group.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest, + ModelRouting: snapshot.Group.ModelRouting, + ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, + MCPXMLInject: snapshot.Group.MCPXMLInject, + SupportedModelScopes: snapshot.Group.SupportedModelScopes, } } return apiKey diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index c824ec1e..25604d2c 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -185,7 +185,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) } } - // 应用优惠码(如果提供且功能已启用) if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 2db72825..0295c23b 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -31,6 +31,7 @@ const ( AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 + AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) ) // Redeem type constants diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index b5827a9e..9b31a9c6 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -70,6 +70,15 @@ func shortSessionHash(sessionHash string) string { return sessionHash[:8] } +func normalizeClaudeModelForAnthropic(requestedModel string) string { + for _, prefix := range anthropicPrefixMappings { + if strings.HasPrefix(requestedModel, prefix) { + return prefix + } + } + return requestedModel +} + func redactAuthHeaderValue(v string) string { v = strings.TrimSpace(v) if v == "" { @@ -252,11 +261,20 @@ var ( "You are a file search specialist for Claude Code", // Explore Agent 版 "You are a helpful AI assistant tasked with summarizing conversations", // Compact 版 } + + anthropicPrefixMappings = []string{ + "claude-opus-4-5", + "claude-haiku-4-5", + "claude-sonnet-4-5", + } ) // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") +// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内 +var ErrModelScopeNotSupported = errors.New("model scope not supported by this group") + // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -1135,6 +1153,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) } + // Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查) + if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { + if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { + return nil, err + } + } + accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err != nil { return nil, err @@ -1632,6 +1657,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (* return group, nil } +func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { + return s.resolveGroupByID(ctx, groupID) +} + func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 { if groupID == nil || requestedModel == "" || platform != PlatformAnthropic { return nil @@ -1697,7 +1726,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID } // 强制平台模式不检查 Claude Code 限制 - if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform { + if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" { return nil, groupID, nil } @@ -2026,6 +2055,13 @@ func shuffleWithinPriority(accounts []*Account) { // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { + // 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内 + if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { + if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { + return nil, err + } + } + preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) @@ -2461,6 +2497,9 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo // Antigravity 平台使用专门的模型支持检查 return IsAntigravityModelSupported(requestedModel) } + if account.Platform == PlatformAnthropic { + requestedModel = normalizeClaudeModelForAnthropic(requestedModel) + } // Gemini API Key 账户直接透传,由上游判断模型是否支持 if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey { return true @@ -2910,16 +2949,28 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) - // 应用模型映射(仅对apikey类型账号) + // 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射) + mappedModel := reqModel + mappingSource := "" if account.Type == AccountTypeAPIKey { - mappedModel := account.GetMappedModel(reqModel) + mappedModel = account.GetMappedModel(reqModel) if mappedModel != reqModel { - // 替换请求体中的模型名 - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name) + mappingSource = "account" } } + if mappingSource == "" && account.Platform == PlatformAnthropic { + normalized := normalizeClaudeModelForAnthropic(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + // 替换请求体中的模型名 + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) + } // 获取凭证 token, tokenType, err := s.GetAccessToken(ctx, account) @@ -4842,16 +4893,28 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return nil } - // 应用模型映射(仅对 apikey 类型账号) - if account.Type == AccountTypeAPIKey { - if reqModel != "" { - mappedModel := account.GetMappedModel(reqModel) + // 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射) + if reqModel != "" { + mappedModel := reqModel + mappingSource := "" + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(reqModel) if mappedModel != reqModel { - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) + mappingSource = "account" } } + if mappingSource == "" && account.Platform == PlatformAnthropic { + normalized := normalizeClaudeModelForAnthropic(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) + } } // 获取凭证 @@ -5103,6 +5166,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { return normalized, nil } +// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内 +func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error { + scope, ok := ResolveAntigravityQuotaScope(requestedModel) + if !ok { + return nil // 无法解析 scope,跳过检查 + } + + group, err := s.resolveGroupByID(ctx, groupID) + if err != nil { + return nil // 查询失败时放行 + } + if group == nil { + return nil // 分组不存在时放行 + } + + if !IsScopeSupported(group.SupportedModelScopes, scope) { + return ErrModelScopeNotSupported + } + return nil +} + // GetAvailableModels returns the list of models available for a group // It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 2d2e86d5..bd322991 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -977,6 +977,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") } + // 过滤掉 parts 为空的消息(Gemini API 不接受空 parts) + if filteredBody, err := filterEmptyPartsFromGeminiRequest(body); err == nil { + body = filteredBody + } + switch action { case "generateContent", "streamGenerateContent", "countTokens": // ok diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index d6d1269b..1302047a 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -29,6 +29,8 @@ type Group struct { // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 + // 无效请求兜底分组(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置 // key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*") @@ -36,6 +38,13 @@ type Group struct { ModelRouting map[string][]int64 ModelRoutingEnabled bool + // MCP XML 协议注入开关(仅 antigravity 平台使用) + MCPXMLInject bool + + // 支持的模型系列(仅 antigravity 平台使用) + // 可选值: claude, gemini_text, gemini_image + SupportedModelScopes []string + CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/ops_metrics_collector.go b/backend/internal/service/ops_metrics_collector.go index edf32cf2..30adaae0 100644 --- a/backend/internal/service/ops_metrics_collector.go +++ b/backend/internal/service/ops_metrics_collector.go @@ -285,6 +285,11 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error { return fmt.Errorf("query error counts: %w", err) } + accountSwitchCount, err := c.queryAccountSwitchCount(ctx, windowStart, windowEnd) + if err != nil { + return fmt.Errorf("query account switch counts: %w", err) + } + windowSeconds := windowEnd.Sub(windowStart).Seconds() if windowSeconds <= 0 { windowSeconds = 60 @@ -309,9 +314,10 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error { Upstream429Count: upstream429, Upstream529Count: upstream529, - TokenConsumed: tokenConsumed, - QPS: float64Ptr(roundTo1DP(qps)), - TPS: float64Ptr(roundTo1DP(tps)), + TokenConsumed: tokenConsumed, + AccountSwitchCount: accountSwitchCount, + QPS: float64Ptr(roundTo1DP(qps)), + TPS: float64Ptr(roundTo1DP(tps)), DurationP50Ms: duration.p50, DurationP90Ms: duration.p90, @@ -551,6 +557,27 @@ WHERE created_at >= $1 AND created_at < $2` return errorTotal, businessLimited, errorSLA, upstreamExcl429529, upstream429, upstream529, nil } +func (c *OpsMetricsCollector) queryAccountSwitchCount(ctx context.Context, start, end time.Time) (int64, error) { + q := ` +SELECT + COALESCE(SUM(CASE + WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1 + ELSE 0 + END), 0) AS switch_count +FROM ops_error_logs o +CROSS JOIN LATERAL jsonb_array_elements( + COALESCE(NULLIF(o.upstream_errors, 'null'::jsonb), '[]'::jsonb) +) AS ev +WHERE o.created_at >= $1 AND o.created_at < $2 + AND o.is_count_tokens = FALSE` + + var count int64 + if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&count); err != nil { + return 0, err + } + return count, nil +} + type opsCollectedSystemStats struct { cpuUsagePercent *float64 memoryUsedMB *int64 diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index 515b47bb..347b06b5 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -161,7 +161,8 @@ type OpsInsertSystemMetricsInput struct { Upstream429Count int64 Upstream529Count int64 - TokenConsumed int64 + TokenConsumed int64 + AccountSwitchCount int64 QPS *float64 TPS *float64 @@ -223,8 +224,9 @@ type OpsSystemMetricsSnapshot struct { DBConnIdle *int `json:"db_conn_idle"` DBConnWaiting *int `json:"db_conn_waiting"` - GoroutineCount *int `json:"goroutine_count"` - ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"` + GoroutineCount *int `json:"goroutine_count"` + ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"` + AccountSwitchCount *int64 `json:"account_switch_count"` } type OpsUpsertJobHeartbeatInput struct { diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index 8d98e43f..ffe4c934 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" "github.com/lib/pq" @@ -476,9 +477,13 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq continue } + attemptCtx := ctx + if switches > 0 { + attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches) + } exec := func() *opsRetryExecution { defer selection.ReleaseFunc() - return s.executeWithAccount(ctx, reqType, errorLog, body, account) + return s.executeWithAccount(attemptCtx, reqType, errorLog, body, account) }() if exec != nil { diff --git a/backend/internal/service/ops_trend_models.go b/backend/internal/service/ops_trend_models.go index f6d07c14..97bbfebe 100644 --- a/backend/internal/service/ops_trend_models.go +++ b/backend/internal/service/ops_trend_models.go @@ -6,6 +6,7 @@ type OpsThroughputTrendPoint struct { BucketStart time.Time `json:"bucket_start"` RequestCount int64 `json:"request_count"` TokenConsumed int64 `json:"token_consumed"` + SwitchCount int64 `json:"switch_count"` QPS float64 `json:"qps"` TPS float64 `json:"tps"` } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 99bf7fd0..1bfb392e 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -39,7 +39,7 @@ type UserRepository interface { ExistsByEmail(ctx context.Context, email string) (bool, error) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) - // TOTP 相关方法 + // TOTP 双因素认证 UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error EnableTotp(ctx context.Context, userID int64) error DisableTotp(ctx context.Context, userID int64) error diff --git a/backend/migrations/042b_add_ops_system_metrics_switch_count.sql b/backend/migrations/042b_add_ops_system_metrics_switch_count.sql new file mode 100644 index 00000000..6d9f48e5 --- /dev/null +++ b/backend/migrations/042b_add_ops_system_metrics_switch_count.sql @@ -0,0 +1,3 @@ +-- ops_system_metrics 增加账号切换次数统计(按分钟窗口) +ALTER TABLE ops_system_metrics + ADD COLUMN IF NOT EXISTS account_switch_count BIGINT NOT NULL DEFAULT 0; diff --git a/backend/migrations/043b_add_group_invalid_request_fallback.sql b/backend/migrations/043b_add_group_invalid_request_fallback.sql new file mode 100644 index 00000000..1c792704 --- /dev/null +++ b/backend/migrations/043b_add_group_invalid_request_fallback.sql @@ -0,0 +1,13 @@ +-- 043_add_group_invalid_request_fallback.sql +-- 添加无效请求兜底分组配置 + +-- 添加 fallback_group_id_on_invalid_request 字段:无效请求兜底使用的分组 +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS fallback_group_id_on_invalid_request BIGINT REFERENCES groups(id) ON DELETE SET NULL; + +-- 添加索引优化查询 +CREATE INDEX IF NOT EXISTS idx_groups_fallback_group_id_on_invalid_request +ON groups(fallback_group_id_on_invalid_request) WHERE deleted_at IS NULL AND fallback_group_id_on_invalid_request IS NOT NULL; + +-- 添加字段注释 +COMMENT ON COLUMN groups.fallback_group_id_on_invalid_request IS '无效请求兜底使用的分组 ID'; diff --git a/backend/migrations/044b_add_group_mcp_xml_inject.sql b/backend/migrations/044b_add_group_mcp_xml_inject.sql new file mode 100644 index 00000000..7db71dd8 --- /dev/null +++ b/backend/migrations/044b_add_group_mcp_xml_inject.sql @@ -0,0 +1,2 @@ +-- Add mcp_xml_inject field to groups table (for antigravity platform) +ALTER TABLE groups ADD COLUMN mcp_xml_inject BOOLEAN NOT NULL DEFAULT true; diff --git a/backend/migrations/046b_add_group_supported_model_scopes.sql b/backend/migrations/046b_add_group_supported_model_scopes.sql new file mode 100644 index 00000000..0b2b3968 --- /dev/null +++ b/backend/migrations/046b_add_group_supported_model_scopes.sql @@ -0,0 +1,6 @@ +-- 添加分组支持的模型系列字段 +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS supported_model_scopes JSONB NOT NULL +DEFAULT '["claude", "gemini_text", "gemini_image"]'::jsonb; + +COMMENT ON COLUMN groups.supported_model_scopes IS '支持的模型系列:claude, gemini_text, gemini_image'; diff --git a/docs/rename_local_migrations_20260202.sql b/docs/rename_local_migrations_20260202.sql new file mode 100644 index 00000000..911ed17d --- /dev/null +++ b/docs/rename_local_migrations_20260202.sql @@ -0,0 +1,34 @@ +-- 修正 schema_migrations 中“本地改名”的迁移文件名 +-- 适用场景:你已执行过旧文件名的迁移,合并后仅改了自己这边的文件名 + +BEGIN; + +UPDATE schema_migrations +SET filename = '042b_add_ops_system_metrics_switch_count.sql' +WHERE filename = '042_add_ops_system_metrics_switch_count.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '042b_add_ops_system_metrics_switch_count.sql' + ); + +UPDATE schema_migrations +SET filename = '043b_add_group_invalid_request_fallback.sql' +WHERE filename = '043_add_group_invalid_request_fallback.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '043b_add_group_invalid_request_fallback.sql' + ); + +UPDATE schema_migrations +SET filename = '044b_add_group_mcp_xml_inject.sql' +WHERE filename = '044_add_group_mcp_xml_inject.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '044b_add_group_mcp_xml_inject.sql' + ); + +UPDATE schema_migrations +SET filename = '046b_add_group_supported_model_scopes.sql' +WHERE filename = '046_add_group_supported_model_scopes.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '046b_add_group_supported_model_scopes.sql' + ); + +COMMIT; diff --git a/frontend/src/api/admin/ops.ts b/frontend/src/api/admin/ops.ts index bf2c246c..a1c41e8c 100644 --- a/frontend/src/api/admin/ops.ts +++ b/frontend/src/api/admin/ops.ts @@ -136,6 +136,7 @@ export interface OpsThroughputTrendPoint { bucket_start: string request_count: number token_consumed: number + switch_count?: number qps: number tps: number } @@ -284,6 +285,7 @@ export interface OpsSystemMetricsSnapshot { goroutine_count?: number | null concurrency_queue_depth?: number | null + account_switch_count?: number | null } export interface OpsJobHeartbeat { diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index 8e525fa3..8dcddff7 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -56,7 +56,6 @@ > -
- +
-
-
+ + + +
+
+ + +
+
+ + +

{{ t('admin.accounts.upstream.baseUrlHint') }}

+
+
+ + +

{{ t('admin.accounts.upstream.apiKeyHint') }}

@@ -1953,6 +2019,9 @@ const customErrorCodeInput = ref(null) const interceptWarmupRequests = ref(false) const autoPauseOnExpired = ref(true) const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling +const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream +const upstreamBaseUrl = ref('') // For upstream type: base URL +const upstreamApiKey = ref('') // For upstream type: API key const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one') @@ -2055,7 +2124,13 @@ const form = reactive({ }) // Helper to check if current type needs OAuth flow -const isOAuthFlow = computed(() => accountCategory.value === 'oauth-based') +const isOAuthFlow = computed(() => { + // Antigravity upstream 类型不需要 OAuth 流程 + if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { + return false + } + return accountCategory.value === 'oauth-based' +}) const isManualInputMethod = computed(() => { return oauthFlowRef.value?.inputMethod === 'manual' @@ -2095,10 +2170,15 @@ watch( } ) -// Sync form.type based on accountCategory and addMethod +// Sync form.type based on accountCategory, addMethod, and antigravityAccountType watch( - [accountCategory, addMethod], - ([category, method]) => { + [accountCategory, addMethod, antigravityAccountType], + ([category, method, agType]) => { + // Antigravity upstream 类型 + if (form.platform === 'antigravity' && agType === 'upstream') { + form.type = 'upstream' + return + } if (category === 'oauth-based') { form.type = method as AccountType // 'oauth' or 'setup-token' } else { @@ -2126,9 +2206,10 @@ watch( if (newPlatform !== 'anthropic') { interceptWarmupRequests.value = false } - // Antigravity only supports OAuth + // Antigravity: reset to OAuth by default, but allow upstream selection if (newPlatform === 'antigravity') { accountCategory.value = 'oauth-based' + antigravityAccountType.value = 'oauth' } // Reset OAuth states oauth.resetState() @@ -2361,6 +2442,9 @@ const resetForm = () => { sessionIdleTimeout.value = null tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false + antigravityAccountType.value = 'oauth' + upstreamBaseUrl.value = '' + upstreamApiKey.value = '' tempUnschedEnabled.value = false tempUnschedRules.value = [] geminiOAuthType.value = 'code_assist' @@ -2442,6 +2526,36 @@ const handleSubmit = async () => { return } + // For Antigravity upstream type, create directly + if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { + if (!form.name.trim()) { + appStore.showError(t('admin.accounts.pleaseEnterAccountName')) + return + } + if (!upstreamBaseUrl.value.trim()) { + appStore.showError(t('admin.accounts.upstream.pleaseEnterBaseUrl')) + return + } + if (!upstreamApiKey.value.trim()) { + appStore.showError(t('admin.accounts.upstream.pleaseEnterApiKey')) + return + } + + submitting.value = true + try { + const credentials: Record = { + base_url: upstreamBaseUrl.value.trim(), + api_key: upstreamApiKey.value.trim() + } + await createAccountAndFinish(form.platform, 'upstream', credentials) + } catch (error: any) { + appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate')) + } finally { + submitting.value = false + } + return + } + // For apikey type, create directly if (!apiKeyValue.value.trim()) { appStore.showError(t('admin.accounts.pleaseEnterApiKey')) diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index 5edbd3b6..fbb1942a 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -238,14 +238,14 @@ - + + diff --git a/frontend/src/views/setup/SetupWizardView.vue b/frontend/src/views/setup/SetupWizardView.vue index f3c773ca..fcf5aa72 100644 --- a/frontend/src/views/setup/SetupWizardView.vue +++ b/frontend/src/views/setup/SetupWizardView.vue @@ -91,6 +91,18 @@
+
+
+

+ {{ t("setup.redis.enableTls") }} +

+

+ {{ t("setup.redis.enableTlsHint") }} +

+
+ +
+