Compare commits

..

1 Commits

Author SHA1 Message Date
shaw
3bddbb6afe chore: update version 2026-02-09 09:42:29 +08:00
121 changed files with 463 additions and 5473 deletions

View File

@@ -32,7 +32,7 @@ jobs:
working-directory: backend
run: |
go install github.com/securego/gosec/v2/cmd/gosec@latest
gosec -conf .gosec.json -severity high -confidence high ./...
gosec -severity high -confidence high ./...
frontend-security:
runs-on: ubuntu-latest

View File

@@ -1,5 +0,0 @@
{
"global": {
"exclude": "G704"
}
}

View File

@@ -1 +1 @@
0.1.83
0.1.76

View File

@@ -44,8 +44,6 @@ type ErrorPassthroughRule struct {
PassthroughBody bool `json:"passthrough_body,omitempty"`
// CustomMessage holds the value of the "custom_message" field.
CustomMessage *string `json:"custom_message,omitempty"`
// SkipMonitoring holds the value of the "skip_monitoring" field.
SkipMonitoring bool `json:"skip_monitoring,omitempty"`
// Description holds the value of the "description" field.
Description *string `json:"description,omitempty"`
selectValues sql.SelectValues
@@ -58,7 +56,7 @@ func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms:
values[i] = new([]byte)
case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody, errorpassthroughrule.FieldSkipMonitoring:
case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody:
values[i] = new(sql.NullBool)
case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode:
values[i] = new(sql.NullInt64)
@@ -173,12 +171,6 @@ func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) err
_m.CustomMessage = new(string)
*_m.CustomMessage = value.String
}
case errorpassthroughrule.FieldSkipMonitoring:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field skip_monitoring", values[i])
} else if value.Valid {
_m.SkipMonitoring = value.Bool
}
case errorpassthroughrule.FieldDescription:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field description", values[i])
@@ -265,9 +257,6 @@ func (_m *ErrorPassthroughRule) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("skip_monitoring=")
builder.WriteString(fmt.Sprintf("%v", _m.SkipMonitoring))
builder.WriteString(", ")
if v := _m.Description; v != nil {
builder.WriteString("description=")
builder.WriteString(*v)

View File

@@ -39,8 +39,6 @@ const (
FieldPassthroughBody = "passthrough_body"
// FieldCustomMessage holds the string denoting the custom_message field in the database.
FieldCustomMessage = "custom_message"
// FieldSkipMonitoring holds the string denoting the skip_monitoring field in the database.
FieldSkipMonitoring = "skip_monitoring"
// FieldDescription holds the string denoting the description field in the database.
FieldDescription = "description"
// Table holds the table name of the errorpassthroughrule in the database.
@@ -63,7 +61,6 @@ var Columns = []string{
FieldResponseCode,
FieldPassthroughBody,
FieldCustomMessage,
FieldSkipMonitoring,
FieldDescription,
}
@@ -98,8 +95,6 @@ var (
DefaultPassthroughCode bool
// DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field.
DefaultPassthroughBody bool
// DefaultSkipMonitoring holds the default value on creation for the "skip_monitoring" field.
DefaultSkipMonitoring bool
)
// OrderOption defines the ordering options for the ErrorPassthroughRule queries.
@@ -160,11 +155,6 @@ func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCustomMessage, opts...).ToFunc()
}
// BySkipMonitoring orders the results by the skip_monitoring field.
func BySkipMonitoring(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSkipMonitoring, opts...).ToFunc()
}
// ByDescription orders the results by the description field.
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDescription, opts...).ToFunc()

View File

@@ -104,11 +104,6 @@ func CustomMessage(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v))
}
// SkipMonitoring applies equality check predicate on the "skip_monitoring" field. It's identical to SkipMonitoringEQ.
func SkipMonitoring(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v))
}
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
func Description(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
@@ -549,16 +544,6 @@ func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v))
}
// SkipMonitoringEQ applies the EQ predicate on the "skip_monitoring" field.
func SkipMonitoringEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v))
}
// SkipMonitoringNEQ applies the NEQ predicate on the "skip_monitoring" field.
func SkipMonitoringNEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldSkipMonitoring, v))
}
// DescriptionEQ applies the EQ predicate on the "description" field.
func DescriptionEQ(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))

View File

@@ -172,20 +172,6 @@ func (_c *ErrorPassthroughRuleCreate) SetNillableCustomMessage(v *string) *Error
return _c
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (_c *ErrorPassthroughRuleCreate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleCreate {
_c.mutation.SetSkipMonitoring(v)
return _c
}
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
func (_c *ErrorPassthroughRuleCreate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleCreate {
if v != nil {
_c.SetSkipMonitoring(*v)
}
return _c
}
// SetDescription sets the "description" field.
func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate {
_c.mutation.SetDescription(v)
@@ -263,10 +249,6 @@ func (_c *ErrorPassthroughRuleCreate) defaults() {
v := errorpassthroughrule.DefaultPassthroughBody
_c.mutation.SetPassthroughBody(v)
}
if _, ok := _c.mutation.SkipMonitoring(); !ok {
v := errorpassthroughrule.DefaultSkipMonitoring
_c.mutation.SetSkipMonitoring(v)
}
}
// check runs all checks and user-defined validators on the builder.
@@ -305,9 +287,6 @@ func (_c *ErrorPassthroughRuleCreate) check() error {
if _, ok := _c.mutation.PassthroughBody(); !ok {
return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)}
}
if _, ok := _c.mutation.SkipMonitoring(); !ok {
return &ValidationError{Name: "skip_monitoring", err: errors.New(`ent: missing required field "ErrorPassthroughRule.skip_monitoring"`)}
}
return nil
}
@@ -387,10 +366,6 @@ func (_c *ErrorPassthroughRuleCreate) createSpec() (*ErrorPassthroughRule, *sqlg
_spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value)
_node.CustomMessage = &value
}
if value, ok := _c.mutation.SkipMonitoring(); ok {
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
_node.SkipMonitoring = value
}
if value, ok := _c.mutation.Description(); ok {
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
_node.Description = &value
@@ -633,18 +608,6 @@ func (u *ErrorPassthroughRuleUpsert) ClearCustomMessage() *ErrorPassthroughRuleU
return u
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (u *ErrorPassthroughRuleUpsert) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsert {
u.Set(errorpassthroughrule.FieldSkipMonitoring, v)
return u
}
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
func (u *ErrorPassthroughRuleUpsert) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsert {
u.SetExcluded(errorpassthroughrule.FieldSkipMonitoring)
return u
}
// SetDescription sets the "description" field.
func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert {
u.Set(errorpassthroughrule.FieldDescription, v)
@@ -925,20 +888,6 @@ func (u *ErrorPassthroughRuleUpsertOne) ClearCustomMessage() *ErrorPassthroughRu
})
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (u *ErrorPassthroughRuleUpsertOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertOne {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.SetSkipMonitoring(v)
})
}
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
func (u *ErrorPassthroughRuleUpsertOne) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertOne {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.UpdateSkipMonitoring()
})
}
// SetDescription sets the "description" field.
func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
@@ -1388,20 +1337,6 @@ func (u *ErrorPassthroughRuleUpsertBulk) ClearCustomMessage() *ErrorPassthroughR
})
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (u *ErrorPassthroughRuleUpsertBulk) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertBulk {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.SetSkipMonitoring(v)
})
}
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
func (u *ErrorPassthroughRuleUpsertBulk) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertBulk {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.UpdateSkipMonitoring()
})
}
// SetDescription sets the "description" field.
func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {

View File

@@ -227,20 +227,6 @@ func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRule
return _u
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (_u *ErrorPassthroughRuleUpdate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdate {
_u.mutation.SetSkipMonitoring(v)
return _u
}
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetSkipMonitoring(*v)
}
return _u
}
// SetDescription sets the "description" field.
func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate {
_u.mutation.SetDescription(v)
@@ -401,9 +387,6 @@ func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, e
if _u.mutation.CustomMessageCleared() {
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
}
if value, ok := _u.mutation.SkipMonitoring(); ok {
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
}
@@ -628,20 +611,6 @@ func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughR
return _u
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetSkipMonitoring(v)
return _u
}
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetSkipMonitoring(*v)
}
return _u
}
// SetDescription sets the "description" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetDescription(v)
@@ -832,9 +801,6 @@ func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *Er
if _u.mutation.CustomMessageCleared() {
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
}
if value, ok := _u.mutation.SkipMonitoring(); ok {
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
}

View File

@@ -325,7 +325,6 @@ var (
{Name: "response_code", Type: field.TypeInt, Nullable: true},
{Name: "passthrough_body", Type: field.TypeBool, Default: true},
{Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647},
{Name: "skip_monitoring", Type: field.TypeBool, Default: false},
{Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647},
}
// ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table.
@@ -650,7 +649,6 @@ var (
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "api_key_id", Type: field.TypeInt64},
{Name: "account_id", Type: field.TypeInt64},
@@ -666,31 +664,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[26]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[27]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -699,32 +697,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[26]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[27]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[30]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[25]},
},
{
Name: "usagelog_model",
@@ -739,12 +737,12 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]},
},
},
}

View File

@@ -5776,7 +5776,6 @@ type ErrorPassthroughRuleMutation struct {
addresponse_code *int
passthrough_body *bool
custom_message *string
skip_monitoring *bool
description *string
clearedFields map[string]struct{}
done bool
@@ -6504,42 +6503,6 @@ func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() {
delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage)
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) {
m.skip_monitoring = &b
}
// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation.
func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) {
v := m.skip_monitoring
if v == nil {
return
}
return *v, true
}
// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity.
// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldSkipMonitoring requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err)
}
return oldValue.SkipMonitoring, nil
}
// ResetSkipMonitoring resets all changes to the "skip_monitoring" field.
func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() {
m.skip_monitoring = nil
}
// SetDescription sets the "description" field.
func (m *ErrorPassthroughRuleMutation) SetDescription(s string) {
m.description = &s
@@ -6623,7 +6586,7 @@ func (m *ErrorPassthroughRuleMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *ErrorPassthroughRuleMutation) Fields() []string {
fields := make([]string, 0, 15)
fields := make([]string, 0, 14)
if m.created_at != nil {
fields = append(fields, errorpassthroughrule.FieldCreatedAt)
}
@@ -6663,9 +6626,6 @@ func (m *ErrorPassthroughRuleMutation) Fields() []string {
if m.custom_message != nil {
fields = append(fields, errorpassthroughrule.FieldCustomMessage)
}
if m.skip_monitoring != nil {
fields = append(fields, errorpassthroughrule.FieldSkipMonitoring)
}
if m.description != nil {
fields = append(fields, errorpassthroughrule.FieldDescription)
}
@@ -6703,8 +6663,6 @@ func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) {
return m.PassthroughBody()
case errorpassthroughrule.FieldCustomMessage:
return m.CustomMessage()
case errorpassthroughrule.FieldSkipMonitoring:
return m.SkipMonitoring()
case errorpassthroughrule.FieldDescription:
return m.Description()
}
@@ -6742,8 +6700,6 @@ func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string
return m.OldPassthroughBody(ctx)
case errorpassthroughrule.FieldCustomMessage:
return m.OldCustomMessage(ctx)
case errorpassthroughrule.FieldSkipMonitoring:
return m.OldSkipMonitoring(ctx)
case errorpassthroughrule.FieldDescription:
return m.OldDescription(ctx)
}
@@ -6846,13 +6802,6 @@ func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) er
}
m.SetCustomMessage(v)
return nil
case errorpassthroughrule.FieldSkipMonitoring:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetSkipMonitoring(v)
return nil
case errorpassthroughrule.FieldDescription:
v, ok := value.(string)
if !ok {
@@ -7014,9 +6963,6 @@ func (m *ErrorPassthroughRuleMutation) ResetField(name string) error {
case errorpassthroughrule.FieldCustomMessage:
m.ResetCustomMessage()
return nil
case errorpassthroughrule.FieldSkipMonitoring:
m.ResetSkipMonitoring()
return nil
case errorpassthroughrule.FieldDescription:
m.ResetDescription()
return nil
@@ -15061,7 +15007,6 @@ type UsageLogMutation struct {
image_count *int
addimage_count *int
image_size *string
cache_ttl_overridden *bool
created_at *time.Time
clearedFields map[string]struct{}
user *int64
@@ -16688,42 +16633,6 @@ func (m *UsageLogMutation) ResetImageSize() {
delete(m.clearedFields, usagelog.FieldImageSize)
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
m.cache_ttl_overridden = &b
}
// CacheTTLOverridden returns the value of the "cache_ttl_overridden" field in the mutation.
func (m *UsageLogMutation) CacheTTLOverridden() (r bool, exists bool) {
v := m.cache_ttl_overridden
if v == nil {
return
}
return *v, true
}
// OldCacheTTLOverridden returns the old "cache_ttl_overridden" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldCacheTTLOverridden(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldCacheTTLOverridden is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldCacheTTLOverridden requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldCacheTTLOverridden: %w", err)
}
return oldValue.CacheTTLOverridden, nil
}
// ResetCacheTTLOverridden resets all changes to the "cache_ttl_overridden" field.
func (m *UsageLogMutation) ResetCacheTTLOverridden() {
m.cache_ttl_overridden = nil
}
// SetCreatedAt sets the "created_at" field.
func (m *UsageLogMutation) SetCreatedAt(t time.Time) {
m.created_at = &t
@@ -16929,7 +16838,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 31)
fields := make([]string, 0, 30)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -17017,9 +16926,6 @@ func (m *UsageLogMutation) Fields() []string {
if m.image_size != nil {
fields = append(fields, usagelog.FieldImageSize)
}
if m.cache_ttl_overridden != nil {
fields = append(fields, usagelog.FieldCacheTTLOverridden)
}
if m.created_at != nil {
fields = append(fields, usagelog.FieldCreatedAt)
}
@@ -17089,8 +16995,6 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.ImageCount()
case usagelog.FieldImageSize:
return m.ImageSize()
case usagelog.FieldCacheTTLOverridden:
return m.CacheTTLOverridden()
case usagelog.FieldCreatedAt:
return m.CreatedAt()
}
@@ -17160,8 +17064,6 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldImageCount(ctx)
case usagelog.FieldImageSize:
return m.OldImageSize(ctx)
case usagelog.FieldCacheTTLOverridden:
return m.OldCacheTTLOverridden(ctx)
case usagelog.FieldCreatedAt:
return m.OldCreatedAt(ctx)
}
@@ -17376,13 +17278,6 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetImageSize(v)
return nil
case usagelog.FieldCacheTTLOverridden:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetCacheTTLOverridden(v)
return nil
case usagelog.FieldCreatedAt:
v, ok := value.(time.Time)
if !ok {
@@ -17796,9 +17691,6 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldImageSize:
m.ResetImageSize()
return nil
case usagelog.FieldCacheTTLOverridden:
m.ResetCacheTTLOverridden()
return nil
case usagelog.FieldCreatedAt:
m.ResetCreatedAt()
return nil

View File

@@ -326,10 +326,6 @@ func init() {
errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor()
// errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field.
errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool)
// errorpassthroughruleDescSkipMonitoring is the schema descriptor for skip_monitoring field.
errorpassthroughruleDescSkipMonitoring := errorpassthroughruleFields[11].Descriptor()
// errorpassthroughrule.DefaultSkipMonitoring holds the default value on creation for the skip_monitoring field.
errorpassthroughrule.DefaultSkipMonitoring = errorpassthroughruleDescSkipMonitoring.Default.(bool)
groupMixin := schema.Group{}.Mixin()
groupMixinHooks1 := groupMixin[1].Hooks()
group.Hooks[0] = groupMixinHooks1[0]
@@ -779,12 +775,8 @@ func init() {
usagelogDescImageSize := usagelogFields[28].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[29].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[30].Descriptor()
usagelogDescCreatedAt := usagelogFields[29].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@@ -105,12 +105,6 @@ func (ErrorPassthroughRule) Fields() []ent.Field {
Optional().
Nillable(),
// skip_monitoring: 是否跳过运维监控记录
// true: 匹配此规则的错误不会被记录到 ops_error_logs
// false: 正常记录到运维监控(默认行为)
field.Bool("skip_monitoring").
Default(false),
// description: 规则描述,用于说明规则的用途
field.Text("description").
Optional().

View File

@@ -119,10 +119,6 @@ func (UsageLog) Fields() []ent.Field {
Optional().
Nillable(),
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
field.Bool("cache_ttl_overridden").
Default(false),
// 时间戳(只有 created_at日志不可修改
field.Time("created_at").
Default(time.Now).

View File

@@ -80,8 +80,6 @@ type UsageLog struct {
ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field.
ImageSize *string `json:"image_size,omitempty"`
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
@@ -167,7 +165,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden:
case usagelog.FieldStream:
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
@@ -380,12 +378,6 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.ImageSize = new(string)
*_m.ImageSize = value.String
}
case usagelog.FieldCacheTTLOverridden:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
} else if value.Valid {
_m.CacheTTLOverridden = value.Bool
}
case usagelog.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
@@ -556,9 +548,6 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("cache_ttl_overridden=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteByte(')')

View File

@@ -72,8 +72,6 @@ const (
FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database.
FieldImageSize = "image_size"
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
FieldCacheTTLOverridden = "cache_ttl_overridden"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// EdgeUser holds the string denoting the user edge name in mutations.
@@ -157,7 +155,6 @@ var Columns = []string{
FieldIPAddress,
FieldImageCount,
FieldImageSize,
FieldCacheTTLOverridden,
FieldCreatedAt,
}
@@ -214,8 +211,6 @@ var (
DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
ImageSizeValidator func(string) error
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
DefaultCacheTTLOverridden bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
)
@@ -373,11 +368,6 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
}
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()

View File

@@ -200,11 +200,6 @@ func ImageSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
}
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
func CacheTTLOverridden(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
@@ -1445,16 +1440,6 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
}
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
}
// CacheTTLOverriddenNEQ applies the NEQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenNEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldCacheTTLOverridden, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))

View File

@@ -393,20 +393,6 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
return _c
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
_c.mutation.SetCacheTTLOverridden(v)
return _c
}
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableCacheTTLOverridden(v *bool) *UsageLogCreate {
if v != nil {
_c.SetCacheTTLOverridden(*v)
}
return _c
}
// SetCreatedAt sets the "created_at" field.
func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate {
_c.mutation.SetCreatedAt(v)
@@ -545,10 +531,6 @@ func (_c *UsageLogCreate) defaults() {
v := usagelog.DefaultImageCount
_c.mutation.SetImageCount(v)
}
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
v := usagelog.DefaultCacheTTLOverridden
_c.mutation.SetCacheTTLOverridden(v)
}
if _, ok := _c.mutation.CreatedAt(); !ok {
v := usagelog.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
@@ -645,9 +627,6 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
}
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)}
}
@@ -783,10 +762,6 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
_node.ImageSize = &value
}
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
_node.CacheTTLOverridden = value
}
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
@@ -1432,18 +1407,6 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
return u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
u.Set(usagelog.FieldCacheTTLOverridden, v)
return u
}
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateCacheTTLOverridden() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldCacheTTLOverridden)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -2077,20 +2040,6 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetCacheTTLOverridden(v)
})
}
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateCacheTTLOverridden() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateCacheTTLOverridden()
})
}
// Exec executes the query.
func (u *UsageLogUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -2890,20 +2839,6 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetCacheTTLOverridden(v)
})
}
// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateCacheTTLOverridden() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateCacheTTLOverridden()
})
}
// Exec executes the query.
func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -612,20 +612,6 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
_u.mutation.SetCacheTTLOverridden(v)
return _u
}
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdate {
if v != nil {
_u.SetCacheTTLOverridden(*v)
}
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate {
return _u.SetUserID(v.ID)
@@ -908,9 +894,6 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
@@ -1656,20 +1639,6 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
_u.mutation.SetCacheTTLOverridden(v)
return _u
}
// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdateOne {
if v != nil {
_u.SetCacheTTLOverridden(*v)
}
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne {
return _u.SetUserID(v.ID)
@@ -1982,9 +1951,6 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,

View File

@@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
pageSize := dataPageCap
var out []service.Account
for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search)
if err != nil {
return nil, err
}

View File

@@ -156,12 +156,7 @@ func (h *AccountHandler) List(c *gin.Context) {
search = search[:100]
}
var groupID int64
if groupIDStr := c.Query("group"); groupIDStr != "" {
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
}
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -429,17 +424,10 @@ type TestAccountRequest struct {
}
type SyncFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
SelectedAccountIDs []string `json:"selected_account_ids"`
}
type PreviewFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
}
// Test handles testing account connectivity with SSE streaming
@@ -478,11 +466,10 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
}
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
SelectedAccountIDs: req.SelectedAccountIDs,
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
})
if err != nil {
// Provide detailed error message for CRS sync failures
@@ -493,28 +480,6 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
response.Success(c, result)
}
// PreviewFromCRS handles previewing accounts from CRS before sync
// POST /api/v1/admin/accounts/sync/crs/preview
func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
var req PreviewFromCRSRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.crsSyncService.PreviewFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
})
if err != nil {
response.InternalError(c, "CRS preview failed: "+err.Error())
return
}
response.Success(c, result)
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {
@@ -1434,7 +1399,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}

View File

@@ -65,27 +65,3 @@ func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) {
response.Success(c, tokenInfo)
}
// AntigravityRefreshTokenRequest represents the request for validating Antigravity refresh token
type AntigravityRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken validates an Antigravity refresh token and returns full token info
// POST /api/v1/admin/antigravity/oauth/refresh-token
func (h *AntigravityOAuthHandler) RefreshToken(c *gin.Context) {
var req AntigravityRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}

View File

@@ -32,7 +32,6 @@ type CreateErrorPassthroughRuleRequest struct {
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
SkipMonitoring *bool `json:"skip_monitoring"`
Description *string `json:"description"`
}
@@ -49,7 +48,6 @@ type UpdateErrorPassthroughRuleRequest struct {
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
SkipMonitoring *bool `json:"skip_monitoring"`
Description *string `json:"description"`
}
@@ -124,9 +122,6 @@ func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
} else {
rule.PassthroughBody = true
}
if req.SkipMonitoring != nil {
rule.SkipMonitoring = *req.SkipMonitoring
}
rule.ResponseCode = req.ResponseCode
rule.CustomMessage = req.CustomMessage
rule.Description = req.Description
@@ -195,7 +190,6 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
ResponseCode: existing.ResponseCode,
PassthroughBody: existing.PassthroughBody,
CustomMessage: existing.CustomMessage,
SkipMonitoring: existing.SkipMonitoring,
Description: existing.Description,
}
@@ -236,9 +230,6 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
if req.Description != nil {
rule.Description = req.Description
}
if req.SkipMonitoring != nil {
rule.SkipMonitoring = *req.SkipMonitoring
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {

View File

@@ -202,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
writer := csv.NewWriter(&buf)
// Write header
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil {
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
@@ -213,10 +213,6 @@ func (h *RedeemHandler) Export(c *gin.Context) {
if code.UsedBy != nil {
usedBy = fmt.Sprintf("%d", *code.UsedBy)
}
usedByEmail := ""
if code.User != nil {
usedByEmail = code.User.Email
}
usedAt := ""
if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
@@ -228,7 +224,6 @@ func (h *RedeemHandler) Export(c *gin.Context) {
fmt.Sprintf("%.2f", code.Value),
code.Status,
usedBy,
usedByEmail,
usedAt,
code.CreatedAt.Format("2006-01-02 15:04:05"),
}); err != nil {

View File

@@ -211,13 +211,6 @@ func AccountFromServiceShallow(a *service.Account) *Account {
enabled := true
out.EnableSessionIDMasking = &enabled
}
// 缓存 TTL 强制替换
if a.IsCacheTTLOverrideEnabled() {
enabled := true
out.CacheTTLOverrideEnabled = &enabled
target := a.GetCacheTTLOverrideTarget()
out.CacheTTLOverrideTarget = &target
}
}
return out
@@ -405,7 +398,6 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
ImageCount: l.ImageCount,
ImageSize: l.ImageSize,
UserAgent: l.UserAgent,
CacheTTLOverridden: l.CacheTTLOverridden,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),

View File

@@ -150,11 +150,6 @@ type Account struct {
// 从 extra 字段提取,方便前端显示和编辑
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
// 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效)
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -278,9 +273,6 @@ type UsageLog struct {
// User-Agent
UserAgent *string `json:"user_agent"`
// Cache TTL Override 标记
CacheTTLOverridden bool `json:"cache_ttl_overridden"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`

View File

@@ -235,17 +235,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
@@ -253,19 +245,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
}
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
} else {
@@ -360,28 +339,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
return
@@ -434,18 +396,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
fallbackUsed := false
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
}
for {
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
retryWithFallback := false
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
@@ -458,19 +412,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
}
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
} else {
@@ -598,28 +539,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
return
@@ -899,23 +823,6 @@ func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFa
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
)
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
func sleepSameAccountRetryDelay(ctx context.Context) bool {
select {
case <-ctx.Done():
return false
case <-time.After(sameAccountRetryDelay):
return true
}
}
// sleepFailoverDelay 账号切换线性递增延时第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
@@ -931,27 +838,6 @@ func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
}
}
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
// 当分组内只有一个可用账号且上游返回 503MODEL_CAPACITY_EXHAUSTED时使用
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
// (最多 3 次、总等待 30s所以 Handler 层的退避只需短暂等待即可。
// 返回 false 表示 context 已取消。
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
// 固定短延时2s
// Service 层已经在原地等待了足够长的时间retryDelay × 重试次数),
// Handler 层只需短暂间隔后重新进入 Service 层即可。
const delay = 2 * time.Second
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount)
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
@@ -971,10 +857,6 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
msg = *rule.CustomMessage
}
if rule.SkipMonitoring {
c.Set(service.OpsSkipPassthroughKey, true)
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}

View File

@@ -1,51 +0,0 @@
package handler
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// sleepAntigravitySingleAccountBackoff 测试
// ---------------------------------------------------------------------------
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.True(t, ok, "should return true when context is not canceled")
// 固定延迟 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
}
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.False(t, ok, "should return false when context is canceled")
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
}
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
// 验证不同 retryCount 都使用固定 2s 延迟
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
elapsed := time.Since(start)
require.True(t, ok)
// 即使 retryCount=5延迟仍然是固定的 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
require.Less(t, elapsed, 5*time.Second)
}

View File

@@ -327,13 +327,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
@@ -341,19 +334,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
}
}
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return
}
@@ -554,10 +534,6 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
msg = *rule.CustomMessage
}
if rule.SkipMonitoring {
c.Set(service.OpsSkipPassthroughKey, true)
}
googleError(c, respCode, msg)
return
}

View File

@@ -354,10 +354,6 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
msg = *rule.CustomMessage
}
if rule.SkipMonitoring {
c.Set(service.OpsSkipPassthroughKey, true)
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}

View File

@@ -537,13 +537,6 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
// Skip logging if a passthrough rule with skip_monitoring=true matched.
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
if skip, _ := v.(bool); skip {
return
}
}
enqueueOpsErrorLog(ops, entry, requestBody)
return
}
@@ -551,13 +544,6 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
body := w.buf.Bytes()
parsed := parseOpsErrorResponse(body)
// Skip logging if a passthrough rule with skip_monitoring=true matched.
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
if skip, _ := v.(bool); skip {
return
}
}
// Skip logging if the error should be filtered based on settings
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
return

View File

@@ -18,7 +18,6 @@ type ErrorPassthroughRule struct {
ResponseCode *int `json:"response_code"` // 自定义状态码passthrough_code=false 时使用)
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
CustomMessage *string `json:"custom_message"` // 自定义错误信息passthrough_body=false 时使用)
SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录
Description *string `json:"description"` // 规则描述
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`

View File

@@ -27,7 +27,7 @@ type ClaudeMessage struct {
// ThinkingConfig Thinking 配置
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" / "adaptive" / "disabled"
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget
}

View File

@@ -115,23 +115,6 @@ type LoadCodeAssistResponse struct {
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
}
// OnboardUserRequest onboardUser 请求
type OnboardUserRequest struct {
TierID string `json:"tierId"`
Metadata struct {
IDEType string `json:"ideType"`
Platform string `json:"platform,omitempty"`
PluginType string `json:"pluginType,omitempty"`
} `json:"metadata"`
}
// OnboardUserResponse onboardUser 响应
type OnboardUserResponse struct {
Name string `json:"name,omitempty"`
Done bool `json:"done"`
Response map[string]any `json:"response,omitempty"`
}
// GetTier 获取账户类型
// 优先返回 paidTier付费订阅级别否则返回 currentTier
func (r *LoadCodeAssistResponse) GetTier() string {
@@ -378,117 +361,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
return nil, nil, lastErr
}
// OnboardUser 触发账号 onboarding并返回 project_id
// 说明:
// 1) 部分账号 loadCodeAssist 不会立即返回 cloudaicompanionProject
// 2) 这时需要调用 onboardUser 完成初始化,之后才能拿到 project_id。
func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
tierID = strings.TrimSpace(tierID)
if tierID == "" {
return "", fmt.Errorf("tier_id 为空")
}
reqBody := OnboardUserRequest{TierID: tierID}
reqBody.Metadata.IDEType = "ANTIGRAVITY"
reqBody.Metadata.Platform = "PLATFORM_UNSPECIFIED"
reqBody.Metadata.PluginType = "GEMINI"
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
availableURLs := BaseURLs
var lastErr error
for urlIdx, baseURL := range availableURLs {
apiURL := baseURL + "/v1internal:onboardUser"
for attempt := 1; attempt <= 5; attempt++ {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes))
if err != nil {
lastErr = fmt.Errorf("创建请求失败: %w", err)
break
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("onboardUser 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
log.Printf("[antigravity] onboardUser URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
break
}
return "", lastErr
}
respBodyBytes, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
}
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
log.Printf("[antigravity] onboardUser URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
break
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("onboardUser 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
return "", lastErr
}
var onboardResp OnboardUserResponse
if err := json.Unmarshal(respBodyBytes, &onboardResp); err != nil {
lastErr = fmt.Errorf("onboardUser 响应解析失败: %w", err)
return "", lastErr
}
if onboardResp.Done {
if projectID := extractProjectIDFromOnboardResponse(onboardResp.Response); projectID != "" {
DefaultURLAvailability.MarkSuccess(baseURL)
return projectID, nil
}
lastErr = fmt.Errorf("onboardUser 完成但未返回 project_id")
return "", lastErr
}
// done=false 时等待后重试(与 CLIProxyAPI 行为一致)
select {
case <-time.After(2 * time.Second):
case <-ctx.Done():
return "", ctx.Err()
}
}
}
if lastErr != nil {
return "", lastErr
}
return "", fmt.Errorf("onboardUser 未返回 project_id")
}
func extractProjectIDFromOnboardResponse(resp map[string]any) string {
if len(resp) == 0 {
return ""
}
if v, ok := resp["cloudaicompanionProject"]; ok {
switch project := v.(type) {
case string:
return strings.TrimSpace(project)
case map[string]any:
if id, ok := project["id"].(string); ok {
return strings.TrimSpace(id)
}
}
}
return ""
}
// ModelQuotaInfo 模型配额信息
type ModelQuotaInfo struct {
RemainingFraction float64 `json:"remainingFraction"`

View File

@@ -1,76 +0,0 @@
package antigravity
import (
"testing"
)
func TestExtractProjectIDFromOnboardResponse(t *testing.T) {
t.Parallel()
tests := []struct {
name string
resp map[string]any
want string
}{
{
name: "nil response",
resp: nil,
want: "",
},
{
name: "empty response",
resp: map[string]any{},
want: "",
},
{
name: "project as string",
resp: map[string]any{
"cloudaicompanionProject": "my-project-123",
},
want: "my-project-123",
},
{
name: "project as string with spaces",
resp: map[string]any{
"cloudaicompanionProject": " my-project-123 ",
},
want: "my-project-123",
},
{
name: "project as map with id",
resp: map[string]any{
"cloudaicompanionProject": map[string]any{
"id": "proj-from-map",
},
},
want: "proj-from-map",
},
{
name: "project as map without id",
resp: map[string]any{
"cloudaicompanionProject": map[string]any{
"name": "some-name",
},
},
want: "",
},
{
name: "missing cloudaicompanionProject key",
resp: map[string]any{
"otherField": "value",
},
want: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := extractProjectIDFromOnboardResponse(tc.resp)
if got != tc.want {
t.Fatalf("extractProjectIDFromOnboardResponse() = %q, want %q", got, tc.want)
}
})
}
}

View File

@@ -155,7 +155,6 @@ type GeminiUsageMetadata struct {
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens按输出价格计费
}
// GeminiGroundingMetadata Gemini grounding 元数据Web Search

View File

@@ -64,10 +64,6 @@ const MaxTokensBudgetPadding = 1000
// Gemini 2.5 Flash thinking budget 上限
const Gemini25FlashThinkingBudgetLimit = 24576
// 对于 Antigravity 的 Claudebudget-only模型该语义最终等价为 thinkingBudget=24576。
// 这里复用相同数值以保持行为一致。
const ClaudeAdaptiveHighThinkingBudgetTokens = Gemini25FlashThinkingBudgetLimit
// ensureMaxTokensGreaterThanBudget 确保 max_tokens > budget_tokens
// Claude API 要求启用 thinking 时max_tokens 必须大于 thinking.budget_tokens
// 返回调整后的 maxTokens 和是否进行了调整
@@ -100,7 +96,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
}
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
@@ -202,7 +198,8 @@ type modelInfo struct {
// modelInfoMap 模型前缀 → 模型信息映射
// 只有在此映射表中的模型才会注入身份提示词
// 注意:模型映射逻辑在网关层完成;这里仅用于按模型前缀判断是否注入身份提示词。
// 注意:当前 claude-opus-4-6 会被映射到 claude-opus-4-5-thinking
// 但保留此条目以便后续 Antigravity 上游支持 4.6 时快速切换
var modelInfoMap = map[string]modelInfo{
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
@@ -274,21 +271,6 @@ func filterOpenCodePrompt(text string) string {
return ""
}
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
var systemBlockFilterPrefixes = []string{
"x-anthropic-billing-header",
}
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
func filterSystemBlockByPrefix(text string) string {
for _, prefix := range systemBlockFilterPrefixes {
if strings.HasPrefix(text, prefix) {
return ""
}
}
return text
}
// buildSystemInstruction 构建 systemInstruction与 Antigravity-Manager 保持一致)
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
var parts []GeminiPart
@@ -305,8 +287,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if strings.Contains(sysStr, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词和黑名单前缀
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(sysStr)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
@@ -320,8 +302,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if strings.Contains(block.Text, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词和黑名单前缀
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(block.Text)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
@@ -596,10 +578,6 @@ func maxOutputTokensLimit(model string) int {
return maxOutputTokensUpperBound
}
func isAntigravityOpus46Model(model string) bool {
return strings.HasPrefix(strings.ToLower(model), "claude-opus-4-6")
}
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
maxLimit := maxOutputTokensLimit(req.Model)
config := &GeminiGenerationConfig{
@@ -613,36 +591,25 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
}
// Thinking 配置
if req.Thinking != nil && (req.Thinking.Type == "enabled" || req.Thinking.Type == "adaptive") {
if req.Thinking != nil && req.Thinking.Type == "enabled" {
config.ThinkingConfig = &GeminiThinkingConfig{
IncludeThoughts: true,
}
// - thinking.type=enabledbudget_tokens>0 用显式预算
// - thinking.type=adaptive仅在 Antigravity 的 Opus 4.6 上覆写为 24576
budget := -1
if req.Thinking.BudgetTokens > 0 {
budget = req.Thinking.BudgetTokens
}
if req.Thinking.Type == "adaptive" && isAntigravityOpus46Model(req.Model) {
budget = ClaudeAdaptiveHighThinkingBudgetTokens
}
// 正预算需要做上限与 max_tokens 约束;动态预算(-1直接透传给上游。
if budget > 0 {
budget := req.Thinking.BudgetTokens
// gemini-2.5-flash 上限
if strings.Contains(req.Model, "gemini-2.5-flash") && budget > Gemini25FlashThinkingBudgetLimit {
budget = Gemini25FlashThinkingBudgetLimit
}
config.ThinkingConfig.ThinkingBudget = budget
// 自动修正max_tokens 必须大于 budget_tokensClaude 上游要求)
// 自动修正max_tokens 必须大于 budget_tokens
if adjusted, ok := ensureMaxTokensGreaterThanBudget(config.MaxOutputTokens, budget); ok {
log.Printf("[Antigravity] Auto-adjusted max_tokens from %d to %d (must be > budget_tokens=%d)",
config.MaxOutputTokens, adjusted, budget)
config.MaxOutputTokens = adjusted
}
}
config.ThinkingConfig.ThinkingBudget = budget
}
if config.MaxOutputTokens > maxLimit {

View File

@@ -259,93 +259,3 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
})
}
}
func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
tests := []struct {
name string
model string
thinking *ThinkingConfig
wantBudget int
wantPresent bool
}{
{
name: "enabled without budget defaults to dynamic (-1)",
model: "claude-opus-4-6-thinking",
thinking: &ThinkingConfig{Type: "enabled"},
wantBudget: -1,
wantPresent: true,
},
{
name: "enabled with budget uses the provided value",
model: "claude-opus-4-6-thinking",
thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1024},
wantBudget: 1024,
wantPresent: true,
},
{
name: "enabled with -1 budget uses dynamic (-1)",
model: "claude-opus-4-6-thinking",
thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: -1},
wantBudget: -1,
wantPresent: true,
},
{
name: "adaptive on opus4.6 maps to high budget (24576)",
model: "claude-opus-4-6-thinking",
thinking: &ThinkingConfig{Type: "adaptive", BudgetTokens: 20000},
wantBudget: ClaudeAdaptiveHighThinkingBudgetTokens,
wantPresent: true,
},
{
name: "adaptive on non-opus model keeps default dynamic (-1)",
model: "claude-sonnet-4-5-thinking",
thinking: &ThinkingConfig{Type: "adaptive"},
wantBudget: -1,
wantPresent: true,
},
{
name: "disabled does not emit thinkingConfig",
model: "claude-opus-4-6-thinking",
thinking: &ThinkingConfig{Type: "disabled", BudgetTokens: 1024},
wantBudget: 0,
wantPresent: false,
},
{
name: "nil thinking does not emit thinkingConfig",
model: "claude-opus-4-6-thinking",
thinking: nil,
wantBudget: 0,
wantPresent: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &ClaudeRequest{
Model: tt.model,
Thinking: tt.thinking,
}
cfg := buildGenerationConfig(req)
if cfg == nil {
t.Fatalf("expected non-nil generationConfig")
}
if tt.wantPresent {
if cfg.ThinkingConfig == nil {
t.Fatalf("expected thinkingConfig to be present")
}
if !cfg.ThinkingConfig.IncludeThoughts {
t.Fatalf("expected includeThoughts=true")
}
if cfg.ThinkingConfig.ThinkingBudget != tt.wantBudget {
t.Fatalf("expected thinkingBudget=%d, got %d", tt.wantBudget, cfg.ThinkingConfig.ThinkingBudget)
}
return
}
if cfg.ThinkingConfig != nil {
t.Fatalf("expected thinkingConfig to be nil, got %+v", cfg.ThinkingConfig)
}
})
}
}

View File

@@ -279,7 +279,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
if geminiResp.UsageMetadata != nil {
cached := geminiResp.UsageMetadata.CachedContentTokenCount
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
usage.CacheReadInputTokens = cached
}

View File

@@ -85,7 +85,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
if geminiResp.UsageMetadata != nil {
cached := geminiResp.UsageMetadata.CachedContentTokenCount
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
p.cacheReadTokens = cached
}
@@ -146,7 +146,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
if v1Resp.Response.UsageMetadata != nil {
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
usage.CacheReadInputTokens = cached
}

View File

@@ -10,7 +10,6 @@ const (
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01"
BetaContext1M = "context-1m-2025-08-07"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
@@ -78,12 +77,6 @@ var DefaultModels = []Model{
DisplayName: "Claude Opus 4.6",
CreatedAt: "2026-02-06T00:00:00Z",
},
{
ID: "claude-sonnet-4-6",
Type: "model",
DisplayName: "Claude Sonnet 4.6",
CreatedAt: "2026-02-18T00:00:00Z",
},
{
ID: "claude-sonnet-4-5-20250929",
Type: "model",

View File

@@ -28,8 +28,4 @@ const (
// IsMaxTokensOneHaikuRequest 标识当前请求是否为 max_tokens=1 + haiku 模型的探测请求
// 用于 ClaudeCodeOnly 验证绕过(绕过 system prompt 检查,但仍需验证 User-Agent
IsMaxTokensOneHaikuRequest Key = "ctx_is_max_tokens_one_haiku"
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
// 在此模式下Service 层的模型限流预检查将等待限流过期而非直接切换账号。
SingleAccountRetry Key = "ctx_single_account_retry"
)

View File

@@ -282,34 +282,6 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return &accounts[0], nil
}
func (r *accountRepository) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT id, extra->>'crs_account_id'
FROM accounts
WHERE deleted_at IS NULL
AND extra->>'crs_account_id' IS NOT NULL
AND extra->>'crs_account_id' != ''
`)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make(map[string]int64)
for rows.Next() {
var id int64
var crsID string
if err := rows.Scan(&id, &crsID); err != nil {
return nil, err
}
result[crsID] = id
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
if account == nil {
return nil
@@ -435,10 +407,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "", 0)
return r.ListWithFilters(ctx, params, "", "", "", "")
}
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
q := r.client.Account.Query()
if platform != "" {
@@ -448,19 +420,11 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
q = q.Where(dbaccount.TypeEQ(accountType))
}
if status != "" {
switch status {
case "rate_limited":
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
default:
q = q.Where(dbaccount.StatusEQ(status))
}
q = q.Where(dbaccount.StatusEQ(status))
}
if search != "" {
q = q.Where(dbaccount.NameContainsFold(search))
}
if groupID > 0 {
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
}
total, err := q.Count(ctx)
if err != nil {

View File

@@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
tt.setup(client)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil {
@@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1)

View File

@@ -54,8 +54,7 @@ func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.Err
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody).
SetSkipMonitoring(rule.SkipMonitoring)
SetPassthroughBody(rule.PassthroughBody)
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
@@ -91,8 +90,7 @@ func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.Err
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody).
SetSkipMonitoring(rule.SkipMonitoring)
SetPassthroughBody(rule.PassthroughBody)
// 处理可选字段
if len(rule.ErrorCodes) > 0 {
@@ -151,7 +149,6 @@ func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model
Platforms: e.Platforms,
PassthroughCode: e.PassthroughCode,
PassthroughBody: e.PassthroughBody,
SkipMonitoring: e.SkipMonitoring,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}

View File

@@ -6,7 +6,6 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -107,12 +106,7 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
q = q.Where(redeemcode.StatusEQ(status))
}
if search != "" {
q = q.Where(
redeemcode.Or(
redeemcode.CodeContainsFold(search),
redeemcode.HasUserWith(user.EmailContainsFold(search)),
),
)
q = q.Where(redeemcode.CodeContainsFold(search))
}
total, err := q.Count(ctx)

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
type usageLogRepository struct {
client *dbent.Client
@@ -115,7 +115,6 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
image_count,
image_size,
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
@@ -123,7 +122,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -174,7 +173,6 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.ImageCount,
imageSize,
reasoningEffort,
log.CacheTTLOverridden,
createdAt,
}
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
@@ -2197,7 +2195,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
imageCount int
imageSize sql.NullString
reasoningEffort sql.NullString
cacheTTLOverridden bool
createdAt time.Time
)
@@ -2233,7 +2230,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&imageCount,
&imageSize,
&reasoningEffort,
&cacheTTLOverridden,
&createdAt,
); err != nil {
return nil, err
@@ -2262,7 +2258,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
BillingType: int8(billingType),
Stream: stream,
ImageCount: imageCount,
CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: createdAt,
}

View File

@@ -10,7 +10,6 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
@@ -192,7 +191,6 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(filters.Search),
dbuser.NotesContainsFold(filters.Search),
dbuser.HasAPIKeysWith(apikey.KeyContainsFold(filters.Search)),
),
)
}

View File

@@ -401,7 +401,6 @@ func TestAPIContracts(t *testing.T) {
"first_token_ms": 50,
"image_count": 0,
"image_size": null,
"cache_ttl_overridden": false,
"created_at": "2025-01-02T03:04:05Z",
"user_agent": null
}
@@ -937,7 +936,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
return nil, nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -1050,10 +1049,6 @@ func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
return int64(len(ids)), nil
}
func (s *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return nil, errors.New("not implemented")
}
type stubProxyRepo struct{}
func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error {

View File

@@ -70,15 +70,7 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
}
}
allowHeaders := []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key"}
// openai node sdk
openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout"}
for _, prop := range openAIProperties {
allowHeaders = append(allowHeaders, "x-stainless-"+prop)
}
c.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(allowHeaders, ", "))
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求

View File

@@ -209,7 +209,6 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
@@ -281,7 +280,6 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers)
{
antigravity.POST("/oauth/auth-url", h.Admin.AntigravityOAuth.GenerateAuthURL)
antigravity.POST("/oauth/exchange-code", h.Admin.AntigravityOAuth.ExchangeCode)
antigravity.POST("/oauth/refresh-token", h.Admin.AntigravityOAuth.RefreshToken)
}
}

View File

@@ -752,38 +752,6 @@ func (a *Account) IsSessionIDMaskingEnabled() bool {
return false
}
// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型5m 或 1h
func (a *Account) IsCacheTTLOverrideEnabled() bool {
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["cache_ttl_override_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetCacheTTLOverrideTarget 获取缓存 TTL 强制替换的目标类型
// 返回 "5m" 或 "1h",默认 "5m"
func (a *Account) GetCacheTTLOverrideTarget() string {
if a.Extra == nil {
return "5m"
}
if v, ok := a.Extra["cache_ttl_override_target"]; ok {
if target, ok := v.(string); ok && (target == "5m" || target == "1h") {
return target
}
}
return "5m"
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {

View File

@@ -25,14 +25,11 @@ type AccountRepository interface {
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
// for all accounts that have been synced from CRS.
ListCRSAccountIDs(ctx context.Context) (map[string]int64, error)
Update(ctx context.Context, account *Account) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error)

View File

@@ -54,10 +54,6 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st
panic("unexpected GetByCRSAccountID call")
}
func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
panic("unexpected ListCRSAccountIDs call")
}
func (s *accountRepoStub) Update(ctx context.Context, account *Account) error {
panic("unexpected Update call")
}
@@ -75,7 +71,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
panic("unexpected List call")
}
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}

View File

@@ -39,7 +39,7 @@ type AdminService interface {
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
@@ -1021,9 +1021,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID)
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil {
return nil, 0, err
}

View File

@@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct {
listWithFiltersErr error
}
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
@@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
require.NoError(t, err)
require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)

View File

@@ -16,12 +16,10 @@ import (
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -41,12 +39,6 @@ const (
antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待)
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
// MODEL_CAPACITY_EXHAUSTED 专用重试参数
// 模型容量不足时,所有账号共享同一容量池,切换账号无意义
// 使用固定 1s 间隔重试,最多重试 60 次
antigravityModelCapacityRetryMaxAttempts = 60
antigravityModelCapacityRetryWait = 1 * time.Second
// Google RPC 状态和类型常量
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
googleRPCStatusUnavailable = "UNAVAILABLE"
@@ -54,22 +46,6 @@ const (
googleRPCTypeErrorInfo = "type.googleapis.com/google.rpc.ErrorInfo"
googleRPCReasonModelCapacityExhausted = "MODEL_CAPACITY_EXHAUSTED"
googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED"
// 单账号 503 退避重试Service 层原地重试的最大次数
// 在 handleSmartRetry 中,对于 shouldRateLimitModel长延迟 ≥ 7s的情况
// 多账号模式下会设限流+切换账号;但单账号模式下改为原地等待+重试。
antigravitySingleAccountSmartRetryMaxAttempts = 3
// 单账号 503 退避重试:原地重试时单次最大等待时间
// 防止上游返回过长的 retryDelay 导致请求卡住太久
antigravitySingleAccountSmartRetryMaxWait = 15 * time.Second
// 单账号 503 退避重试:原地重试的总累计等待时间上限
// 超过此上限将不再重试,直接返回 503
antigravitySingleAccountSmartRetryTotalMaxWait = 30 * time.Second
// MODEL_CAPACITY_EXHAUSTED 全局去重:重试全部失败后的 cooldown 时间
antigravityModelCapacityCooldown = 10 * time.Second
)
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
@@ -78,15 +54,8 @@ var antigravityPassthroughErrorMessages = []string{
"prompt is too long",
}
// MODEL_CAPACITY_EXHAUSTED 全局去重:避免多个并发请求同时对同一模型进行容量耗尽重试
var (
modelCapacityExhaustedMu sync.RWMutex
modelCapacityExhaustedUntil = make(map[string]time.Time) // modelName -> cooldown until
)
const (
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
)
@@ -148,20 +117,6 @@ type antigravityRetryLoopResult struct {
resp *http.Response
}
// resolveAntigravityForwardBaseURL 解析转发用 base URL。
// 默认使用 dailyForwardBaseURLs 的首个地址);当环境变量为 prod 时使用第二个地址。
func resolveAntigravityForwardBaseURL() string {
baseURLs := antigravity.ForwardBaseURLs()
if len(baseURLs) == 0 {
return ""
}
mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv)))
if mode == "prod" && len(baseURLs) > 1 {
return baseURLs[1]
}
return baseURLs[0]
}
// smartRetryAction 智能重试的处理结果
type smartRetryAction int
@@ -189,17 +144,10 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
}
// 判断是否触发智能重试
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody)
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody)
// 情况1: retryDelay >= 阈值,限流模型并切换账号
if shouldRateLimitModel {
// 单账号 503 退避重试模式:不设限流、不切换账号,改为原地等待+重试
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
// 多账号场景下切换账号是最优选择,但单账号场景下设限流毫无意义(只会导致双重等待)。
if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) {
return s.handleSingleAccountRetryInPlace(p, resp, respBody, baseURL, waitDuration, modelName)
}
rateLimitDuration := waitDuration
if rateLimitDuration <= 0 {
rateLimitDuration = antigravityDefaultRateLimitDuration
@@ -226,48 +174,20 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
}
}
// 情况2: retryDelay < 阈值(或 MODEL_CAPACITY_EXHAUSTED智能重试
// 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次)
if shouldSmartRetry {
var lastRetryResp *http.Response
var lastRetryBody []byte
// MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数60 次,固定 1s 间隔)
maxAttempts := antigravitySmartRetryMaxAttempts
if isModelCapacityExhausted {
maxAttempts = antigravityModelCapacityRetryMaxAttempts
waitDuration = antigravityModelCapacityRetryWait
// 全局去重:如果其他 goroutine 已在重试同一模型且尚在 cooldown 中,直接返回 503
if modelName != "" {
modelCapacityExhaustedMu.RLock()
cooldownUntil, exists := modelCapacityExhaustedUntil[modelName]
modelCapacityExhaustedMu.RUnlock()
if exists && time.Now().Before(cooldownUntil) {
log.Printf("%s status=%d model_capacity_exhausted_dedup model=%s account=%d cooldown_until=%v (skip retry)",
p.prefix, resp.StatusCode, modelName, p.account.ID, cooldownUntil.Format("15:04:05"))
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
},
}
}
}
}
for attempt := 1; attempt <= maxAttempts; attempt++ {
for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ {
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID)
p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID)
timer := time.NewTimer(waitDuration)
select {
case <-p.ctx.Done():
timer.Stop()
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
case <-timer.C:
case <-time.After(waitDuration):
}
// 智能重试:创建新请求
@@ -287,19 +207,13 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts)
// 重试成功,清除 MODEL_CAPACITY_EXHAUSTED cooldown
if isModelCapacityExhausted && modelName != "" {
modelCapacityExhaustedMu.Lock()
delete(modelCapacityExhaustedUntil, modelName)
modelCapacityExhaustedMu.Unlock()
}
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts)
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
}
// 网络错误时,继续重试
if retryErr != nil || retryResp == nil {
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, maxAttempts, retryErr)
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr)
continue
}
@@ -309,20 +223,20 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
}
lastRetryResp = retryResp
if retryResp != nil {
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
}
// 解析新的重试信息,用于下次重试的等待时间MODEL_CAPACITY_EXHAUSTED 使用固定循环,跳过)
if !isModelCapacityExhausted && attempt < maxAttempts && lastRetryBody != nil {
newShouldRetry, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
// 解析新的重试信息,用于下次重试的等待时间
if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil {
newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
if newShouldRetry && newWaitDuration > 0 {
waitDuration = newWaitDuration
}
}
}
// 所有重试都失败
// 所有重试都失败,限流当前模型并切换账号
rateLimitDuration := waitDuration
if rateLimitDuration <= 0 {
rateLimitDuration = antigravityDefaultRateLimitDuration
@@ -331,45 +245,8 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
if retryBody == nil {
retryBody = respBody
}
// MODEL_CAPACITY_EXHAUSTED模型容量不足切换账号无意义
// 直接返回上游错误响应,不设置模型限流,不切换账号
if isModelCapacityExhausted {
// 设置 cooldown让后续请求快速失败避免重复重试
if modelName != "" {
modelCapacityExhaustedMu.Lock()
modelCapacityExhaustedUntil[modelName] = time.Now().Add(antigravityModelCapacityCooldown)
modelCapacityExhaustedMu.Unlock()
}
log.Printf("%s status=%d smart_retry_exhausted_model_capacity attempts=%d model=%s account=%d body=%s (model capacity exhausted, not switching account)",
p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200))
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
},
}
}
// 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号,
// 直接返回 503 让 Handler 层的单账号退避循环做最终处理。
if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) {
log.Printf("%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)",
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200))
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
},
}
}
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)",
p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200))
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200))
resetAt := time.Now().Add(rateLimitDuration)
if p.accountRepo != nil && modelName != "" {
@@ -402,163 +279,25 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
return &smartRetryResult{action: smartRetryActionContinue}
}
// handleSingleAccountRetryInPlace 单账号 503 退避重试的原地重试逻辑。
//
// 在多账号场景下,收到 503 + 长 retryDelay≥ 7s时会设置模型限流 + 切换账号;
// 但在单账号场景下,设限流毫无意义(因为切换回来的还是同一个账号,还要等限流过期)。
// 此方法改为在 Service 层原地等待 + 重试,避免双重等待问题:
//
// 旧流程Service 设限流 → Handler 退避等待 → Service 等限流过期 → 再请求(总耗时 = 退避 + 限流)
// 新流程Service 直接等 retryDelay → 重试 → 成功/再等 → 重试...(总耗时 ≈ 实际 retryDelay × 重试次数)
//
// 约束:
// - 单次等待不超过 antigravitySingleAccountSmartRetryMaxWait
// - 总累计等待不超过 antigravitySingleAccountSmartRetryTotalMaxWait
// - 最多重试 antigravitySingleAccountSmartRetryMaxAttempts 次
func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
p antigravityRetryLoopParams,
resp *http.Response,
respBody []byte,
baseURL string,
waitDuration time.Duration,
modelName string,
) *smartRetryResult {
// 限制单次等待时间
if waitDuration > antigravitySingleAccountSmartRetryMaxWait {
waitDuration = antigravitySingleAccountSmartRetryMaxWait
}
if waitDuration < antigravitySmartRetryMinWait {
waitDuration = antigravitySmartRetryMinWait
}
log.Printf("%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)",
p.prefix, resp.StatusCode, modelName, p.account.ID, waitDuration)
var lastRetryResp *http.Response
var lastRetryBody []byte
totalWaited := time.Duration(0)
for attempt := 1; attempt <= antigravitySingleAccountSmartRetryMaxAttempts; attempt++ {
// 检查累计等待是否超限
if totalWaited+waitDuration > antigravitySingleAccountSmartRetryTotalMaxWait {
remaining := antigravitySingleAccountSmartRetryTotalMaxWait - totalWaited
if remaining <= 0 {
log.Printf("%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up",
p.prefix, totalWaited, antigravitySingleAccountSmartRetryTotalMaxWait)
break
}
waitDuration = remaining
}
log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d",
p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID)
timer := time.NewTimer(waitDuration)
select {
case <-p.ctx.Done():
timer.Stop()
log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix)
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
case <-timer.C:
}
totalWaited += waitDuration
// 创建新请求
retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
if err != nil {
log.Printf("%s single_account_503_retry: request_build_failed error=%v", p.prefix, err)
break
}
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
log.Printf("%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v",
p.prefix, retryResp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited)
// 关闭之前的响应
if lastRetryResp != nil {
_ = lastRetryResp.Body.Close()
}
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
}
// 网络错误时继续重试
if retryErr != nil || retryResp == nil {
log.Printf("%s single_account_503_retry: network_error attempt=%d/%d error=%v",
p.prefix, attempt, antigravitySingleAccountSmartRetryMaxAttempts, retryErr)
continue
}
// 关闭之前的响应
if lastRetryResp != nil {
_ = lastRetryResp.Body.Close()
}
lastRetryResp = retryResp
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
_ = retryResp.Body.Close()
// 解析新的重试信息,更新下次等待时间
if attempt < antigravitySingleAccountSmartRetryMaxAttempts && lastRetryBody != nil {
_, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
if newWaitDuration > 0 {
waitDuration = newWaitDuration
if waitDuration > antigravitySingleAccountSmartRetryMaxWait {
waitDuration = antigravitySingleAccountSmartRetryMaxWait
}
if waitDuration < antigravitySmartRetryMinWait {
waitDuration = antigravitySmartRetryMinWait
}
}
}
}
// 所有重试都失败,不设限流,直接返回 503
// Handler 层的单账号退避循环会做最终处理
retryBody := lastRetryBody
if retryBody == nil {
retryBody = respBody
}
log.Printf("%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)",
p.prefix, resp.StatusCode, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited, modelName, p.account.ID, truncateForLog(retryBody, 200))
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
},
}
}
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
// 预检查:如果账号已限流,直接返回切换信号
if p.requestedModel != "" {
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求。
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
// 如果上游确实还不可用handleSmartRetry → handleSingleAccountRetryInPlace
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
if isSingleAccountRetry(p.ctx) {
log.Printf("%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
} else {
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
return nil, &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: p.requestedModel,
IsStickySession: p.isStickySession,
}
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
return nil, &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: p.requestedModel,
IsStickySession: p.isStickySession,
}
}
}
baseURL := resolveAntigravityForwardBaseURL()
if baseURL == "" {
return nil, errors.New("no antigravity forward base url configured")
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs
}
availableURLs := []string{baseURL}
var resp *http.Response
var usedBaseURL string
@@ -632,12 +371,12 @@ urlFallbackLoop:
_ = resp.Body.Close()
// ★ 统一入口:自定义错误码 + 临时不可调度
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
if policyErr != nil {
return nil, policyErr
}
resp = &http.Response{
StatusCode: outStatus,
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
@@ -871,22 +610,21 @@ func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, accoun
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
}
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。
// ErrorPolicySkipped 时 outStatus 为 500前端约定未命中的错误返回 500
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) {
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) {
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
case ErrorPolicySkipped:
return true, http.StatusInternalServerError, nil
return true, nil
case ErrorPolicyMatched:
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
return true, statusCode, nil
return true, nil
case ErrorPolicyTempUnscheduled:
slog.Info("temp_unschedulable_matched",
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
}
return false, statusCode, nil
return false, nil
}
// mapAntigravityModel 获取映射后的模型名
@@ -996,11 +734,11 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
proxyURL = account.Proxy.URL()
}
baseURL := resolveAntigravityForwardBaseURL()
if baseURL == "" {
return nil, errors.New("no antigravity forward base url configured")
// URL fallback 循环
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
if len(availableURLs) == 0 {
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
}
availableURLs := []string{baseURL}
var lastErr error
for urlIdx, baseURL := range availableURLs {
@@ -1309,7 +1047,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
}
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5自动改为 thinking 版本
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
// 获取 access_token
@@ -1465,7 +1203,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
break
}
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
if retryResp.StatusCode == http.StatusTooManyRequests {
retryBaseURL := ""
@@ -1546,27 +1284,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
if resp.StatusCode == http.StatusBadRequest {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if isGoogleProjectConfigError(msg) {
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
upstreamDetail := s.getUpstreamErrorDetail(respBody)
log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true}
}
}
if s.shouldFailoverUpstreamError(resp.StatusCode) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
@@ -2107,22 +1824,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// Always record upstream context for Ops error logs, even when we will failover.
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
if resp.StatusCode == http.StatusBadRequest && isGoogleProjectConfigError(strings.ToLower(upstreamMsg)) {
log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps, RetryableOnSameAccount: true}
}
if s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
@@ -2218,44 +1919,6 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
}
}
// isGoogleProjectConfigError 判断(已提取的小写)错误消息是否属于 Google 服务端配置类问题。
// 只精确匹配已知的服务端侧错误,避免对客户端请求错误做无意义重试。
// 适用于所有走 Google 后端的平台Antigravity、Gemini
func isGoogleProjectConfigError(lowerMsg string) bool {
// Google 间歇性 BugProject ID 有效但被临时识别失败
return strings.Contains(lowerMsg, "invalid project resource name")
}
// googleConfigErrorCooldown 服务端配置类 400 错误的临时封禁时长
const googleConfigErrorCooldown = 1 * time.Minute
// tempUnscheduleGoogleConfigError 对服务端配置类 400 错误触发临时封禁,
// 避免短时间内反复调度到同一个有问题的账号。
func tempUnscheduleGoogleConfigError(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) {
until := time.Now().Add(googleConfigErrorCooldown)
reason := "400: invalid project resource name (auto temp-unschedule 1m)"
if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil {
log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err)
} else {
log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason)
}
}
// emptyResponseCooldown 空流式响应的临时封禁时长
const emptyResponseCooldown = 1 * time.Minute
// tempUnscheduleEmptyResponse 对空流式响应触发临时封禁,
// 避免短时间内反复调度到同一个返回空响应的账号。
func tempUnscheduleEmptyResponse(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) {
until := time.Now().Add(emptyResponseCooldown)
reason := "empty stream response (auto temp-unschedule 1m)"
if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil {
log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err)
} else {
log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason)
}
}
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
// 返回 true 表示正常完成等待false 表示 context 已取消
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
@@ -2272,22 +1935,14 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
sleepFor = 0
}
timer := time.NewTimer(sleepFor)
select {
case <-ctx.Done():
timer.Stop()
return false
case <-timer.C:
case <-time.After(sleepFor):
return true
}
}
// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记
func isSingleAccountRetry(ctx context.Context) bool {
v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool)
return v
}
// setModelRateLimitByModelName 使用官方模型 ID 设置模型级限流
// 直接使用上游返回的模型 ID如 claude-sonnet-4-5作为限流 key
// 返回是否已成功设置(若模型名为空或 repo 为 nil 将返回 false
@@ -2322,9 +1977,8 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
// antigravitySmartRetryInfo 智能重试所需的信息
type antigravitySmartRetryInfo struct {
RetryDelay time.Duration // 重试延迟时间
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5"
IsModelCapacityExhausted bool // 是否为模型容量不足MODEL_CAPACITY_EXHAUSTED
RetryDelay time.Duration // 重试延迟时间
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5"
}
// parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息
@@ -2439,40 +2093,31 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo {
}
return &antigravitySmartRetryInfo{
RetryDelay: retryDelay,
ModelName: modelName,
IsModelCapacityExhausted: hasModelCapacityExhausted,
RetryDelay: retryDelay,
ModelName: modelName,
}
}
// shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试
// 返回:
// - shouldRetry: 是否应该智能重试retryDelay < antigravityRateLimitThreshold,或 MODEL_CAPACITY_EXHAUSTED
// - shouldRateLimitModel: 是否应该限流模型并切换账号(仅 RATE_LIMIT_EXCEEDED 且 retryDelay >= 阈值
// - waitDuration: 等待时间
// - shouldRetry: 是否应该智能重试retryDelay < antigravityRateLimitThreshold
// - shouldRateLimitModel: 是否应该限流模型retryDelay >= antigravityRateLimitThreshold
// - waitDuration: 等待时间智能重试时使用shouldRateLimitModel=true 时为 0
// - modelName: 限流的模型名称
// - isModelCapacityExhausted: 是否为模型容量不足MODEL_CAPACITY_EXHAUSTED
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string, isModelCapacityExhausted bool) {
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) {
if account.Platform != PlatformAntigravity {
return false, false, 0, "", false
return false, false, 0, ""
}
info := parseAntigravitySmartRetryInfo(respBody)
if info == nil {
return false, false, 0, "", false
return false, false, 0, ""
}
// MODEL_CAPACITY_EXHAUSTED模型容量不足所有账号共享同一模型容量池
// 切换账号无意义,使用固定 1s 间隔重试
if info.IsModelCapacityExhausted {
return true, false, antigravityModelCapacityRetryWait, info.ModelName, true
}
// RATE_LIMIT_EXCEEDED账号级限流
// retryDelay >= 阈值:直接限流模型,不重试
// 注意:如果上游未提供 retryDelayparseAntigravitySmartRetryInfo 已设置为默认 30s
if info.RetryDelay >= antigravityRateLimitThreshold {
return false, true, info.RetryDelay, info.ModelName, false
return false, true, info.RetryDelay, info.ModelName
}
// retryDelay < 阈值:智能重试
@@ -2481,7 +2126,7 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou
waitDuration = antigravitySmartRetryMinWait
}
return true, false, waitDuration, info.ModelName, false
return true, false, waitDuration, info.ModelName
}
// handleModelRateLimitParams 模型级限流处理参数
@@ -2507,9 +2152,8 @@ type handleModelRateLimitResult struct {
// handleModelRateLimit 处理模型级限流(在原有逻辑之前调用)
// 仅处理 429/503解析模型名和 retryDelay
// - MODEL_CAPACITY_EXHAUSTED: 返回 Handled=true实际重试由 handleSmartRetry 处理)
// - RATE_LIMIT_EXCEEDED + retryDelay < 阈值: 返回 ShouldRetry=true由调用方等待后重试
// - RATE_LIMIT_EXCEEDED + retryDelay >= 阈值: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true由调用方等待后重试
// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult {
if p.statusCode != 429 && p.statusCode != 503 {
return &handleModelRateLimitResult{Handled: false}
@@ -2520,17 +2164,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
return &handleModelRateLimitResult{Handled: false}
}
// MODEL_CAPACITY_EXHAUSTED模型容量不足所有账号共享同一容量池
// 切换账号无意义,不设置模型限流(实际重试由 handleSmartRetry 处理)
if info.IsModelCapacityExhausted {
log.Printf("%s status=%d model_capacity_exhausted model=%s (not switching account, retry handled by smart retry)",
p.prefix, p.statusCode, info.ModelName)
return &handleModelRateLimitResult{
Handled: true,
}
}
// RATE_LIMIT_EXCEEDED: < antigravityRateLimitThreshold: 等待后重试
// < antigravityRateLimitThreshold: 等待后重试
if info.RetryDelay < antigravityRateLimitThreshold {
log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v",
p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
@@ -2541,7 +2175,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
}
}
// RATE_LIMIT_EXCEEDED: >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
// >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
s.setModelRateLimitAndClearSession(p, info)
return &handleModelRateLimitResult{
@@ -2608,10 +2242,6 @@ func (s *AntigravityGatewayService) handleUpstreamError(
requestedModel string,
groupID int64, sessionHash string, isStickySession bool,
) *handleModelRateLimitResult {
// 遵守自定义错误码策略:未命中则跳过所有限流处理
if !account.ShouldHandleErrorCode(statusCode) {
return nil
}
// 模型级限流处理(优先)
result := s.handleModelRateLimit(&handleModelRateLimitParams{
ctx: ctx,
@@ -3089,14 +2719,9 @@ returnResponse:
// 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况 — 触发同账号重试 + failover 切换账号
// 处理空响应情况
if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover")
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
RetryableOnSameAccount: true,
}
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
}
// 如果收集到了图片 parts需要合并到最终响应中
@@ -3314,21 +2939,6 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
}
// 检查错误透传规则
if ptStatus, ptErrType, ptErrMsg, matched := applyErrorPassthroughRule(
c, account.Platform, upstreamStatus, body,
0, "", "",
); matched {
c.JSON(ptStatus, gin.H{
"type": "error",
"error": gin.H{"type": ptErrType, "message": ptErrMsg},
})
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d", upstreamStatus)
}
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
}
var statusCode int
var errType, errMsg string
@@ -3524,14 +3134,10 @@ returnResponse:
// 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况 — 触发同账号重试 + failover 切换账号
// 处理空响应情况
if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover")
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
RetryableOnSameAccount: true,
}
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
}
// 将收集的所有 parts 合并到最终响应中
@@ -4111,15 +3717,6 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheCreationInputTokens = int(v)
}
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
if cc, ok := u["cache_creation"].(map[string]any); ok {
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
usage.CacheCreation5mTokens = int(v)
}
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
usage.CacheCreation1hTokens = int(v)
}
}
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
@@ -4142,15 +3739,6 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
usage.CacheCreationInputTokens = int(v)
}
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
if cc, ok := u["cache_creation"].(map[string]any); ok {
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
usage.CacheCreation5mTokens = int(v)
}
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
usage.CacheCreation1hTokens = int(v)
}
}
}
return usage
}

View File

@@ -553,75 +553,6 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
require.NotContains(t, body, "event: error")
}
// TestHandleGeminiStreamingResponse_ThoughtsTokenCount
// 验证Gemini 流式转发时 thoughtsTokenCount 被计入 OutputTokens
func TestHandleGeminiStreamingResponse_ThoughtsTokenCount(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":30,"thoughtsTokenCount":80,"cachedContentTokenCount":10}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
// promptTokenCount=100, cachedContentTokenCount=10 → InputTokens=90
require.Equal(t, 90, result.usage.InputTokens)
// candidatesTokenCount=30 + thoughtsTokenCount=80 → OutputTokens=110
require.Equal(t, 110, result.usage.OutputTokens)
require.Equal(t, 10, result.usage.CacheReadInputTokens)
}
// TestHandleClaudeStreamingResponse_ThoughtsTokenCount
// 验证Gemini→Claude 流式转换时 thoughtsTokenCount 被计入 OutputTokens
func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":10,"thoughtsTokenCount":25}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
// promptTokenCount=50 → InputTokens=50
require.Equal(t, 50, result.usage.InputTokens)
// candidatesTokenCount=10 + thoughtsTokenCount=25 → OutputTokens=35
require.Equal(t, 35, result.usage.OutputTokens)
}
// --- 流式客户端断开检测测试 ---
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage

View File

@@ -192,43 +192,6 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
}
// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id
func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64) (*AntigravityTokenInfo, error) {
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// 刷新 token
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, err
}
// 获取用户信息email
client := antigravity.NewClient(proxyURL)
userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
} else {
tokenInfo.Email = userInfo.Email
}
// 获取 project_id容错失败不阻塞
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
tokenInfo.ProjectIDMissing = true
} else {
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
func isNonRetryableAntigravityOAuthError(err error) bool {
msg := err.Error()
nonRetryable := []string{
@@ -310,21 +273,12 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
}
client := antigravity.NewClient(proxyURL)
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil
}
if err == nil {
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
return projectID, nil
} else if onboardErr != nil {
lastErr = onboardErr
continue
}
}
// 记录错误
if err != nil {
lastErr = err
@@ -338,65 +292,6 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
tierID := resolveDefaultTierID(loadRaw)
if tierID == "" {
return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
}
projectID, err := client.OnboardUser(ctx, accessToken, tierID)
if err != nil {
return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
}
return projectID, nil
}
func resolveDefaultTierID(loadRaw map[string]any) string {
if len(loadRaw) == 0 {
return ""
}
rawTiers, ok := loadRaw["allowedTiers"]
if !ok {
return ""
}
tiers, ok := rawTiers.([]any)
if !ok {
return ""
}
for _, rawTier := range tiers {
tier, ok := rawTier.(map[string]any)
if !ok {
continue
}
if isDefault, _ := tier["isDefault"].(bool); !isDefault {
continue
}
if id, ok := tier["id"].(string); ok {
id = strings.TrimSpace(id)
if id != "" {
return id
}
}
}
return ""
}
// FillProjectID 仅获取 project_id不刷新 OAuth token
func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) {
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
}
// BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{

View File

@@ -1,82 +0,0 @@
package service
import (
"testing"
)
func TestResolveDefaultTierID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
loadRaw map[string]any
want string
}{
{
name: "nil loadRaw",
loadRaw: nil,
want: "",
},
{
name: "missing allowedTiers",
loadRaw: map[string]any{
"paidTier": map[string]any{"id": "g1-pro-tier"},
},
want: "",
},
{
name: "empty allowedTiers",
loadRaw: map[string]any{"allowedTiers": []any{}},
want: "",
},
{
name: "tier missing id field",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"isDefault": true},
},
},
want: "",
},
{
name: "allowedTiers but no default",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": "free-tier", "isDefault": false},
map[string]any{"id": "standard-tier", "isDefault": false},
},
},
want: "",
},
{
name: "default tier found",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": "free-tier", "isDefault": true},
map[string]any{"id": "standard-tier", "isDefault": false},
},
},
want: "free-tier",
},
{
name: "default tier id with spaces",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": " standard-tier ", "isDefault": true},
},
},
want: "standard-tier",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := resolveDefaultTierID(tc.loadRaw)
if got != tc.want {
t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
}
})
}
}

View File

@@ -86,9 +86,7 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i
return nil
}
func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
t.Setenv(antigravityForwardBaseURLEnv, "")
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
@@ -133,16 +131,15 @@ func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
require.True(t, handleErrorCalled)
require.Len(t, upstream.calls, antigravityMaxRetries)
for _, callURL := range upstream.calls {
require.True(t, strings.HasPrefix(callURL, base1))
}
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.False(t, handleErrorCalled)
require.Len(t, upstream.calls, 2)
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
require.NotEmpty(t, available)
require.Equal(t, base1, available[0])
require.Equal(t, base2, available[0])
}
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
@@ -191,14 +188,13 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity}
// 503 + MODEL_CAPACITY_EXHAUSTED → 等待重试,不切换账号
// 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流
body := []byte(`{
"error": {
"status": "UNAVAILABLE",
@@ -211,13 +207,13 @@ func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// MODEL_CAPACITY_EXHAUSTED 应该标记为已处理,不切换账号,不设置模型限流
// 实际重试由 handleSmartRetry 处理
// 应该触发模型限流
require.NotNil(t, result)
require.True(t, result.Handled)
require.False(t, result.ShouldRetry, "MODEL_CAPACITY_EXHAUSTED should not trigger retry from handleModelRateLimit path")
require.Nil(t, result.SwitchError, "MODEL_CAPACITY_EXHAUSTED should not trigger account switch")
require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
require.NotNil(t, result.SwitchError)
require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理)
@@ -305,12 +301,11 @@ func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) {
func TestParseAntigravitySmartRetryInfo(t *testing.T) {
tests := []struct {
name string
body string
expectedDelay time.Duration
expectedModel string
expectedNil bool
expectedIsModelCapacityExhausted bool
name string
body string
expectedDelay time.Duration
expectedModel string
expectedNil bool
}{
{
name: "valid complete response with RATE_LIMIT_EXCEEDED",
@@ -373,9 +368,8 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
"message": "No capacity available for model gemini-3-pro-high on the server"
}
}`,
expectedDelay: 39 * time.Second,
expectedModel: "gemini-3-pro-high",
expectedIsModelCapacityExhausted: true,
expectedDelay: 39 * time.Second,
expectedModel: "gemini-3-pro-high",
},
{
name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil",
@@ -486,9 +480,6 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
if result.ModelName != tt.expectedModel {
t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel)
}
if result.IsModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
t.Errorf("IsModelCapacityExhausted = %v, want %v", result.IsModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
}
})
}
}
@@ -500,14 +491,13 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
apiKeyAccount := &Account{Type: AccountTypeAPIKey}
tests := []struct {
name string
account *Account
body string
expectedShouldRetry bool
expectedShouldRateLimit bool
expectedIsModelCapacityExhausted bool
minWait time.Duration
modelName string
name string
account *Account
body string
expectedShouldRetry bool
expectedShouldRateLimit bool
minWait time.Duration
modelName string
}{
{
name: "OAuth account with short delay (< 7s) - smart retry",
@@ -621,14 +611,13 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
]
}
}`,
expectedShouldRetry: true,
expectedShouldRateLimit: false,
expectedIsModelCapacityExhausted: true,
minWait: 1 * time.Second,
modelName: "gemini-3-pro-high",
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 39 * time.Second,
modelName: "gemini-3-pro-high",
},
{
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use fixed wait",
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit",
account: oauthAccount,
body: `{
"error": {
@@ -640,11 +629,10 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
"message": "No capacity available for model gemini-2.5-flash on the server"
}
}`,
expectedShouldRetry: true,
expectedShouldRateLimit: false,
expectedIsModelCapacityExhausted: true,
minWait: 1 * time.Second,
modelName: "gemini-2.5-flash",
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 30 * time.Second,
modelName: "gemini-2.5-flash",
},
{
name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit",
@@ -668,16 +656,13 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldRetry, shouldRateLimit, wait, model, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
if shouldRetry != tt.expectedShouldRetry {
t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry)
}
if shouldRateLimit != tt.expectedShouldRateLimit {
t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit)
}
if isModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
t.Errorf("isModelCapacityExhausted = %v, want %v", isModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
}
if shouldRetry {
if wait < tt.minWait {
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
@@ -930,22 +915,6 @@ func TestIsAntigravityAccountSwitchError(t *testing.T) {
}
}
func TestResolveAntigravityForwardBaseURL_DefaultDaily(t *testing.T) {
t.Setenv(antigravityForwardBaseURLEnv, "")
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
defer func() {
antigravity.BaseURLs = oldBaseURLs
}()
prodURL := "https://prod.test"
dailyURL := "https://daily.test"
antigravity.BaseURLs = []string{dailyURL, prodURL}
resolved := resolveAntigravityForwardBaseURL()
require.Equal(t, dailyURL, resolved)
}
func TestAntigravityAccountSwitchError_Error(t *testing.T) {
err := &AntigravityAccountSwitchError{
OriginalAccountID: 789,

View File

@@ -1,904 +0,0 @@
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// 辅助函数:构造带 SingleAccountRetry 标记的 context
// ---------------------------------------------------------------------------
func ctxWithSingleAccountRetry() context.Context {
return context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
}
// ---------------------------------------------------------------------------
// 1. isSingleAccountRetry 测试
// ---------------------------------------------------------------------------
func TestIsSingleAccountRetry_True(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
require.True(t, isSingleAccountRetry(ctx))
}
func TestIsSingleAccountRetry_False_NoValue(t *testing.T) {
require.False(t, isSingleAccountRetry(context.Background()))
}
func TestIsSingleAccountRetry_False_ExplicitFalse(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, false)
require.False(t, isSingleAccountRetry(ctx))
}
func TestIsSingleAccountRetry_False_WrongType(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, "true")
require.False(t, isSingleAccountRetry(ctx))
}
// ---------------------------------------------------------------------------
// 2. 常量验证
// ---------------------------------------------------------------------------
func TestSingleAccountRetryConstants(t *testing.T) {
require.Equal(t, 3, antigravitySingleAccountSmartRetryMaxAttempts,
"单账号原地重试最多 3 次")
require.Equal(t, 15*time.Second, antigravitySingleAccountSmartRetryMaxWait,
"单次最大等待 15s")
require.Equal(t, 30*time.Second, antigravitySingleAccountSmartRetryTotalMaxWait,
"总累计等待不超过 30s")
}
// ---------------------------------------------------------------------------
// 3. handleSmartRetry + 503 + SingleAccountRetry → 走 handleSingleAccountRetryInPlace
// (而非设模型限流 + 切换账号)
// ---------------------------------------------------------------------------
// TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace
// 核心场景503 + retryDelay >= 7s + SingleAccountRetry 标记
// → 不设模型限流、不切换账号,改为原地重试
func TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace(t *testing.T) {
// 原地重试成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 1,
Name: "acc-single",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
// 503 + 39s >= 7s 阈值 + MODEL_CAPACITY_EXHAUSTED
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
],
"message": "No capacity available for model gemini-3-pro-high on the server"
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(), // 关键:设置单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 关键断言:返回 resp原地重试成功而非 switchError切换账号
require.NotNil(t, result.resp, "should return successful response from in-place retry")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should NOT return switchError in single account mode")
require.Nil(t, result.err)
// 验证未设模型限流(单账号模式不应设限流)
require.Len(t, repo.modelRateLimitCalls, 0,
"should NOT set model rate limit in single account retry mode")
// 验证确实调用了 upstream原地重试
require.GreaterOrEqual(t, len(upstream.calls), 1, "should have made at least one retry call")
}
// TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches
// 对照组503 + retryDelay >= 7s + 无 SingleAccountRetry 标记
// → 照常设模型限流 + 切换账号
func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 2,
Name: "acc-multi",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 503 + 39s >= 7s 阈值(使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED
// 因为 MODEL_CAPACITY_EXHAUSTED 走独立的重试路径,不触发 shouldRateLimitModel
respBody := []byte(`{
"error": {
"code": 503,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(), // 关键:无单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 对照:多账号模式返回 switchError
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
require.Nil(t, result.resp, "should not return resp when switchError is set")
// 对照:多账号模式应设模型限流
require.Len(t, repo.modelRateLimitCalls, 1,
"multi-account mode SHOULD set model rate limit")
}
// TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches
// 边界情况429非 503+ SingleAccountRetry 标记
// → 单账号原地重试仅针对 503429 依然走切换账号逻辑
func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
Name: "acc-429",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 429 + 15s >= 7s 阈值
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests, // 429不是 503
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(), // 有单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 429 即使有单账号标记,也应走切换账号
require.NotNil(t, result.switchError, "429 should still return switchError even with SingleAccountRetry")
require.Len(t, repo.modelRateLimitCalls, 1,
"429 should still set model rate limit even with SingleAccountRetry")
}
// ---------------------------------------------------------------------------
// 4. handleSmartRetry + 503 + 短延迟 + SingleAccountRetry → 智能重试耗尽后不设限流
// ---------------------------------------------------------------------------
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503不设限流
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
// 智能重试也返回 503
failRespBody := `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 4,
Name: "acc-short-503",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 0.1s < 7s 阈值
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 关键断言:单账号 503 模式下,智能重试耗尽后直接返回 503 响应,不切换
require.NotNil(t, result.resp, "should return 503 response directly for single account mode")
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
require.Nil(t, result.switchError, "should NOT switch account in single account mode")
// 关键断言:不设模型限流
require.Len(t, repo.modelRateLimitCalls, 0,
"should NOT set model rate limit for 503 in single account mode")
}
// TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit
// 对照组503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流
// 使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED因为后者走独立的 60 次重试路径
func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) {
failRespBody := `{
"error": {
"code": 503,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 5,
Name: "acc-multi-503",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"code": 503,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(), // 无单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 对照:多账号模式应返回 switchError
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
// 对照:多账号模式应设模型限流
require.Len(t, repo.modelRateLimitCalls, 1,
"multi-account mode should set model rate limit")
}
// ---------------------------------------------------------------------------
// 5. handleSingleAccountRetryInPlace 直接测试
// ---------------------------------------------------------------------------
// TestHandleSingleAccountRetryInPlace_Success 原地重试成功
func TestHandleSingleAccountRetryInPlace_Success(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
account := &Account{
ID: 10,
Name: "acc-inplace-ok",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not switch account on success")
require.Nil(t, result.err)
}
// TestHandleSingleAccountRetryInPlace_AllRetriesFail 所有重试都失败,返回 503不设限流
func TestHandleSingleAccountRetryInPlace_AllRetriesFail(t *testing.T) {
// 构造 3 个 503 响应(对应 3 次原地重试)
var responses []*http.Response
var errors []error
for i := 0; i < antigravitySingleAccountSmartRetryMaxAttempts; i++ {
responses = append(responses, &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)),
})
errors = append(errors, nil)
}
upstream := &mockSmartRetryUpstream{
responses: responses,
errors: errors,
}
account := &Account{
ID: 11,
Name: "acc-inplace-fail",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
origBody := []byte(`{"error":{"code":503,"status":"UNAVAILABLE"}}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{"X-Test": {"original"}},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, origBody, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 关键:返回 503 resp不返回 switchError
require.NotNil(t, result.resp, "should return 503 response directly")
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
require.Nil(t, result.switchError, "should NOT return switchError - let Handler handle it")
require.Nil(t, result.err)
// 验证确实重试了指定次数
require.Len(t, upstream.calls, antigravitySingleAccountSmartRetryMaxAttempts,
"should have made exactly maxAttempts retry calls")
}
// TestHandleSingleAccountRetryInPlace_WaitDurationClamped 等待时间被限制在 [min, max] 范围
func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) {
// 用短延迟的成功响应,只验证不 panic
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
account := &Account{
ID: 12,
Name: "acc-clamp",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
// 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)
require.Equal(t, http.StatusOK, result.resp.StatusCode)
}
// TestHandleSingleAccountRetryInPlace_ContextCanceled context 取消时立即返回
func TestHandleSingleAccountRetryInPlace_ContextCanceled(t *testing.T) {
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil},
errors: []error{nil},
}
account := &Account{
ID: 13,
Name: "acc-cancel",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
ctx, cancel := context.WithCancel(context.Background())
ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true)
cancel() // 立即取消
params := antigravityRetryLoopParams{
ctx: ctx,
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Error(t, result.err, "should return context error")
// 不应调用 upstream因为在等待阶段就被取消了
require.Len(t, upstream.calls, 0, "should not call upstream when context is canceled")
}
// TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry 网络错误时继续重试
func TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
// 第1次网络错误nil resp第2次成功
responses: []*http.Response{nil, successResp},
errors: []error{nil, nil},
}
account := &Account{
ID: 14,
Name: "acc-net-retry",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after network error recovery")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Len(t, upstream.calls, 2, "first call fails (network error), second succeeds")
}
// ---------------------------------------------------------------------------
// 6. antigravityRetryLoop 预检查:单账号模式跳过限流
// ---------------------------------------------------------------------------
// TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit
// 预检查中,如果有 SingleAccountRetry 标记,即使账号已限流也跳过直接发请求
func TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit(t *testing.T) {
// 创建一个已设模型限流的账号
upstream := &recordingOKUpstream{}
account := &Account{
ID: 20,
Name: "acc-rate-limited",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err, "should not return error")
require.NotNil(t, result, "should return result")
require.NotNil(t, result.resp, "should have response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
// 关键:尽管限流了,有 SingleAccountRetry 标记时仍然到达了 upstream
require.Equal(t, 1, upstream.calls, "should have reached upstream despite rate limit")
}
// TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit
// 对照组:无 SingleAccountRetry + 已限流 → 预检查返回 switchError
func TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 21,
Name: "acc-rate-limited-multi",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(), // 无单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result, "should not return result on rate limit switch")
require.NotNil(t, err, "should return error")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr, "should return AntigravityAccountSwitchError")
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
// upstream 不应被调用(预检查就短路了)
require.Equal(t, 0, upstream.calls, "upstream should NOT be called when pre-check blocks")
}
// ---------------------------------------------------------------------------
// 7. 端到端集成场景测试
// ---------------------------------------------------------------------------
// TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E
// 端到端场景503 + 单账号 + 原地重试第2次成功
func TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E(t *testing.T) {
// 第1次原地重试仍返回 503第2次成功
fail503Body := `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
resp503 := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(fail503Body)),
}
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{resp503, successResp},
errors: []error{nil, nil},
}
account := &Account{
ID: 30,
Name: "acc-e2e",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after 2nd attempt")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError)
require.Len(t, upstream.calls, 2, "first 503, second OK")
}
// TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E
// 通过 antigravityRetryLoop → handleSmartRetry → handleSingleAccountRetryInPlace 完整链路
func TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E(t *testing.T) {
// 初始请求返回 503 + 长延迟
initial503Body := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "10s"}
],
"message": "No capacity available"
}
}`)
initial503Resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(initial503Body)),
}
// 原地重试成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
// 第1次调用retryLoop 主循环)返回 503
// 第2次调用handleSingleAccountRetryInPlace 原地重试)返回 200
responses: []*http.Response{initial503Resp, successResp},
errors: []error{nil, nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 31,
Name: "acc-e2e-loop",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err, "should not return error on successful retry")
require.NotNil(t, result, "should return result")
require.NotNil(t, result.resp, "should return response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
// 验证未设模型限流
require.Len(t, repo.modelRateLimitCalls, 0,
"should NOT set model rate limit in single account retry mode")
}

View File

@@ -294,9 +294,8 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)")
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess 测试 503 MODEL_CAPACITY_EXHAUSTED 重试成功
// MODEL_CAPACITY_EXHAUSTED 使用固定 1s 间隔重试,不切换账号
func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T) {
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError
func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
@@ -305,7 +304,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T)
Platform: PlatformAntigravity,
}
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s(上游 retryDelay 应被忽略,使用固定 1s
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值
respBody := []byte(`{
"error": {
"code": 503,
@@ -323,14 +322,6 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T)
Body: io.NopCloser(bytes.NewReader(respBody)),
}
// mock: 第 1 次重试返回 200 成功
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{
{StatusCode: http.StatusOK, Header: http.Header{}, Body: io.NopCloser(strings.NewReader(`{"ok":true}`))},
},
errors: []error{nil},
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
@@ -339,7 +330,6 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T)
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
@@ -353,67 +343,16 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.resp)
require.Nil(t, result.err)
require.Nil(t, result.switchError, "MODEL_CAPACITY_EXHAUSTED should not return switchError")
require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 不应设置模型限流
require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
require.Len(t, upstream.calls, 1, "should have made one retry call before success")
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel 测试 MODEL_CAPACITY_EXHAUSTED 上下文取消
func TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
Name: "acc-3",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
// 立即取消上下文,验证重试循环能正确退出
ctx, cancel := context.WithCancel(context.Background())
cancel()
params := antigravityRetryLoopParams{
ctx: ctx,
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Error(t, result.err, "should return context error")
require.Nil(t, result.switchError, "should not return switchError on context cancel")
require.Empty(t, repo.modelRateLimitCalls, "should not set model rate limit on context cancel")
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑
@@ -1190,20 +1129,20 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t
}
// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
// 429 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
// 503 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) {
failRespBody := `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
@@ -1223,16 +1162,16 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
respBody := []byte(`{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}

View File

@@ -7,14 +7,12 @@ import (
"log/slog"
"strconv"
"strings"
"sync"
"time"
)
const (
antigravityTokenRefreshSkew = 3 * time.Minute
antigravityTokenCacheSkew = 5 * time.Minute
antigravityBackfillCooldown = 5 * time.Minute
)
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
@@ -25,7 +23,6 @@ type AntigravityTokenProvider struct {
accountRepo AccountRepository
tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
}
func NewAntigravityTokenProvider(
@@ -96,7 +93,13 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if err != nil {
return "", err
}
p.mergeCredentials(account, tokenInfo)
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
}
@@ -110,21 +113,6 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("access_token not found in credentials")
}
// 如果账号还没有 project_id尝试在线补齐避免请求 daily/sandbox 时出现
// "Invalid project resource name projects/"。
// 仅调用 loadProjectIDWithRetry不刷新 OAuth token带冷却机制防止频繁重试。
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
if p.shouldAttemptBackfill(account.ID) {
p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
}
}
}
}
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
@@ -156,31 +144,6 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil
}
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
}
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id冷却期内不重复尝试
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
if v, ok := p.backfillCooldown.Load(accountID); ok {
if lastAttempt, ok := v.(time.Time); ok {
return time.Since(lastAttempt) > antigravityBackfillCooldown
}
}
return true
}
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
p.backfillCooldown.Store(accountID, time.Now())
}
func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {

View File

@@ -31,8 +31,8 @@ type ModelPricing struct {
OutputPricePerToken float64 // 每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建价格每百万token- 仅用于硬编码回退
CacheCreation1hPrice float64 // 1小时缓存创建价格每百万token- 仅用于硬编码回退
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
}
@@ -172,20 +172,12 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
if s.pricingService != nil {
litellmPricing := s.pricingService.GetModelPricing(model)
if litellmPricing != nil {
// 启用 5m/1h 分类计费的条件:
// 1. 存在 1h 价格
// 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费)
price5m := litellmPricing.CacheCreationInputTokenCost
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
enableBreakdown := price1h > 0 && price1h > price5m
return &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken,
OutputPricePerToken: litellmPricing.OutputCostPerToken,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
CacheCreation5mPrice: price5m,
CacheCreation1hPrice: price1h,
SupportsCacheBreakdown: enableBreakdown,
SupportsCacheBreakdown: false,
}, nil
}
}
@@ -217,14 +209,9 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
// 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
// 支持详细缓存分类的模型5分钟/1小时缓存,价格为 per-token
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
} else {
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
}
// 支持详细缓存分类的模型5分钟/1小时缓存
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
} else {
// 标准缓存创建价格per-token
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
@@ -293,12 +280,10 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
// 范围内部分:正常计费
inRangeTokens := UsageTokens{
InputTokens: inRangeInputTokens,
OutputTokens: tokens.OutputTokens, // 输出只算一次
CacheCreationTokens: tokens.CacheCreationTokens,
CacheReadTokens: inRangeCacheTokens,
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
InputTokens: inRangeInputTokens,
OutputTokens: tokens.OutputTokens, // 输出只算一次
CacheCreationTokens: tokens.CacheCreationTokens,
CacheReadTokens: inRangeCacheTokens,
}
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil {

View File

@@ -1,112 +0,0 @@
package service
import (
"testing"
)
func TestBuildSelectedSet(t *testing.T) {
tests := []struct {
name string
ids []string
wantNil bool
wantSize int
}{
{
name: "nil input returns nil (backward compatible: create all)",
ids: nil,
wantNil: true,
},
{
name: "empty slice returns empty map (create none)",
ids: []string{},
wantNil: false,
wantSize: 0,
},
{
name: "single ID",
ids: []string{"abc-123"},
wantNil: false,
wantSize: 1,
},
{
name: "multiple IDs",
ids: []string{"a", "b", "c"},
wantNil: false,
wantSize: 3,
},
{
name: "duplicate IDs are deduplicated",
ids: []string{"a", "a", "b"},
wantNil: false,
wantSize: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildSelectedSet(tt.ids)
if tt.wantNil {
if got != nil {
t.Errorf("buildSelectedSet(%v) = %v, want nil", tt.ids, got)
}
return
}
if got == nil {
t.Fatalf("buildSelectedSet(%v) = nil, want non-nil map", tt.ids)
}
if len(got) != tt.wantSize {
t.Errorf("buildSelectedSet(%v) has %d entries, want %d", tt.ids, len(got), tt.wantSize)
}
// Verify all unique IDs are present
for _, id := range tt.ids {
if _, ok := got[id]; !ok {
t.Errorf("buildSelectedSet(%v) missing key %q", tt.ids, id)
}
}
})
}
}
func TestShouldCreateAccount(t *testing.T) {
tests := []struct {
name string
crsID string
selectedSet map[string]struct{}
want bool
}{
{
name: "nil set allows all (backward compatible)",
crsID: "any-id",
selectedSet: nil,
want: true,
},
{
name: "empty set blocks all",
crsID: "any-id",
selectedSet: map[string]struct{}{},
want: false,
},
{
name: "ID in set is allowed",
crsID: "abc-123",
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
want: true,
},
{
name: "ID not in set is blocked",
crsID: "xyz-789",
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := shouldCreateAccount(tt.crsID, tt.selectedSet)
if got != tt.want {
t.Errorf("shouldCreateAccount(%q, %v) = %v, want %v",
tt.crsID, tt.selectedSet, got, tt.want)
}
})
}
}

View File

@@ -45,11 +45,10 @@ func NewCRSSyncService(
}
type SyncFromCRSInput struct {
BaseURL string
Username string
Password string
SyncProxies bool
SelectedAccountIDs []string // if non-empty, only create new accounts with these CRS IDs
BaseURL string
Username string
Password string
SyncProxies bool
}
type SyncFromCRSItemResult struct {
@@ -191,27 +190,25 @@ type crsGeminiAPIKeyAccount struct {
Extra map[string]any `json:"extra"`
}
// fetchCRSExport validates the connection parameters, authenticates with CRS,
// and returns the exported accounts. Shared by SyncFromCRS and PreviewFromCRS.
func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, password string) (*crsExportResponse, error) {
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
if s.cfg == nil {
return nil, errors.New("config is not available")
}
normalizedURL := strings.TrimSpace(baseURL)
baseURL := strings.TrimSpace(input.BaseURL)
if s.cfg.Security.URLAllowlist.Enabled {
normalized, err := normalizeBaseURL(normalizedURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
normalizedURL = normalized
baseURL = normalized
} else {
normalized, err := urlvalidator.ValidateURLFormat(normalizedURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
normalizedURL = normalized
baseURL = normalized
}
if strings.TrimSpace(username) == "" || strings.TrimSpace(password) == "" {
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
return nil, errors.New("username and password are required")
}
@@ -224,16 +221,12 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username,
client = &http.Client{Timeout: 20 * time.Second}
}
adminToken, err := crsLogin(ctx, client, normalizedURL, username, password)
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
if err != nil {
return nil, err
}
return crsExportAccounts(ctx, client, normalizedURL, adminToken)
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password)
exported, err := crsExportAccounts(ctx, client, baseURL, adminToken)
if err != nil {
return nil, err
}
@@ -248,8 +241,6 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
),
}
selectedSet := buildSelectedSet(input.SelectedAccountIDs)
var proxies []Proxy
if input.SyncProxies {
proxies, _ = s.proxyRepo.ListActive(ctx)
@@ -338,13 +329,6 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
@@ -462,13 +446,6 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
@@ -592,13 +569,6 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
@@ -720,13 +690,6 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
@@ -835,13 +798,6 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
@@ -953,13 +909,6 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
@@ -1304,102 +1253,3 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account
return newCredentials
}
// buildSelectedSet converts a slice of selected CRS account IDs to a set for O(1) lookup.
// Returns nil if ids is nil (field not sent → backward compatible: create all).
// Returns an empty map if ids is non-nil but empty (user selected none → create none).
func buildSelectedSet(ids []string) map[string]struct{} {
if ids == nil {
return nil
}
set := make(map[string]struct{}, len(ids))
for _, id := range ids {
set[id] = struct{}{}
}
return set
}
// shouldCreateAccount checks if a new CRS account should be created based on user selection.
// Returns true if selectedSet is nil (backward compatible: create all) or if crsID is in the set.
func shouldCreateAccount(crsID string, selectedSet map[string]struct{}) bool {
if selectedSet == nil {
return true
}
_, ok := selectedSet[crsID]
return ok
}
// PreviewFromCRSResult contains the preview of accounts from CRS before sync.
type PreviewFromCRSResult struct {
NewAccounts []CRSPreviewAccount `json:"new_accounts"`
ExistingAccounts []CRSPreviewAccount `json:"existing_accounts"`
}
// CRSPreviewAccount represents a single account in the preview result.
type CRSPreviewAccount struct {
CRSAccountID string `json:"crs_account_id"`
Kind string `json:"kind"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
}
// PreviewFromCRS connects to CRS, fetches all accounts, and classifies them
// as new or existing by batch-querying local crs_account_id mappings.
func (s *CRSSyncService) PreviewFromCRS(ctx context.Context, input SyncFromCRSInput) (*PreviewFromCRSResult, error) {
exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password)
if err != nil {
return nil, err
}
// Batch query all existing CRS account IDs
existingCRSIDs, err := s.accountRepo.ListCRSAccountIDs(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list existing CRS accounts: %w", err)
}
result := &PreviewFromCRSResult{
NewAccounts: make([]CRSPreviewAccount, 0),
ExistingAccounts: make([]CRSPreviewAccount, 0),
}
classify := func(crsID, kind, name, platform, accountType string) {
preview := CRSPreviewAccount{
CRSAccountID: crsID,
Kind: kind,
Name: defaultName(name, crsID),
Platform: platform,
Type: accountType,
}
if _, exists := existingCRSIDs[crsID]; exists {
result.ExistingAccounts = append(result.ExistingAccounts, preview)
} else {
result.NewAccounts = append(result.NewAccounts, preview)
}
}
for _, src := range exported.Data.ClaudeAccounts {
authType := strings.TrimSpace(src.AuthType)
if authType == "" {
authType = AccountTypeOAuth
}
classify(src.ID, src.Kind, src.Name, PlatformAnthropic, authType)
}
for _, src := range exported.Data.ClaudeConsoleAccounts {
classify(src.ID, src.Kind, src.Name, PlatformAnthropic, AccountTypeAPIKey)
}
for _, src := range exported.Data.OpenAIOAuthAccounts {
classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeOAuth)
}
for _, src := range exported.Data.OpenAIResponsesAccounts {
classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeAPIKey)
}
for _, src := range exported.Data.GeminiOAuthAccounts {
classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeOAuth)
}
for _, src := range exported.Data.GeminiAPIKeyAccounts {
classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeAPIKey)
}
return result, nil
}

View File

@@ -61,11 +61,6 @@ func applyErrorPassthroughRule(
errMsg = *rule.CustomMessage
}
// 命中 skip_monitoring 时在 context 中标记,供 ops_error_logger 跳过记录。
if rule.SkipMonitoring {
c.Set(OpsSkipPassthroughKey, true)
}
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
errType = "upstream_error"
return status, errType, errMsg, true

View File

@@ -194,63 +194,6 @@ func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) {
assert.Equal(t, "Gemini上游失败", errField["message"])
}
func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
rule.SkipMonitoring = true
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
BindErrorPassthroughService(c, ruleSvc)
_, _, _, matched := applyErrorPassthroughRule(
c,
PlatformAnthropic,
http.StatusBadRequest,
[]byte(`{"error":{"message":"prompt is too long"}}`),
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
)
assert.True(t, matched)
v, exists := c.Get(OpsSkipPassthroughKey)
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
boolVal, ok := v.(bool)
assert.True(t, ok, "value should be bool")
assert.True(t, boolVal)
}
func TestApplyErrorPassthroughRule_NoSkipMonitoringDoesNotSetContextKey(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
rule.SkipMonitoring = false
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
BindErrorPassthroughService(c, ruleSvc)
_, _, _, matched := applyErrorPassthroughRule(
c,
PlatformAnthropic,
http.StatusBadRequest,
[]byte(`{"error":{"message":"prompt is too long"}}`),
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
)
assert.True(t, matched)
_, exists := c.Get(OpsSkipPassthroughKey)
assert.False(t, exists, "OpsSkipPassthroughKey should NOT be set when skip_monitoring=false")
}
func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule {
return &model.ErrorPassthroughRule{
ID: 1,

View File

@@ -45,20 +45,10 @@ type ErrorPassthroughService struct {
cache ErrorPassthroughCache
// 本地内存缓存,用于快速匹配
localCache []*cachedPassthroughRule
localCache []*model.ErrorPassthroughRule
localCacheMu sync.RWMutex
}
// cachedPassthroughRule 预计算的规则缓存,避免运行时重复 ToLower
type cachedPassthroughRule struct {
*model.ErrorPassthroughRule
lowerKeywords []string // 预计算的小写关键词
lowerPlatforms []string // 预计算的小写平台
errorCodeSet map[int]struct{} // 预计算的 error code set
}
const maxBodyMatchLen = 8 << 10 // 8KB错误信息不会在 8KB 之后才出现
// NewErrorPassthroughService 创建错误透传规则服务
func NewErrorPassthroughService(
repo ErrorPassthroughRepository,
@@ -160,19 +150,17 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
return nil
}
lowerPlatform := strings.ToLower(platform)
var bodyLower string // 延迟初始化,只在需要关键词匹配时计算
var bodyLowerDone bool
bodyStr := strings.ToLower(string(body))
for _, rule := range rules {
if !rule.Enabled {
continue
}
if !s.platformMatchesCached(rule, lowerPlatform) {
if !s.platformMatches(rule, platform) {
continue
}
if s.ruleMatchesOptimized(rule, statusCode, body, &bodyLower, &bodyLowerDone) {
return rule.ErrorPassthroughRule
if s.ruleMatches(rule, statusCode, bodyStr) {
return rule
}
}
@@ -180,7 +168,7 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
}
// getCachedRules 获取缓存的规则列表(按优先级排序)
func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule {
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
s.localCacheMu.RLock()
rules := s.localCache
s.localCacheMu.RUnlock()
@@ -235,39 +223,17 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
return nil
}
// setLocalCache 设置本地缓存,预计算小写值和 set 以避免运行时重复计算
// setLocalCache 设置本地缓存
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
cached := make([]*cachedPassthroughRule, len(rules))
for i, r := range rules {
cr := &cachedPassthroughRule{ErrorPassthroughRule: r}
if len(r.Keywords) > 0 {
cr.lowerKeywords = make([]string, len(r.Keywords))
for j, kw := range r.Keywords {
cr.lowerKeywords[j] = strings.ToLower(kw)
}
}
if len(r.Platforms) > 0 {
cr.lowerPlatforms = make([]string, len(r.Platforms))
for j, p := range r.Platforms {
cr.lowerPlatforms[j] = strings.ToLower(p)
}
}
if len(r.ErrorCodes) > 0 {
cr.errorCodeSet = make(map[int]struct{}, len(r.ErrorCodes))
for _, code := range r.ErrorCodes {
cr.errorCodeSet[code] = struct{}{}
}
}
cached[i] = cr
}
// 按优先级排序
sort.Slice(cached, func(i, j int) bool {
return cached[i].Priority < cached[j].Priority
sorted := make([]*model.ErrorPassthroughRule, len(rules))
copy(sorted, rules)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
})
s.localCacheMu.Lock()
s.localCache = cached
s.localCache = sorted
s.localCacheMu.Unlock()
}
@@ -307,79 +273,62 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
}
}
// ensureBodyLower 延迟初始化 body 的小写版本,只做一次转换,限制 8KB
func ensureBodyLower(body []byte, bodyLower *string, done *bool) string {
if *done {
return *bodyLower
}
b := body
if len(b) > maxBodyMatchLen {
b = b[:maxBodyMatchLen]
}
*bodyLower = strings.ToLower(string(b))
*done = true
return *bodyLower
}
// platformMatchesCached 使用预计算的小写平台检查是否匹配
func (s *ErrorPassthroughService) platformMatchesCached(rule *cachedPassthroughRule, lowerPlatform string) bool {
if len(rule.lowerPlatforms) == 0 {
// platformMatches 检查平台是否匹配
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
// 如果没有配置平台限制,则匹配所有平台
if len(rule.Platforms) == 0 {
return true
}
for _, p := range rule.lowerPlatforms {
if p == lowerPlatform {
platform = strings.ToLower(platform)
for _, p := range rule.Platforms {
if strings.ToLower(p) == platform {
return true
}
}
return false
}
// ruleMatchesOptimized 优化的规则匹配,支持短路和延迟 body 转换
func (s *ErrorPassthroughService) ruleMatchesOptimized(rule *cachedPassthroughRule, statusCode int, body []byte, bodyLower *string, bodyLowerDone *bool) bool {
hasErrorCodes := len(rule.errorCodeSet) > 0
hasKeywords := len(rule.lowerKeywords) > 0
// ruleMatches 检查规则是否匹配
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
hasErrorCodes := len(rule.ErrorCodes) > 0
hasKeywords := len(rule.Keywords) > 0
// 如果没有配置任何条件,不匹配
if !hasErrorCodes && !hasKeywords {
return false
}
codeMatch := !hasErrorCodes || s.containsIntSet(rule.errorCodeSet, statusCode)
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
if rule.MatchMode == model.MatchModeAll {
// "all" 模式:所有配置的条件都必须满足,短路
if hasErrorCodes && !codeMatch {
return false
}
if hasKeywords {
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
}
return codeMatch
// "all" 模式:所有配置的条件都必须满足
return codeMatch && keywordMatch
}
// "any" 模式:任一条件满足即可,短路
// "any" 模式:任一条件满足即可
if hasErrorCodes && hasKeywords {
if codeMatch {
return true
}
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
return codeMatch || keywordMatch
}
// 只配置了一种条件
if hasKeywords {
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
}
return codeMatch
return codeMatch && keywordMatch
}
// containsIntSet 使用 map 查找替代线性扫描
func (s *ErrorPassthroughService) containsIntSet(set map[int]struct{}, val int) bool {
_, ok := set[val]
return ok
}
// containsAnyKeywordCached 使用预计算的小写关键词检查匹配
func (s *ErrorPassthroughService) containsAnyKeywordCached(bodyLower string, lowerKeywords []string) bool {
for _, kw := range lowerKeywords {
if strings.Contains(bodyLower, kw) {
// containsInt 检查切片是否包含指定整数
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
for _, v := range slice {
if v == val {
return true
}
}
return false
}
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
for _, kw := range keywords {
if strings.Contains(bodyLower, strings.ToLower(kw)) {
return true
}
}

View File

@@ -145,58 +145,32 @@ func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughServic
return svc
}
// newCachedRuleForTest 从 model.ErrorPassthroughRule 创建 cachedPassthroughRule测试用
func newCachedRuleForTest(rule *model.ErrorPassthroughRule) *cachedPassthroughRule {
cr := &cachedPassthroughRule{ErrorPassthroughRule: rule}
if len(rule.Keywords) > 0 {
cr.lowerKeywords = make([]string, len(rule.Keywords))
for j, kw := range rule.Keywords {
cr.lowerKeywords[j] = strings.ToLower(kw)
}
}
if len(rule.Platforms) > 0 {
cr.lowerPlatforms = make([]string, len(rule.Platforms))
for j, p := range rule.Platforms {
cr.lowerPlatforms[j] = strings.ToLower(p)
}
}
if len(rule.ErrorCodes) > 0 {
cr.errorCodeSet = make(map[int]struct{}, len(rule.ErrorCodes))
for _, code := range rule.ErrorCodes {
cr.errorCodeSet[code] = struct{}{}
}
}
return cr
}
// =============================================================================
// 测试 ruleMatchesOptimized 核心匹配逻辑
// 测试 ruleMatches 核心匹配逻辑
// =============================================================================
func TestRuleMatches_NoConditions(t *testing.T) {
// 没有配置任何条件时,不应该匹配
svc := newTestService(nil)
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{},
MatchMode: model.MatchModeAny,
})
}
var bodyLower string
var bodyLowerDone bool
assert.False(t, svc.ruleMatchesOptimized(rule, 422, []byte("some error message"), &bodyLower, &bodyLowerDone),
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
"没有配置条件时不应该匹配")
}
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{},
MatchMode: model.MatchModeAny,
})
}
tests := []struct {
name string
@@ -212,9 +186,7 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var bodyLower string
var bodyLowerDone bool
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result)
})
}
@@ -222,12 +194,12 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{"context limit", "model not supported"},
MatchMode: model.MatchModeAny,
})
}
tests := []struct {
name string
@@ -238,14 +210,16 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
{"关键词匹配 context limit", 500, "error: context limit reached", true},
{"关键词匹配 model not supported", 400, "the model not supported here", true},
{"关键词不匹配", 422, "some other error", false},
{"关键词大小写 - 自动转换", 500, "Context Limit exceeded", true},
// 注意ruleMatches 接收的 body 参数应该是已经转换为小写的
// 实际使用时MatchRule 会先将 body 转换为小写再传给 ruleMatches
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var bodyLower string
var bodyLowerDone bool
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
// 模拟 MatchRule 的行为:先转换为小写
bodyLower := strings.ToLower(tt.body)
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
assert.Equal(t, tt.expected, result)
})
}
@@ -254,12 +228,12 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
// any 模式:错误码 OR 关键词
svc := newTestService(nil)
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAny,
})
}
tests := []struct {
name string
@@ -300,9 +274,7 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var bodyLower string
var bodyLowerDone bool
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
@@ -311,12 +283,12 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
// all 模式:错误码 AND 关键词
svc := newTestService(nil)
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAll,
})
}
tests := []struct {
name string
@@ -357,16 +329,14 @@ func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var bodyLower string
var bodyLowerDone bool
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
}
// =============================================================================
// 测试 platformMatchesCached 平台匹配逻辑
// 测试 platformMatches 平台匹配逻辑
// =============================================================================
func TestPlatformMatches(t *testing.T) {
@@ -424,10 +394,10 @@ func TestPlatformMatches(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
rule := &model.ErrorPassthroughRule{
Platforms: tt.rulePlatforms,
})
result := svc.platformMatchesCached(rule, strings.ToLower(tt.requestPlatform))
}
result := svc.platformMatches(rule, tt.requestPlatform)
assert.Equal(t, tt.expected, result)
})
}

View File

@@ -116,7 +116,7 @@ func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
customCodes: []any{float64(500)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 500,
expectStatusCode: 429,
},
{
name: "500_in_custom_codes_matched",
@@ -364,109 +364,3 @@ func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
}
// ---------------------------------------------------------------------------
// epTrackingRepo — records SetRateLimited / SetError calls for verification.
// ---------------------------------------------------------------------------
type epTrackingRepo struct {
mockAccountRepoForGemini
rateLimitedCalls int
rateLimitedID int64
setErrCalls int
setErrID int64
tempCalls int
}
func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
r.rateLimitedCalls++
r.rateLimitedID = id
return nil
}
func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error {
r.setErrCalls++
r.setErrID = id
return nil
}
func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.tempCalls++
return nil
}
// ---------------------------------------------------------------------------
// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit
//
// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码),
// 当上游返回 429/500/503/401 时:
// - 返回给客户端的状态码必须是 500而不是透传原始状态码
// - 不调用 SetRateLimited不进入限流状态
// - 不调用 SetError不停止调度
// - 不调用 handleError
// ---------------------------------------------------------------------------
func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) {
errorCodes := []int{429, 500, 503, 401, 403}
for _, upstreamStatus := range errorCodes {
t.Run(http.StatusText(upstreamStatus), func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{
statusCode: upstreamStatus,
body: `{"error":"some upstream error"}`,
}
repo := &epTrackingRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := &Account{
ID: 500,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(599)},
},
}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
// 不应返回 errorSkipped 不触发账号切换)
require.NoError(t, err, "should not return error")
require.NotNil(t, result, "result should not be nil")
require.NotNil(t, result.resp, "response should not be nil")
defer func() { _ = result.resp.Body.Close() }()
// 状态码必须是 500不透传原始状态码
require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode,
"skipped error should return 500, not %d", upstreamStatus)
// 不调用 handleError
require.Equal(t, 0, handleErrorCount,
"handleError should NOT be called for skipped errors")
// 不标记限流
require.Equal(t, 0, repo.rateLimitedCalls,
"SetRateLimited should NOT be called for skipped errors")
// 不停止调度
require.Equal(t, 0, repo.setErrCalls,
"SetError should NOT be called for skipped errors")
// 只调用一次上游(不重试)
require.Equal(t, 1, upstream.calls,
"should call upstream exactly once (no retry)")
})
}
}

View File

@@ -158,7 +158,6 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode int
body []byte
expectedHandled bool
expectedStatus int // expected outStatus
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
handleErrorCalls int
}{
@@ -172,7 +171,6 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: false,
expectedStatus: 500, // passthrough
handleErrorCalls: 0,
},
{
@@ -189,7 +187,6 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 500, // not in custom codes
body: []byte(`"error"`),
expectedHandled: true,
expectedStatus: http.StatusInternalServerError, // skipped → 500
handleErrorCalls: 0,
},
{
@@ -206,7 +203,6 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: true,
expectedStatus: 500, // matched → original status
handleErrorCalls: 1,
},
{
@@ -229,7 +225,6 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 503,
body: []byte(`overloaded`),
expectedHandled: true,
expectedStatus: 503, // temp_unscheduled → original status
expectedSwitchErr: true,
handleErrorCalls: 0,
},
@@ -255,10 +250,9 @@ func TestApplyErrorPolicy(t *testing.T) {
isStickySession: true,
}
handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch")
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
if tt.expectedSwitchErr {

View File

@@ -21,72 +21,3 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
}
func TestStripBetaToken(t *testing.T) {
tests := []struct {
name string
header string
token string
want string
}{
{
name: "token in middle",
header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
token: "context-1m-2025-08-07",
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "token at start",
header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14",
token: "context-1m-2025-08-07",
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "token at end",
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07",
token: "context-1m-2025-08-07",
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "token not present",
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
token: "context-1m-2025-08-07",
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "empty header",
header: "",
token: "context-1m-2025-08-07",
want: "",
},
{
name: "with spaces",
header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14",
token: "context-1m-2025-08-07",
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "only token",
header: "context-1m-2025-08-07",
token: "context-1m-2025-08-07",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := stripBetaToken(tt.header, tt.token)
require.Equal(t, tt.want, got)
})
}
}
func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) {
required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
incoming := "context-1m-2025-08-07,foo-beta,oauth-2025-04-20"
drop := map[string]struct{}{"context-1m-2025-08-07": {}}
got := mergeAnthropicBetaDropping(required, incoming, drop)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
require.NotContains(t, got, "context-1m-2025-08-07")
}

View File

@@ -77,9 +77,6 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun
func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) Update(ctx context.Context, account *Account) error {
return nil
}
@@ -87,7 +84,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {

View File

@@ -101,9 +101,9 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
}
}
// thinking: {type: "enabled" | "adaptive"}
// thinking: {type: "enabled"}
if rawThinking, ok := req["thinking"].(map[string]any); ok {
if t, ok := rawThinking["type"].(string); ok && (t == "enabled" || t == "adaptive") {
if t, ok := rawThinking["type"].(string); ok && t == "enabled" {
parsed.ThinkingEnabled = true
}
}
@@ -161,9 +161,9 @@ func parseIntegralNumber(raw any) (int, bool) {
// Returns filtered body or original body if filtering fails (fail-safe)
// This prevents 400 errors from invalid thinking block signatures
//
// 策略:
// - thinking.type 不是 "enabled"/"adaptive":移除所有 thinking 相关块
// - thinking.type "enabled"/"adaptive":仅移除缺失/无效 signature 的 thinking 块(避免 400
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
func FilterThinkingBlocks(body []byte) []byte {
return filterThinkingBlocksInternal(body, false)
@@ -489,9 +489,9 @@ func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
}
// filterThinkingBlocksInternal removes invalid thinking blocks from request
// 策略:
// - thinking.type 不是 "enabled"/"adaptive":移除所有 thinking 相关块
// - thinking.type "enabled"/"adaptive":仅移除缺失/无效 signature 的 thinking 块
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
// Fast path: if body doesn't contain "thinking", skip parsing
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
@@ -511,7 +511,7 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
// Check if thinking is enabled
thinkingEnabled := false
if thinking, ok := req["thinking"].(map[string]any); ok {
if thinkType, ok := thinking["type"].(string); ok && (thinkType == "enabled" || thinkType == "adaptive") {
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
thinkingEnabled = true
}
}

View File

@@ -29,14 +29,6 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
require.True(t, parsed.ThinkingEnabled)
}
func TestParseGatewayRequest_ThinkingAdaptiveEnabled(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"adaptive"},"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
require.True(t, parsed.ThinkingEnabled)
}
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
parsed, err := ParseGatewayRequest(body, "")
@@ -217,16 +209,6 @@ func TestFilterThinkingBlocks(t *testing.T) {
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`,
shouldFilter: true,
},
{
name: "does not filter signed thinking blocks when thinking adaptive",
input: `{"thinking":{"type":"adaptive"},"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"ok","signature":"sig_real_123"},{"type":"text","text":"B"}]}]}`,
shouldFilter: false,
},
{
name: "filters unsigned thinking blocks when thinking adaptive",
input: `{"thinking":{"type":"adaptive"},"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"internal","signature":""},{"type":"text","text":"B"}]}]}`,
shouldFilter: true,
},
{
name: "handles no thinking blocks",
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,

View File

@@ -243,12 +243,6 @@ var (
}
)
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除
var systemBlockFilterPrefixes = []string{
"x-anthropic-billing-header",
}
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
@@ -349,8 +343,6 @@ type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
CacheCreation5mTokens int // 5分钟缓存创建token来自嵌套 cache_creation 对象)
CacheCreation1hTokens int // 1小时缓存创建token来自嵌套 cache_creation 对象)
}
// ForwardResult 转发结果
@@ -370,31 +362,15 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应应在同一账号上重试 N 次再切换
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
}
func (e *UpstreamFailoverError) Error() string {
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
}
// TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。
// 由 handler 层在同账号重试全部用尽、切换账号时调用。
func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) {
if failoverErr == nil || !failoverErr.RetryableOnSameAccount {
return
}
// 根据状态码选择封禁策略
switch failoverErr.StatusCode {
case http.StatusBadRequest:
tempUnscheduleGoogleConfigError(ctx, s.accountRepo, accountID, "[handler]")
case http.StatusBadGateway:
tempUnscheduleEmptyResponse(ctx, s.accountRepo, accountID, "[handler]")
}
}
// GatewayService handles API gateway operations
type GatewayService struct {
accountRepo AccountRepository
@@ -1707,17 +1683,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
return accounts, useMixed, nil
}
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
func (s *GatewayService) IsSingleAntigravityAccountGroup(ctx context.Context, groupID *int64) bool {
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAntigravity, true)
if err != nil {
return false
}
return len(accounts) == 1
}
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
if account == nil {
return false
@@ -2708,60 +2673,6 @@ func hasClaudeCodePrefix(text string) bool {
return false
}
// matchesFilterPrefix 检查文本是否匹配任一过滤前缀
func matchesFilterPrefix(text string) bool {
for _, prefix := range systemBlockFilterPrefixes {
if strings.HasPrefix(text, prefix) {
return true
}
}
return false
}
// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素
// 直接从 body 解析 system不依赖外部传入的 parsed.System因为前置步骤可能已修改 body 中的 system
func filterSystemBlocksByPrefix(body []byte) []byte {
sys := gjson.GetBytes(body, "system")
if !sys.Exists() {
return body
}
switch {
case sys.Type == gjson.String:
if matchesFilterPrefix(sys.Str) {
result, err := sjson.DeleteBytes(body, "system")
if err != nil {
return body
}
return result
}
case sys.IsArray():
var parsed []any
if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil {
return body
}
filtered := make([]any, 0, len(parsed))
changed := false
for _, item := range parsed {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) {
changed = true
continue
}
}
filtered = append(filtered, item)
}
if changed {
result, err := sjson.SetBytes(body, "system", filtered)
if err != nil {
return body
}
return result
}
}
return body
}
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
// 处理 null、字符串、数组三种格式
func injectClaudeCodePrompt(body []byte, system any) []byte {
@@ -3041,12 +2952,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据)
// 放在 inject/normalize 之后,确保不会被覆盖
if account.IsOAuth() {
body = filterSystemBlocksByPrefix(body)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
@@ -3553,12 +3458,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
drop := map[string]struct{}{claude.BetaClaudeCode: {}, claude.BetaContext1M: {}}
drop := map[string]struct{}{claude.BetaClaudeCode: {}}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
} else {
// Claude Code 客户端:尽量透传原始 header仅补齐 oauth beta
clientBetaHeader := req.Header.Get("anthropic-beta")
req.Header.Set("anthropic-beta", stripBetaToken(s.getBetaHeader(modelID, clientBetaHeader), claude.BetaContext1M))
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader))
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
@@ -3633,8 +3538,7 @@ func requestNeedsBetaFeatures(body []byte) bool {
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
return true
}
thinkingType := gjson.GetBytes(body, "thinking.type").String()
if strings.EqualFold(thinkingType, "enabled") || strings.EqualFold(thinkingType, "adaptive") {
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
return true
}
return false
@@ -3712,23 +3616,6 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str
return strings.Join(out, ",")
}
// stripBetaToken removes a single beta token from a comma-separated header value.
// It short-circuits when the token is not present to avoid unnecessary allocations.
func stripBetaToken(header, token string) string {
if !strings.Contains(header, token) {
return header
}
out := make([]string, 0, 8)
for _, p := range strings.Split(header, ",") {
p = strings.TrimSpace(p)
if p == "" || p == token {
continue
}
out = append(out, p)
}
return strings.Join(out, ",")
}
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
@@ -4293,23 +4180,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
}
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() {
overrideTarget := account.GetCacheTTLOverrideTarget()
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
rewriteCacheCreationJSON(u, overrideTarget)
}
}
}
if eventType == "message_delta" {
if u, ok := event["usage"].(map[string]any); ok {
rewriteCacheCreationJSON(u, overrideTarget)
}
}
}
if needModelReplace {
if msg, ok := event["message"].(map[string]any); ok {
if model, ok := msg["model"].(string); ok && model == mappedModel {
@@ -4437,14 +4307,6 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
usage.InputTokens = msgStart.Message.Usage.InputTokens
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens")
cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens")
if cc5m.Exists() || cc1h.Exists() {
usage.CacheCreation5mTokens = int(cc5m.Int())
usage.CacheCreation1hTokens = int(cc1h.Int())
}
}
// 解析message_delta获取tokens兼容GLM等把所有usage放在delta中的API
@@ -4473,66 +4335,6 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
if msgDelta.Usage.CacheReadInputTokens > 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
}
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens")
cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens")
if cc5m.Exists() || cc1h.Exists() {
usage.CacheCreation5mTokens = int(cc5m.Int())
usage.CacheCreation1hTokens = int(cc1h.Int())
}
}
}
// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。
// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。
func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool {
// Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别
if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 {
usage.CacheCreation5mTokens = usage.CacheCreationInputTokens
}
total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if total == 0 {
return false
}
switch target {
case "1h":
if usage.CacheCreation1hTokens == total {
return false // 已经全是 1h
}
usage.CacheCreation1hTokens = total
usage.CacheCreation5mTokens = 0
default: // "5m"
if usage.CacheCreation5mTokens == total {
return false // 已经全是 5m
}
usage.CacheCreation5mTokens = total
usage.CacheCreation1hTokens = 0
}
return true
}
// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。
// usageObj 是 usage JSON 对象map[string]any
func rewriteCacheCreationJSON(usageObj map[string]any, target string) {
ccObj, ok := usageObj["cache_creation"].(map[string]any)
if !ok {
return
}
v5m, _ := ccObj["ephemeral_5m_input_tokens"].(float64)
v1h, _ := ccObj["ephemeral_1h_input_tokens"].(float64)
total := v5m + v1h
if total == 0 {
return
}
switch target {
case "1h":
ccObj["ephemeral_1h_input_tokens"] = total
ccObj["ephemeral_5m_input_tokens"] = float64(0)
default: // "5m"
ccObj["ephemeral_5m_input_tokens"] = total
ccObj["ephemeral_1h_input_tokens"] = float64(0)
}
}
@@ -4553,14 +4355,6 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
return nil, fmt.Errorf("parse response: %w", err)
}
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens")
cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens")
if cc5m.Exists() || cc1h.Exists() {
response.Usage.CacheCreation5mTokens = int(cc5m.Int())
response.Usage.CacheCreation1hTokens = int(cc1h.Int())
}
// 兼容 Kimi cached_tokens → cache_read_input_tokens
if response.Usage.CacheReadInputTokens == 0 {
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
@@ -4572,20 +4366,6 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() {
overrideTarget := account.GetCacheTTLOverrideTarget()
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
// 同步更新 body JSON 中的嵌套 cache_creation 对象
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil {
body = newBody
}
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil {
body = newBody
}
}
}
// 如果有模型映射替换响应中的model字段
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
@@ -4662,13 +4442,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
result.Usage.InputTokens = 0
}
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
@@ -4699,12 +4472,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} else {
// Token 计费
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
}
var err error
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
@@ -4738,8 +4509,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
@@ -4754,7 +4523,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: time.Now(),
}
@@ -4855,13 +4623,6 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
result.Usage.InputTokens = 0
}
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
@@ -4892,12 +4653,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
} else {
// Token 计费(使用长上下文计费方法)
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
}
var err error
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
@@ -4931,8 +4690,6 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
@@ -4947,7 +4704,6 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: time.Now(),
}
@@ -5253,8 +5009,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
incomingBeta := req.Header.Get("anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
drop := map[string]struct{}{claude.BetaContext1M: {}}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta))
} else {
clientBetaHeader := req.Header.Get("anthropic-beta")
if clientBetaHeader == "" {
@@ -5264,7 +5019,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if !strings.Contains(beta, claude.BetaTokenCounting) {
beta = beta + "," + claude.BetaTokenCounting
}
req.Header.Set("anthropic-beta", stripBetaToken(beta, claude.BetaContext1M))
req.Header.Set("anthropic-beta", beta)
}
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {

View File

@@ -770,14 +770,6 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
break
}
// 错误策略优先:匹配则跳过重试直接处理。
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
resp = rebuilt
break
} else {
resp = rebuilt
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
@@ -847,7 +839,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody)
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
upstreamReqID := resp.Header.Get(requestIDHeader)
@@ -880,37 +872,6 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
if resp.StatusCode == http.StatusBadRequest {
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if isGoogleProjectConfigError(msg400) {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true}
}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
@@ -1215,14 +1176,6 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
}
// 错误策略优先:匹配则跳过重试直接处理。
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
resp = rebuilt
break
} else {
resp = rebuilt
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
@@ -1330,7 +1283,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if contentType == "" {
contentType = "application/json"
}
c.Data(http.StatusInternalServerError, contentType, respBody)
c.Data(resp.StatusCode, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
@@ -1361,34 +1314,6 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
if resp.StatusCode == http.StatusBadRequest {
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if isGoogleProjectConfigError(msg400) {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody)))
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(evBody), maxBytes)
}
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody, RetryableOnSameAccount: true}
}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
@@ -1500,26 +1425,6 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
// checkErrorPolicyInLoop 在重试循环内预检查错误策略。
// 返回 true 表示策略已匹配(调用者应 breakresp 已重建可直接使用。
// 返回 false 表示 ErrorPolicyNoneresp 已重建,调用者继续走重试逻辑。
func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop(
ctx context.Context, account *Account, resp *http.Response,
) (matched bool, rebuilt *http.Response) {
if resp.StatusCode < 400 || s.rateLimitService == nil {
return false, resp
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
rebuilt = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(body)),
}
policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body)
return policy != ErrorPolicyNone, rebuilt
}
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
@@ -2663,12 +2568,11 @@ func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
prompt, _ := asInt(usageMeta["promptTokenCount"])
cand, _ := asInt(usageMeta["candidatesTokenCount"])
cached, _ := asInt(usageMeta["cachedContentTokenCount"])
thoughts, _ := asInt(usageMeta["thoughtsTokenCount"])
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
return &ClaudeUsage{
InputTokens: prompt - cached,
OutputTokens: cand + thoughts,
OutputTokens: cand,
CacheReadInputTokens: cached,
}
}
@@ -2693,10 +2597,6 @@ func asInt(v any) (int, bool) {
}
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
// 遵守自定义错误码策略:未命中则跳过所有限流处理
if !account.ShouldHandleErrorCode(statusCode) {
return
}
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
return

View File

@@ -4,8 +4,6 @@ import (
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
@@ -205,70 +203,3 @@ func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
}
}
func TestExtractGeminiUsage_ThoughtsTokenCount(t *testing.T) {
tests := []struct {
name string
resp map[string]any
wantInput int
wantOutput int
wantCacheRead int
wantNil bool
}{
{
name: "with thoughtsTokenCount",
resp: map[string]any{
"usageMetadata": map[string]any{
"promptTokenCount": float64(100),
"candidatesTokenCount": float64(20),
"thoughtsTokenCount": float64(50),
},
},
wantInput: 100,
wantOutput: 70,
},
{
name: "with thoughtsTokenCount and cache",
resp: map[string]any{
"usageMetadata": map[string]any{
"promptTokenCount": float64(100),
"candidatesTokenCount": float64(20),
"cachedContentTokenCount": float64(30),
"thoughtsTokenCount": float64(50),
},
},
wantInput: 70,
wantOutput: 70,
wantCacheRead: 30,
},
{
name: "without thoughtsTokenCount (old model)",
resp: map[string]any{
"usageMetadata": map[string]any{
"promptTokenCount": float64(100),
"candidatesTokenCount": float64(20),
},
},
wantInput: 100,
wantOutput: 20,
},
{
name: "no usageMetadata",
resp: map[string]any{},
wantNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
usage := extractGeminiUsage(tt.resp)
if tt.wantNil {
require.Nil(t, usage)
return
}
require.NotNil(t, usage)
require.Equal(t, tt.wantInput, usage.InputTokens)
require.Equal(t, tt.wantOutput, usage.OutputTokens)
require.Equal(t, tt.wantCacheRead, usage.CacheReadInputTokens)
})
}
}

View File

@@ -66,15 +66,12 @@ func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account)
func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil }
func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {

View File

@@ -112,19 +112,13 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran
result.Modified = true
}
// Strip parameters unsupported by codex models via the Responses API.
for _, key := range []string{
"max_output_tokens",
"max_completion_tokens",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
} {
if _, ok := reqBody[key]; ok {
delete(reqBody, key)
result.Modified = true
}
if _, ok := reqBody["max_output_tokens"]; ok {
delete(reqBody, "max_output_tokens")
result.Modified = true
}
if _, ok := reqBody["max_completion_tokens"]; ok {
delete(reqBody, "max_completion_tokens")
result.Modified = true
}
if normalizeCodexTools(reqBody) {

View File

@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page,
PageSize: opsAccountsPageSize,
}, platformFilter, "", "", "", 0)
}, platformFilter, "", "", "")
if err != nil {
return nil, err
}

View File

@@ -20,10 +20,6 @@ const (
// retry the specific upstream attempt (not just the client request).
// This value is sanitized+trimmed before being persisted.
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
// OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。
// ops_error_logger 中间件检查此 key为 true 时跳过错误记录。
OpsSkipPassthroughKey = "ops_skip_passthrough"
)
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
@@ -107,37 +103,6 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
evCopy := ev
existing = append(existing, &evCopy)
c.Set(OpsUpstreamErrorsKey, existing)
checkSkipMonitoringForUpstreamEvent(c, &evCopy)
}
// checkSkipMonitoringForUpstreamEvent checks whether the upstream error event
// matches a passthrough rule with skip_monitoring=true and, if so, sets the
// OpsSkipPassthroughKey on the context. This ensures intermediate retry /
// failover errors (which never go through the final applyErrorPassthroughRule
// path) can still suppress ops_error_logs recording.
func checkSkipMonitoringForUpstreamEvent(c *gin.Context, ev *OpsUpstreamErrorEvent) {
if ev.UpstreamStatusCode == 0 {
return
}
svc := getBoundErrorPassthroughService(c)
if svc == nil {
return
}
// Use the best available body representation for keyword matching.
// Even when body is empty, MatchRule can still match rules that only
// specify ErrorCodes (no Keywords), so we always call it.
body := ev.Detail
if body == "" {
body = ev.Message
}
rule := svc.MatchRule(ev.Platform, ev.UpstreamStatusCode, []byte(body))
if rule != nil && rule.SkipMonitoring {
c.Set(OpsSkipPassthroughKey, true)
}
}
func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {

View File

@@ -27,15 +27,14 @@ var (
// LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值
type LiteLLMModelPricing struct {
InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"`
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
}
// PricingRemoteClient 远程价格数据获取接口
@@ -46,15 +45,14 @@ type PricingRemoteClient interface {
// LiteLLMRawEntry 用于解析原始JSON数据
type LiteLLMRawEntry struct {
InputCostPerToken *float64 `json:"input_cost_per_token"`
OutputCostPerToken *float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"`
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage *float64 `json:"output_cost_per_image"`
InputCostPerToken *float64 `json:"input_cost_per_token"`
OutputCostPerToken *float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage *float64 `json:"output_cost_per_image"`
}
// PricingService 动态价格服务
@@ -320,9 +318,6 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
if entry.CacheCreationInputTokenCost != nil {
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
}
if entry.CacheCreationInputTokenCostAbove1hr != nil {
pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr
}
if entry.CacheReadInputTokenCost != nil {
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
}

View File

@@ -381,31 +381,10 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
}
}
// 2. Anthropic 平台:尝试解析 per-window 头5h / 7d选择实际触发的窗口
if result := calculateAnthropic429ResetTime(headers); result != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
// 更新 session window优先使用 5h-reset 头精确计算,否则从 resetAt 反推
windowEnd := result.resetAt
if result.fiveHourReset != nil {
windowEnd = *result.fiveHourReset
}
windowStart := windowEnd.Add(-5 * time.Hour)
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
}
slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second))
return
}
// 3. 尝试从响应头解析重置时间Anthropic 聚合头,向后兼容)
// 2. 尝试从响应头解析重置时间Anthropic
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
// 4. 如果响应头没有尝试从响应体解析OpenAI usage_limit_reached, Gemini
// 3. 如果响应头没有尝试从响应体解析OpenAI usage_limit_reached, Gemini
if resetTimestamp == "" {
switch account.Platform {
case PlatformOpenAI:
@@ -518,112 +497,6 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
return nil
}
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
type anthropic429Result struct {
resetAt time.Time // The correct reset time to use for SetRateLimited
fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available
}
// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers
// to determine which window (5h or 7d) actually triggered the 429.
//
// Headers used:
// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold
// - anthropic-ratelimit-unified-5h-reset
// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold
// - anthropic-ratelimit-unified-7d-reset
//
// Returns nil when the per-window headers are absent (caller should fall back to
// the aggregated anthropic-ratelimit-unified-reset header).
func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result {
reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset")
reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset")
if reset5hStr == "" && reset7dStr == "" {
return nil
}
var reset5h, reset7d *time.Time
if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil {
t := time.Unix(ts, 0)
reset5h = &t
}
if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil {
t := time.Unix(ts, 0)
reset7d = &t
}
is5hExceeded := isAnthropicWindowExceeded(headers, "5h")
is7dExceeded := isAnthropicWindowExceeded(headers, "7d")
slog.Info("anthropic_429_window_analysis",
"is_5h_exceeded", is5hExceeded,
"is_7d_exceeded", is7dExceeded,
"reset_5h", reset5hStr,
"reset_7d", reset7dStr,
)
// Select the correct reset time based on which window(s) are exceeded.
var chosen *time.Time
switch {
case is5hExceeded && is7dExceeded:
// Both exceeded → prefer 7d (longer cooldown), fall back to 5h
chosen = reset7d
if chosen == nil {
chosen = reset5h
}
case is5hExceeded:
chosen = reset5h
case is7dExceeded:
chosen = reset7d
default:
// Neither flag clearly exceeded — pick the sooner reset as best guess
chosen = pickSooner(reset5h, reset7d)
}
if chosen == nil {
return nil
}
return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h}
}
// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window
// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers.
func isAnthropicWindowExceeded(headers http.Header, window string) bool {
prefix := "anthropic-ratelimit-unified-" + window + "-"
// Check surpassed-threshold first (most explicit signal)
if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") {
return true
}
// Fall back to utilization >= 1.0
if utilStr := headers.Get(prefix + "utilization"); utilStr != "" {
if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 {
// Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0
return true
}
}
return false
}
// pickSooner returns whichever of the two time pointers is earlier.
// If only one is non-nil, it is returned. If both are nil, returns nil.
func pickSooner(a, b *time.Time) *time.Time {
switch {
case a != nil && b != nil:
if a.Before(*b) {
return a
}
return b
case a != nil:
return a
default:
return b
}
}
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
// OpenAI 的 usage_limit_reached 错误格式:
//
@@ -750,10 +623,6 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
}
}
// 同时清除模型级别限流
if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil {
slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err)
}
return nil
}

View File

@@ -1,202 +0,0 @@
package service
import (
"net/http"
"testing"
"time"
)
func TestCalculateAnthropic429ResetTime_Only5hExceeded(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.02")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.32")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1770998400)
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
}
}
func TestCalculateAnthropic429ResetTime_Only7dExceeded(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.50")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.05")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1771549200)
// fiveHourReset should still be populated for session window calculation
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
}
}
func TestCalculateAnthropic429ResetTime_BothExceeded(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.10")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.02")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1771549200)
}
func TestCalculateAnthropic429ResetTime_NoPerWindowHeaders(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
if result != nil {
t.Errorf("expected nil result when no per-window headers, got resetAt=%v", result.resetAt)
}
}
func TestCalculateAnthropic429ResetTime_NoHeaders(t *testing.T) {
result := calculateAnthropic429ResetTime(http.Header{})
if result != nil {
t.Errorf("expected nil result for empty headers, got resetAt=%v", result.resetAt)
}
}
func TestCalculateAnthropic429ResetTime_SurpassedThreshold(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "false")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1770998400)
}
func TestCalculateAnthropic429ResetTime_UtilizationExactlyOne(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.0")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.5")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1770998400)
}
func TestCalculateAnthropic429ResetTime_NeitherExceeded_UsesShorter(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.95")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") // sooner
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.80")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") // later
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1770998400)
}
func TestCalculateAnthropic429ResetTime_Only5hResetHeader(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.05")
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1770998400)
}
func TestCalculateAnthropic429ResetTime_Only7dResetHeader(t *testing.T) {
headers := http.Header{}
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.03")
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
result := calculateAnthropic429ResetTime(headers)
assertAnthropicResult(t, result, 1771549200)
if result.fiveHourReset != nil {
t.Errorf("expected fiveHourReset=nil when no 5h headers, got %v", result.fiveHourReset)
}
}
func TestIsAnthropicWindowExceeded(t *testing.T) {
tests := []struct {
name string
headers http.Header
window string
expected bool
}{
{
name: "utilization above 1.0",
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.02"),
window: "5h",
expected: true,
},
{
name: "utilization exactly 1.0",
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.0"),
window: "5h",
expected: true,
},
{
name: "utilization below 1.0",
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "0.99"),
window: "5h",
expected: false,
},
{
name: "surpassed-threshold true",
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "true"),
window: "7d",
expected: true,
},
{
name: "surpassed-threshold True (case insensitive)",
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "True"),
window: "7d",
expected: true,
},
{
name: "surpassed-threshold false",
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "false"),
window: "7d",
expected: false,
},
{
name: "no headers",
headers: http.Header{},
window: "5h",
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := isAnthropicWindowExceeded(tc.headers, tc.window)
if got != tc.expected {
t.Errorf("expected %v, got %v", tc.expected, got)
}
})
}
}
// assertAnthropicResult is a test helper that verifies the result is non-nil and
// has the expected resetAt unix timestamp.
func assertAnthropicResult(t *testing.T, result *anthropic429Result, wantUnix int64) {
t.Helper()
if result == nil {
t.Fatal("expected non-nil result")
return // unreachable, but satisfies staticcheck SA5011
}
want := time.Unix(wantUnix, 0)
if !result.resetAt.Equal(want) {
t.Errorf("expected resetAt=%v, got %v", want, result.resetAt)
}
}
func makeHeader(key, value string) http.Header {
h := http.Header{}
h.Set(key, value)
return h
}

View File

@@ -26,8 +26,8 @@ type UsageLog struct {
CacheCreationTokens int
CacheReadTokens int
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
CacheCreation5mTokens int
CacheCreation1hTokens int
InputCost float64
OutputCost float64
@@ -46,9 +46,6 @@ type UsageLog struct {
UserAgent *string
IPAddress *string
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
CacheTTLOverridden bool
// 图片生成字段
ImageCount int
ImageSize *string

View File

@@ -1,4 +0,0 @@
-- Add skip_monitoring field to error_passthrough_rules table
-- When true, errors matching this rule will not be recorded in ops_error_logs
ALTER TABLE error_passthrough_rules
ADD COLUMN IF NOT EXISTS skip_monitoring BOOLEAN NOT NULL DEFAULT false;

View File

@@ -1,14 +0,0 @@
-- Drop legacy cache token columns that lack the underscore separator.
-- These were created by GORM's automatic snake_case conversion:
-- CacheCreation5mTokens → cache_creation5m_tokens (incorrect)
-- CacheCreation1hTokens → cache_creation1h_tokens (incorrect)
--
-- The canonical columns are:
-- cache_creation_5m_tokens (defined in 001_init.sql)
-- cache_creation_1h_tokens (defined in 001_init.sql)
--
-- Migration 009 already copied data from legacy → canonical columns.
-- This migration drops the legacy columns to avoid confusion.
ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation5m_tokens;
ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation1h_tokens;

View File

@@ -1,2 +0,0 @@
-- Add cache_ttl_overridden flag to usage_logs for tracking cache TTL override per account.
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS cache_ttl_overridden BOOLEAN NOT NULL DEFAULT FALSE;

View File

@@ -158,7 +158,6 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
- PGDATA=/var/lib/postgresql/data
- TZ=${TZ:-Asia/Shanghai}
networks:
- sub2api-network

View File

@@ -17,7 +17,7 @@
"dependencies": {
"@lobehub/icons": "^4.0.2",
"@vueuse/core": "^10.7.0",
"axios": "^1.13.5",
"axios": "^1.6.2",
"chart.js": "^4.4.1",
"dompurify": "^3.3.1",
"driver.js": "^1.4.0",

View File

@@ -15,8 +15,8 @@ importers:
specifier: ^10.7.0
version: 10.11.1(vue@3.5.26(typescript@5.6.3))
axios:
specifier: ^1.13.5
version: 1.13.5
specifier: ^1.6.2
version: 1.13.2
chart.js:
specifier: ^4.4.1
version: 4.5.1
@@ -1257,67 +1257,56 @@ packages:
resolution: {integrity: sha512-EHMUcDwhtdRGlXZsGSIuXSYwD5kOT9NVnx9sqzYiwAc91wfYOE1g1djOEDseZJKKqtHAHGwnGPQu3kytmfaXLQ==}
cpu: [arm]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-arm-musleabihf@4.54.0':
resolution: {integrity: sha512-+pBrqEjaakN2ySv5RVrj/qLytYhPKEUwk+e3SFU5jTLHIcAtqh2rLrd/OkbNuHJpsBgxsD8ccJt5ga/SeG0JmA==}
cpu: [arm]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-arm64-gnu@4.54.0':
resolution: {integrity: sha512-NSqc7rE9wuUaRBsBp5ckQ5CVz5aIRKCwsoa6WMF7G01sX3/qHUw/z4pv+D+ahL1EIKy6Enpcnz1RY8pf7bjwng==}
cpu: [arm64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-arm64-musl@4.54.0':
resolution: {integrity: sha512-gr5vDbg3Bakga5kbdpqx81m2n9IX8M6gIMlQQIXiLTNeQW6CucvuInJ91EuCJ/JYvc+rcLLsDFcfAD1K7fMofg==}
cpu: [arm64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-loong64-gnu@4.54.0':
resolution: {integrity: sha512-gsrtB1NA3ZYj2vq0Rzkylo9ylCtW/PhpLEivlgWe0bpgtX5+9j9EZa0wtZiCjgu6zmSeZWyI/e2YRX1URozpIw==}
cpu: [loong64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-ppc64-gnu@4.54.0':
resolution: {integrity: sha512-y3qNOfTBStmFNq+t4s7Tmc9hW2ENtPg8FeUD/VShI7rKxNW7O4fFeaYbMsd3tpFlIg1Q8IapFgy7Q9i2BqeBvA==}
cpu: [ppc64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-riscv64-gnu@4.54.0':
resolution: {integrity: sha512-89sepv7h2lIVPsFma8iwmccN7Yjjtgz0Rj/Ou6fEqg3HDhpCa+Et+YSufy27i6b0Wav69Qv4WBNl3Rs6pwhebQ==}
cpu: [riscv64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-riscv64-musl@4.54.0':
resolution: {integrity: sha512-ZcU77ieh0M2Q8Ur7D5X7KvK+UxbXeDHwiOt/CPSBTI1fBmeDMivW0dPkdqkT4rOgDjrDDBUed9x4EgraIKoR2A==}
cpu: [riscv64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-s390x-gnu@4.54.0':
resolution: {integrity: sha512-2AdWy5RdDF5+4YfG/YesGDDtbyJlC9LHmL6rZw6FurBJ5n4vFGupsOBGfwMRjBYH7qRQowT8D/U4LoSvVwOhSQ==}
cpu: [s390x]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-x64-gnu@4.54.0':
resolution: {integrity: sha512-WGt5J8Ij/rvyqpFexxk3ffKqqbLf9AqrTBbWDk7ApGUzaIs6V+s2s84kAxklFwmMF/vBNGrVdYgbblCOFFezMQ==}
cpu: [x64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-x64-musl@4.54.0':
resolution: {integrity: sha512-JzQmb38ATzHjxlPHuTH6tE7ojnMKM2kYNzt44LO/jJi8BpceEC8QuXYA908n8r3CNuG/B3BV8VR3Hi1rYtmPiw==}
cpu: [x64]
os: [linux]
libc: [musl]
'@rollup/rollup-openharmony-arm64@4.54.0':
resolution: {integrity: sha512-huT3fd0iC7jigGh7n3q/+lfPcXxBi+om/Rs3yiFxjvSxbSB6aohDFXbWvlspaqjeOh+hx7DDHS+5Es5qRkWkZg==}
@@ -1816,8 +1805,8 @@ packages:
peerDependencies:
postcss: ^8.1.0
axios@1.13.5:
resolution: {integrity: sha512-cz4ur7Vb0xS4/KUN0tPWe44eqxrIu31me+fbang3ijiNscE129POzipJJA6zniq2C/Z6sJCjMimjS8Lc/GAs8Q==}
axios@1.13.2:
resolution: {integrity: sha512-VPk9ebNqPcy5lRGuSlKx752IlDatOjT9paPlm8A7yOuW2Fbvp4X3JznJtT4f0GzGLLiWE9W8onz51SqLYwzGaA==}
babel-plugin-macros@3.1.0:
resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==}
@@ -6398,7 +6387,7 @@ snapshots:
postcss: 8.5.6
postcss-value-parser: 4.2.0
axios@1.13.5:
axios@1.13.2:
dependencies:
follow-redirects: 1.15.11
form-data: 4.0.5

View File

@@ -32,7 +32,6 @@ export async function list(
platform?: string
type?: string
status?: string
group?: string
search?: string
},
options?: {
@@ -328,34 +327,11 @@ export async function getAvailableModels(id: number): Promise<ClaudeModel[]> {
return data
}
export interface CRSPreviewAccount {
crs_account_id: string
kind: string
name: string
platform: string
type: string
}
export interface PreviewFromCRSResult {
new_accounts: CRSPreviewAccount[]
existing_accounts: CRSPreviewAccount[]
}
export async function previewFromCrs(params: {
base_url: string
username: string
password: string
}): Promise<PreviewFromCRSResult> {
const { data } = await apiClient.post<PreviewFromCRSResult>('/admin/accounts/sync/crs/preview', params)
return data
}
export async function syncFromCrs(params: {
base_url: string
username: string
password: string
sync_proxies?: boolean
selected_account_ids?: string[]
}): Promise<{
created: number
updated: number
@@ -369,19 +345,7 @@ export async function syncFromCrs(params: {
error?: string
}>
}> {
const { data } = await apiClient.post<{
created: number
updated: number
skipped: number
failed: number
items: Array<{
crs_account_id: string
kind: string
name: string
action: string
error?: string
}>
}>('/admin/accounts/sync/crs', params)
const { data } = await apiClient.post('/admin/accounts/sync/crs', params)
return data
}
@@ -478,7 +442,6 @@ export const accountsAPI = {
batchCreate,
batchUpdateCredentials,
bulkUpdate,
previewFromCrs,
syncFromCrs,
exportData,
importData,

View File

@@ -53,18 +53,4 @@ export async function exchangeCode(
return data
}
export async function refreshAntigravityToken(
refreshToken: string,
proxyId?: number | null
): Promise<AntigravityTokenInfo> {
const payload: Record<string, any> = { refresh_token: refreshToken }
if (proxyId) payload.proxy_id = proxyId
const { data } = await apiClient.post<AntigravityTokenInfo>(
'/admin/antigravity/oauth/refresh-token',
payload
)
return data
}
export default { generateAuthUrl, exchangeCode, refreshAntigravityToken }
export default { generateAuthUrl, exchangeCode }

View File

@@ -21,7 +21,6 @@ export interface ErrorPassthroughRule {
response_code: number | null
passthrough_body: boolean
custom_message: string | null
skip_monitoring: boolean
description: string | null
created_at: string
updated_at: string
@@ -42,7 +41,6 @@ export interface CreateRuleRequest {
response_code?: number | null
passthrough_body?: boolean
custom_message?: string | null
skip_monitoring?: boolean
description?: string | null
}
@@ -61,7 +59,6 @@ export interface UpdateRuleRequest {
response_code?: number | null
passthrough_body?: boolean
custom_message?: string | null
skip_monitoring?: boolean
description?: string | null
}

View File

@@ -41,7 +41,7 @@
>
<div class="mb-2 flex items-center justify-between">
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.groupCountTotal', { count: groups.length }) }}
{{ t('admin.accounts.allGroups', { count: groups.length }) }}
</span>
<button
@click="showPopover = false"

View File

@@ -708,7 +708,6 @@ const groupIds = ref<number[]>([])
// All models list (combined Anthropic + OpenAI)
const allModels = [
{ value: 'claude-opus-4-6', label: 'Claude Opus 4.6' },
{ value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' },
{ value: 'claude-opus-4-5-20251101', label: 'Claude Opus 4.5' },
{ value: 'claude-sonnet-4-20250514', label: 'Claude Sonnet 4' },
{ value: 'claude-sonnet-4-5-20250929', label: 'Claude Sonnet 4.5' },
@@ -755,13 +754,6 @@ const presetMappings = [
color:
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
},
{
label: 'Sonnet 4.6',
from: 'claude-sonnet-4-6',
to: 'claude-sonnet-4-6',
color:
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
},
{
label: 'Opus->Sonnet',
from: 'claude-opus-4-5-20251101',

View File

@@ -665,8 +665,8 @@
<Icon name="cloud" size="sm" />
</div>
<div>
<span class="block text-sm font-medium text-gray-900 dark:text-white">API Key</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.antigravityApikey') }}</span>
<span class="block text-sm font-medium text-gray-900 dark:text-white">{{ t('admin.accounts.types.upstream') }}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.upstreamDesc') }}</span>
</div>
</button>
</div>
@@ -681,7 +681,7 @@
type="text"
required
class="input"
placeholder="https://cloudcode-pa.googleapis.com"
placeholder="https://s.konstants.xyz"
/>
<p class="input-hint">{{ t('admin.accounts.upstream.baseUrlHint') }}</p>
</div>
@@ -816,8 +816,8 @@
</div>
</div>
<!-- API Key input (only for apikey type, excluding Antigravity which has its own fields) -->
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity'" class="space-y-4">
<!-- API Key input (only for apikey type) -->
<div v-if="form.type === 'apikey'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
<input
@@ -862,7 +862,7 @@
<p class="input-hint">{{ t('admin.accounts.gemini.tier.aiStudioHint') }}</p>
</div>
<!-- Model Restriction Section (不适用于 GeminiAntigravity 已在上层条件排除) -->
<!-- Model Restriction Section (不适用于 Gemini) -->
<div v-if="form.platform !== 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
@@ -1527,46 +1527,6 @@
</button>
</div>
</div>
<!-- Cache TTL Override -->
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.cacheTTLOverride.label') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.cacheTTLOverride.hint') }}
</p>
</div>
<button
type="button"
@click="cacheTTLOverrideEnabled = !cacheTTLOverrideEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
cacheTTLOverrideEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
cacheTTLOverrideEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
<div v-if="cacheTTLOverrideEnabled" class="mt-3">
<label class="input-label text-xs">{{ t('admin.accounts.quotaControl.cacheTTLOverride.target') }}</label>
<select
v-model="cacheTTLOverrideTarget"
class="mt-1 block w-full rounded-md border border-gray-300 bg-white px-3 py-2 text-sm shadow-sm focus:border-primary-500 focus:outline-none focus:ring-1 focus:ring-primary-500 dark:border-dark-500 dark:bg-dark-700 dark:text-white"
>
<option value="5m">5m</option>
<option value="1h">1h</option>
</select>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.cacheTTLOverride.targetHint') }}
</p>
</div>
</div>
</div>
<div>
@@ -1687,12 +1647,12 @@
:show-proxy-warning="form.platform !== 'openai' && !!form.proxy_id"
:allow-multiple="form.platform === 'anthropic'"
:show-cookie-option="form.platform === 'anthropic'"
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'antigravity'"
:show-refresh-token-option="form.platform === 'openai'"
:platform="form.platform"
:show-project-id="geminiOAuthType === 'code_assist'"
@generate-url="handleGenerateUrl"
@cookie-auth="handleCookieAuth"
@validate-refresh-token="handleValidateRefreshToken"
@validate-refresh-token="handleOpenAIValidateRT"
/>
</div>
@@ -2186,8 +2146,6 @@ const maxSessions = ref<number | null>(null)
const sessionIdleTimeout = ref<number | null>(null)
const tlsFingerprintEnabled = ref(false)
const sessionIdMaskingEnabled = ref(false)
const cacheTTLOverrideEnabled = ref(false)
const cacheTTLOverrideTarget = ref<string>('5m')
// Gemini tier selection (used as fallback when auto-detection is unavailable/fails)
const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free')
@@ -2639,8 +2597,6 @@ const resetForm = () => {
sessionIdleTimeout.value = null
tlsFingerprintEnabled.value = false
sessionIdMaskingEnabled.value = false
cacheTTLOverrideEnabled.value = false
cacheTTLOverrideTarget.value = '5m'
antigravityAccountType.value = 'oauth'
upstreamBaseUrl.value = ''
upstreamApiKey.value = ''
@@ -2846,14 +2802,6 @@ const handleGenerateUrl = async () => {
}
}
const handleValidateRefreshToken = (rt: string) => {
if (form.platform === 'openai') {
handleOpenAIValidateRT(rt)
} else if (form.platform === 'antigravity') {
handleAntigravityValidateRT(rt)
}
}
const formatDateTimeLocal = formatDateTimeLocalInput
const parseDateTimeLocal = parseDateTimeLocalInput
@@ -3002,95 +2950,6 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
}
}
// Antigravity 手动 RT 批量验证和创建
const handleAntigravityValidateRT = async (refreshTokenInput: string) => {
if (!refreshTokenInput.trim()) return
// Parse multiple refresh tokens (one per line)
const refreshTokens = refreshTokenInput
.split('\n')
.map((rt) => rt.trim())
.filter((rt) => rt)
if (refreshTokens.length === 0) {
antigravityOAuth.error.value = t('admin.accounts.oauth.antigravity.pleaseEnterRefreshToken')
return
}
antigravityOAuth.loading.value = true
antigravityOAuth.error.value = ''
let successCount = 0
let failedCount = 0
const errors: string[] = []
try {
for (let i = 0; i < refreshTokens.length; i++) {
try {
const tokenInfo = await antigravityOAuth.validateRefreshToken(
refreshTokens[i],
form.proxy_id
)
if (!tokenInfo) {
failedCount++
errors.push(`#${i + 1}: ${antigravityOAuth.error.value || 'Validation failed'}`)
antigravityOAuth.error.value = ''
continue
}
const credentials = antigravityOAuth.buildCredentials(tokenInfo)
// Generate account name with index for batch
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
// Note: Antigravity doesn't have buildExtraInfo, so we pass empty extra or rely on credentials
await adminAPI.accounts.create({
name: accountName,
notes: form.notes,
platform: 'antigravity',
type: 'oauth',
credentials,
extra: {},
proxy_id: form.proxy_id,
concurrency: form.concurrency,
priority: form.priority,
rate_multiplier: form.rate_multiplier,
group_ids: form.group_ids,
expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value
})
successCount++
} catch (error: any) {
failedCount++
const errMsg = error.response?.data?.detail || error.message || 'Unknown error'
errors.push(`#${i + 1}: ${errMsg}`)
}
}
// Show results
if (successCount > 0 && failedCount === 0) {
appStore.showSuccess(
refreshTokens.length > 1
? t('admin.accounts.oauth.batchSuccess', { count: successCount })
: t('admin.accounts.accountCreated')
)
emit('created')
handleClose()
} else if (successCount > 0 && failedCount > 0) {
appStore.showWarning(
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
)
antigravityOAuth.error.value = errors.join('\n')
emit('created')
} else {
antigravityOAuth.error.value = errors.join('\n')
appStore.showError(t('admin.accounts.oauth.batchFailed'))
}
} finally {
antigravityOAuth.loading.value = false
}
}
// Gemini OAuth 授权码兑换
const handleGeminiExchange = async (authCode: string) => {
if (!authCode.trim() || !geminiOAuth.sessionId.value) return
@@ -3218,12 +3077,6 @@ const handleAnthropicExchange = async (authCode: string) => {
extra.session_id_masking_enabled = true
}
// Add cache TTL override settings
if (cacheTTLOverrideEnabled.value) {
extra.cache_ttl_override_enabled = true
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
}
const credentials = {
...tokenInfo,
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
@@ -3317,12 +3170,6 @@ const handleCookieAuth = async (sessionKey: string) => {
extra.session_id_masking_enabled = true
}
// Add cache TTL override settings
if (cacheTTLOverrideEnabled.value) {
extra.cache_ttl_override_enabled = true
extra.cache_ttl_override_target = cacheTTLOverrideTarget.value
}
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
// Merge interceptWarmupRequests into credentials

Some files were not shown because too many files have changed in this diff Show More