mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-05 07:52:13 +08:00
Compare commits
60 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a817cafe3d | ||
|
|
2857fa2ef7 | ||
|
|
e681431454 | ||
|
|
5b568aa9d4 | ||
|
|
471943269c | ||
|
|
28a5e2f0e6 | ||
|
|
b4c22ce6ce | ||
|
|
5248097f90 | ||
|
|
8e2c22d0bd | ||
|
|
be56a282f2 | ||
|
|
5f4eb9f9d0 | ||
|
|
d1cd5c0a73 | ||
|
|
5429c74c10 | ||
|
|
fe1d46a8ea | ||
|
|
c7b42148a5 | ||
|
|
bc1abb6a23 | ||
|
|
d307d48def | ||
|
|
1bb40084fc | ||
|
|
8f0efa16ca | ||
|
|
ef2c35dbb1 | ||
|
|
04a1a7c2b5 | ||
|
|
d21d70a5cf | ||
|
|
e73b778d2b | ||
|
|
723102766b | ||
|
|
a4a46a8618 | ||
|
|
6ae82e04d5 | ||
|
|
19cca11e00 | ||
|
|
c8f87a9c92 | ||
|
|
ae6fed15cc | ||
|
|
378e476e48 | ||
|
|
2a1067c82b | ||
|
|
a54b81cf74 | ||
|
|
2d4236f76e | ||
|
|
84ced1c497 | ||
|
|
b161312183 | ||
|
|
1f647b120a | ||
|
|
7d0a30fa8f | ||
|
|
d95e04fd1f | ||
|
|
5dd83d3cf2 | ||
|
|
14e1aac9b5 | ||
|
|
6114f69cca | ||
|
|
d6c2921f2b | ||
|
|
61c73287dc | ||
|
|
89905ec43d | ||
|
|
aa4b102108 | ||
|
|
e4bc35151f | ||
|
|
56da498b7e | ||
|
|
1bba1a62b1 | ||
|
|
4a84ca9a02 | ||
|
|
a70d37a676 | ||
|
|
6892e84ad2 | ||
|
|
73f455745c | ||
|
|
021abfca18 | ||
|
|
7d66f7ff0d | ||
|
|
470b37be7e | ||
|
|
f6cfab9901 | ||
|
|
51572b5da0 | ||
|
|
91ca28b7e3 | ||
|
|
04cedce9a1 | ||
|
|
5e0d789440 |
2
.github/workflows/security-scan.yml
vendored
2
.github/workflows/security-scan.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
working-directory: backend
|
working-directory: backend
|
||||||
run: |
|
run: |
|
||||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||||
gosec -severity high -confidence high ./...
|
gosec -conf .gosec.json -severity high -confidence high ./...
|
||||||
|
|
||||||
frontend-security:
|
frontend-security:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
5
backend/.gosec.json
Normal file
5
backend/.gosec.json
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"global": {
|
||||||
|
"exclude": "G704"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1 +1 @@
|
|||||||
0.1.74.7
|
0.1.76
|
||||||
@@ -44,6 +44,8 @@ type ErrorPassthroughRule struct {
|
|||||||
PassthroughBody bool `json:"passthrough_body,omitempty"`
|
PassthroughBody bool `json:"passthrough_body,omitempty"`
|
||||||
// CustomMessage holds the value of the "custom_message" field.
|
// CustomMessage holds the value of the "custom_message" field.
|
||||||
CustomMessage *string `json:"custom_message,omitempty"`
|
CustomMessage *string `json:"custom_message,omitempty"`
|
||||||
|
// SkipMonitoring holds the value of the "skip_monitoring" field.
|
||||||
|
SkipMonitoring bool `json:"skip_monitoring,omitempty"`
|
||||||
// Description holds the value of the "description" field.
|
// Description holds the value of the "description" field.
|
||||||
Description *string `json:"description,omitempty"`
|
Description *string `json:"description,omitempty"`
|
||||||
selectValues sql.SelectValues
|
selectValues sql.SelectValues
|
||||||
@@ -56,7 +58,7 @@ func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) {
|
|||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms:
|
case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms:
|
||||||
values[i] = new([]byte)
|
values[i] = new([]byte)
|
||||||
case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody:
|
case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody, errorpassthroughrule.FieldSkipMonitoring:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode:
|
case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
@@ -171,6 +173,12 @@ func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) err
|
|||||||
_m.CustomMessage = new(string)
|
_m.CustomMessage = new(string)
|
||||||
*_m.CustomMessage = value.String
|
*_m.CustomMessage = value.String
|
||||||
}
|
}
|
||||||
|
case errorpassthroughrule.FieldSkipMonitoring:
|
||||||
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field skip_monitoring", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.SkipMonitoring = value.Bool
|
||||||
|
}
|
||||||
case errorpassthroughrule.FieldDescription:
|
case errorpassthroughrule.FieldDescription:
|
||||||
if value, ok := values[i].(*sql.NullString); !ok {
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field description", values[i])
|
return fmt.Errorf("unexpected type %T for field description", values[i])
|
||||||
@@ -257,6 +265,9 @@ func (_m *ErrorPassthroughRule) String() string {
|
|||||||
builder.WriteString(*v)
|
builder.WriteString(*v)
|
||||||
}
|
}
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("skip_monitoring=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.SkipMonitoring))
|
||||||
|
builder.WriteString(", ")
|
||||||
if v := _m.Description; v != nil {
|
if v := _m.Description; v != nil {
|
||||||
builder.WriteString("description=")
|
builder.WriteString("description=")
|
||||||
builder.WriteString(*v)
|
builder.WriteString(*v)
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ const (
|
|||||||
FieldPassthroughBody = "passthrough_body"
|
FieldPassthroughBody = "passthrough_body"
|
||||||
// FieldCustomMessage holds the string denoting the custom_message field in the database.
|
// FieldCustomMessage holds the string denoting the custom_message field in the database.
|
||||||
FieldCustomMessage = "custom_message"
|
FieldCustomMessage = "custom_message"
|
||||||
|
// FieldSkipMonitoring holds the string denoting the skip_monitoring field in the database.
|
||||||
|
FieldSkipMonitoring = "skip_monitoring"
|
||||||
// FieldDescription holds the string denoting the description field in the database.
|
// FieldDescription holds the string denoting the description field in the database.
|
||||||
FieldDescription = "description"
|
FieldDescription = "description"
|
||||||
// Table holds the table name of the errorpassthroughrule in the database.
|
// Table holds the table name of the errorpassthroughrule in the database.
|
||||||
@@ -61,6 +63,7 @@ var Columns = []string{
|
|||||||
FieldResponseCode,
|
FieldResponseCode,
|
||||||
FieldPassthroughBody,
|
FieldPassthroughBody,
|
||||||
FieldCustomMessage,
|
FieldCustomMessage,
|
||||||
|
FieldSkipMonitoring,
|
||||||
FieldDescription,
|
FieldDescription,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,6 +98,8 @@ var (
|
|||||||
DefaultPassthroughCode bool
|
DefaultPassthroughCode bool
|
||||||
// DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field.
|
// DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field.
|
||||||
DefaultPassthroughBody bool
|
DefaultPassthroughBody bool
|
||||||
|
// DefaultSkipMonitoring holds the default value on creation for the "skip_monitoring" field.
|
||||||
|
DefaultSkipMonitoring bool
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the ErrorPassthroughRule queries.
|
// OrderOption defines the ordering options for the ErrorPassthroughRule queries.
|
||||||
@@ -155,6 +160,11 @@ func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldCustomMessage, opts...).ToFunc()
|
return sql.OrderByField(FieldCustomMessage, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BySkipMonitoring orders the results by the skip_monitoring field.
|
||||||
|
func BySkipMonitoring(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldSkipMonitoring, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByDescription orders the results by the description field.
|
// ByDescription orders the results by the description field.
|
||||||
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
|
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldDescription, opts...).ToFunc()
|
return sql.OrderByField(FieldDescription, opts...).ToFunc()
|
||||||
|
|||||||
@@ -104,6 +104,11 @@ func CustomMessage(v string) predicate.ErrorPassthroughRule {
|
|||||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v))
|
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SkipMonitoring applies equality check predicate on the "skip_monitoring" field. It's identical to SkipMonitoringEQ.
|
||||||
|
func SkipMonitoring(v bool) predicate.ErrorPassthroughRule {
|
||||||
|
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v))
|
||||||
|
}
|
||||||
|
|
||||||
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
|
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
|
||||||
func Description(v string) predicate.ErrorPassthroughRule {
|
func Description(v string) predicate.ErrorPassthroughRule {
|
||||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
|
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
|
||||||
@@ -544,6 +549,16 @@ func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule {
|
|||||||
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v))
|
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SkipMonitoringEQ applies the EQ predicate on the "skip_monitoring" field.
|
||||||
|
func SkipMonitoringEQ(v bool) predicate.ErrorPassthroughRule {
|
||||||
|
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SkipMonitoringNEQ applies the NEQ predicate on the "skip_monitoring" field.
|
||||||
|
func SkipMonitoringNEQ(v bool) predicate.ErrorPassthroughRule {
|
||||||
|
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldSkipMonitoring, v))
|
||||||
|
}
|
||||||
|
|
||||||
// DescriptionEQ applies the EQ predicate on the "description" field.
|
// DescriptionEQ applies the EQ predicate on the "description" field.
|
||||||
func DescriptionEQ(v string) predicate.ErrorPassthroughRule {
|
func DescriptionEQ(v string) predicate.ErrorPassthroughRule {
|
||||||
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
|
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
|
||||||
|
|||||||
@@ -172,6 +172,20 @@ func (_c *ErrorPassthroughRuleCreate) SetNillableCustomMessage(v *string) *Error
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||||
|
func (_c *ErrorPassthroughRuleCreate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleCreate {
|
||||||
|
_c.mutation.SetSkipMonitoring(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
|
||||||
|
func (_c *ErrorPassthroughRuleCreate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetSkipMonitoring(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetDescription sets the "description" field.
|
// SetDescription sets the "description" field.
|
||||||
func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate {
|
func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate {
|
||||||
_c.mutation.SetDescription(v)
|
_c.mutation.SetDescription(v)
|
||||||
@@ -249,6 +263,10 @@ func (_c *ErrorPassthroughRuleCreate) defaults() {
|
|||||||
v := errorpassthroughrule.DefaultPassthroughBody
|
v := errorpassthroughrule.DefaultPassthroughBody
|
||||||
_c.mutation.SetPassthroughBody(v)
|
_c.mutation.SetPassthroughBody(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.SkipMonitoring(); !ok {
|
||||||
|
v := errorpassthroughrule.DefaultSkipMonitoring
|
||||||
|
_c.mutation.SetSkipMonitoring(v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check runs all checks and user-defined validators on the builder.
|
// check runs all checks and user-defined validators on the builder.
|
||||||
@@ -287,6 +305,9 @@ func (_c *ErrorPassthroughRuleCreate) check() error {
|
|||||||
if _, ok := _c.mutation.PassthroughBody(); !ok {
|
if _, ok := _c.mutation.PassthroughBody(); !ok {
|
||||||
return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)}
|
return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.SkipMonitoring(); !ok {
|
||||||
|
return &ValidationError{Name: "skip_monitoring", err: errors.New(`ent: missing required field "ErrorPassthroughRule.skip_monitoring"`)}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -366,6 +387,10 @@ func (_c *ErrorPassthroughRuleCreate) createSpec() (*ErrorPassthroughRule, *sqlg
|
|||||||
_spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value)
|
_spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value)
|
||||||
_node.CustomMessage = &value
|
_node.CustomMessage = &value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.SkipMonitoring(); ok {
|
||||||
|
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
|
||||||
|
_node.SkipMonitoring = value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.Description(); ok {
|
if value, ok := _c.mutation.Description(); ok {
|
||||||
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
|
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
|
||||||
_node.Description = &value
|
_node.Description = &value
|
||||||
@@ -608,6 +633,18 @@ func (u *ErrorPassthroughRuleUpsert) ClearCustomMessage() *ErrorPassthroughRuleU
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||||
|
func (u *ErrorPassthroughRuleUpsert) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsert {
|
||||||
|
u.Set(errorpassthroughrule.FieldSkipMonitoring, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
|
||||||
|
func (u *ErrorPassthroughRuleUpsert) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsert {
|
||||||
|
u.SetExcluded(errorpassthroughrule.FieldSkipMonitoring)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetDescription sets the "description" field.
|
// SetDescription sets the "description" field.
|
||||||
func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert {
|
func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert {
|
||||||
u.Set(errorpassthroughrule.FieldDescription, v)
|
u.Set(errorpassthroughrule.FieldDescription, v)
|
||||||
@@ -888,6 +925,20 @@ func (u *ErrorPassthroughRuleUpsertOne) ClearCustomMessage() *ErrorPassthroughRu
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||||
|
func (u *ErrorPassthroughRuleUpsertOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertOne {
|
||||||
|
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||||
|
s.SetSkipMonitoring(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
|
||||||
|
func (u *ErrorPassthroughRuleUpsertOne) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertOne {
|
||||||
|
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||||
|
s.UpdateSkipMonitoring()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetDescription sets the "description" field.
|
// SetDescription sets the "description" field.
|
||||||
func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne {
|
func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne {
|
||||||
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||||
@@ -1337,6 +1388,20 @@ func (u *ErrorPassthroughRuleUpsertBulk) ClearCustomMessage() *ErrorPassthroughR
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||||
|
func (u *ErrorPassthroughRuleUpsertBulk) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertBulk {
|
||||||
|
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||||
|
s.SetSkipMonitoring(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
|
||||||
|
func (u *ErrorPassthroughRuleUpsertBulk) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertBulk {
|
||||||
|
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||||
|
s.UpdateSkipMonitoring()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetDescription sets the "description" field.
|
// SetDescription sets the "description" field.
|
||||||
func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk {
|
func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk {
|
||||||
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
|
||||||
|
|||||||
@@ -227,6 +227,20 @@ func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRule
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||||
|
func (_u *ErrorPassthroughRuleUpdate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdate {
|
||||||
|
_u.mutation.SetSkipMonitoring(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
|
||||||
|
func (_u *ErrorPassthroughRuleUpdate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetSkipMonitoring(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetDescription sets the "description" field.
|
// SetDescription sets the "description" field.
|
||||||
func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate {
|
func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate {
|
||||||
_u.mutation.SetDescription(v)
|
_u.mutation.SetDescription(v)
|
||||||
@@ -387,6 +401,9 @@ func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, e
|
|||||||
if _u.mutation.CustomMessageCleared() {
|
if _u.mutation.CustomMessageCleared() {
|
||||||
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
|
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.SkipMonitoring(); ok {
|
||||||
|
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.Description(); ok {
|
if value, ok := _u.mutation.Description(); ok {
|
||||||
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
|
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
|
||||||
}
|
}
|
||||||
@@ -611,6 +628,20 @@ func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughR
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||||
|
func (_u *ErrorPassthroughRuleUpdateOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdateOne {
|
||||||
|
_u.mutation.SetSkipMonitoring(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
|
||||||
|
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetSkipMonitoring(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetDescription sets the "description" field.
|
// SetDescription sets the "description" field.
|
||||||
func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne {
|
func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne {
|
||||||
_u.mutation.SetDescription(v)
|
_u.mutation.SetDescription(v)
|
||||||
@@ -801,6 +832,9 @@ func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *Er
|
|||||||
if _u.mutation.CustomMessageCleared() {
|
if _u.mutation.CustomMessageCleared() {
|
||||||
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
|
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.SkipMonitoring(); ok {
|
||||||
|
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.Description(); ok {
|
if value, ok := _u.mutation.Description(); ok {
|
||||||
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
|
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -325,6 +325,7 @@ var (
|
|||||||
{Name: "response_code", Type: field.TypeInt, Nullable: true},
|
{Name: "response_code", Type: field.TypeInt, Nullable: true},
|
||||||
{Name: "passthrough_body", Type: field.TypeBool, Default: true},
|
{Name: "passthrough_body", Type: field.TypeBool, Default: true},
|
||||||
{Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647},
|
{Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647},
|
||||||
|
{Name: "skip_monitoring", Type: field.TypeBool, Default: false},
|
||||||
{Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647},
|
{Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647},
|
||||||
}
|
}
|
||||||
// ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table.
|
// ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table.
|
||||||
|
|||||||
@@ -5776,6 +5776,7 @@ type ErrorPassthroughRuleMutation struct {
|
|||||||
addresponse_code *int
|
addresponse_code *int
|
||||||
passthrough_body *bool
|
passthrough_body *bool
|
||||||
custom_message *string
|
custom_message *string
|
||||||
|
skip_monitoring *bool
|
||||||
description *string
|
description *string
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
done bool
|
done bool
|
||||||
@@ -6503,6 +6504,42 @@ func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() {
|
|||||||
delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage)
|
delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSkipMonitoring sets the "skip_monitoring" field.
|
||||||
|
func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) {
|
||||||
|
m.skip_monitoring = &b
|
||||||
|
}
|
||||||
|
|
||||||
|
// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation.
|
||||||
|
func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) {
|
||||||
|
v := m.skip_monitoring
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity.
|
||||||
|
// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldSkipMonitoring requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.SkipMonitoring, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetSkipMonitoring resets all changes to the "skip_monitoring" field.
|
||||||
|
func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() {
|
||||||
|
m.skip_monitoring = nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetDescription sets the "description" field.
|
// SetDescription sets the "description" field.
|
||||||
func (m *ErrorPassthroughRuleMutation) SetDescription(s string) {
|
func (m *ErrorPassthroughRuleMutation) SetDescription(s string) {
|
||||||
m.description = &s
|
m.description = &s
|
||||||
@@ -6586,7 +6623,7 @@ func (m *ErrorPassthroughRuleMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *ErrorPassthroughRuleMutation) Fields() []string {
|
func (m *ErrorPassthroughRuleMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 14)
|
fields := make([]string, 0, 15)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, errorpassthroughrule.FieldCreatedAt)
|
fields = append(fields, errorpassthroughrule.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -6626,6 +6663,9 @@ func (m *ErrorPassthroughRuleMutation) Fields() []string {
|
|||||||
if m.custom_message != nil {
|
if m.custom_message != nil {
|
||||||
fields = append(fields, errorpassthroughrule.FieldCustomMessage)
|
fields = append(fields, errorpassthroughrule.FieldCustomMessage)
|
||||||
}
|
}
|
||||||
|
if m.skip_monitoring != nil {
|
||||||
|
fields = append(fields, errorpassthroughrule.FieldSkipMonitoring)
|
||||||
|
}
|
||||||
if m.description != nil {
|
if m.description != nil {
|
||||||
fields = append(fields, errorpassthroughrule.FieldDescription)
|
fields = append(fields, errorpassthroughrule.FieldDescription)
|
||||||
}
|
}
|
||||||
@@ -6663,6 +6703,8 @@ func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.PassthroughBody()
|
return m.PassthroughBody()
|
||||||
case errorpassthroughrule.FieldCustomMessage:
|
case errorpassthroughrule.FieldCustomMessage:
|
||||||
return m.CustomMessage()
|
return m.CustomMessage()
|
||||||
|
case errorpassthroughrule.FieldSkipMonitoring:
|
||||||
|
return m.SkipMonitoring()
|
||||||
case errorpassthroughrule.FieldDescription:
|
case errorpassthroughrule.FieldDescription:
|
||||||
return m.Description()
|
return m.Description()
|
||||||
}
|
}
|
||||||
@@ -6700,6 +6742,8 @@ func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string
|
|||||||
return m.OldPassthroughBody(ctx)
|
return m.OldPassthroughBody(ctx)
|
||||||
case errorpassthroughrule.FieldCustomMessage:
|
case errorpassthroughrule.FieldCustomMessage:
|
||||||
return m.OldCustomMessage(ctx)
|
return m.OldCustomMessage(ctx)
|
||||||
|
case errorpassthroughrule.FieldSkipMonitoring:
|
||||||
|
return m.OldSkipMonitoring(ctx)
|
||||||
case errorpassthroughrule.FieldDescription:
|
case errorpassthroughrule.FieldDescription:
|
||||||
return m.OldDescription(ctx)
|
return m.OldDescription(ctx)
|
||||||
}
|
}
|
||||||
@@ -6802,6 +6846,13 @@ func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) er
|
|||||||
}
|
}
|
||||||
m.SetCustomMessage(v)
|
m.SetCustomMessage(v)
|
||||||
return nil
|
return nil
|
||||||
|
case errorpassthroughrule.FieldSkipMonitoring:
|
||||||
|
v, ok := value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetSkipMonitoring(v)
|
||||||
|
return nil
|
||||||
case errorpassthroughrule.FieldDescription:
|
case errorpassthroughrule.FieldDescription:
|
||||||
v, ok := value.(string)
|
v, ok := value.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -6963,6 +7014,9 @@ func (m *ErrorPassthroughRuleMutation) ResetField(name string) error {
|
|||||||
case errorpassthroughrule.FieldCustomMessage:
|
case errorpassthroughrule.FieldCustomMessage:
|
||||||
m.ResetCustomMessage()
|
m.ResetCustomMessage()
|
||||||
return nil
|
return nil
|
||||||
|
case errorpassthroughrule.FieldSkipMonitoring:
|
||||||
|
m.ResetSkipMonitoring()
|
||||||
|
return nil
|
||||||
case errorpassthroughrule.FieldDescription:
|
case errorpassthroughrule.FieldDescription:
|
||||||
m.ResetDescription()
|
m.ResetDescription()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -326,6 +326,10 @@ func init() {
|
|||||||
errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor()
|
errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor()
|
||||||
// errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field.
|
// errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field.
|
||||||
errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool)
|
errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool)
|
||||||
|
// errorpassthroughruleDescSkipMonitoring is the schema descriptor for skip_monitoring field.
|
||||||
|
errorpassthroughruleDescSkipMonitoring := errorpassthroughruleFields[11].Descriptor()
|
||||||
|
// errorpassthroughrule.DefaultSkipMonitoring holds the default value on creation for the skip_monitoring field.
|
||||||
|
errorpassthroughrule.DefaultSkipMonitoring = errorpassthroughruleDescSkipMonitoring.Default.(bool)
|
||||||
groupMixin := schema.Group{}.Mixin()
|
groupMixin := schema.Group{}.Mixin()
|
||||||
groupMixinHooks1 := groupMixin[1].Hooks()
|
groupMixinHooks1 := groupMixin[1].Hooks()
|
||||||
group.Hooks[0] = groupMixinHooks1[0]
|
group.Hooks[0] = groupMixinHooks1[0]
|
||||||
|
|||||||
@@ -105,6 +105,12 @@ func (ErrorPassthroughRule) Fields() []ent.Field {
|
|||||||
Optional().
|
Optional().
|
||||||
Nillable(),
|
Nillable(),
|
||||||
|
|
||||||
|
// skip_monitoring: 是否跳过运维监控记录
|
||||||
|
// true: 匹配此规则的错误不会被记录到 ops_error_logs
|
||||||
|
// false: 正常记录到运维监控(默认行为)
|
||||||
|
field.Bool("skip_monitoring").
|
||||||
|
Default(false),
|
||||||
|
|
||||||
// description: 规则描述,用于说明规则的用途
|
// description: 规则描述,用于说明规则的用途
|
||||||
field.Text("description").
|
field.Text("description").
|
||||||
Optional().
|
Optional().
|
||||||
|
|||||||
@@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
|
|||||||
pageSize := dataPageCap
|
pageSize := dataPageCap
|
||||||
var out []service.Account
|
var out []service.Account
|
||||||
for {
|
for {
|
||||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search)
|
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -156,7 +156,12 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
search = search[:100]
|
search = search[:100]
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
var groupID int64
|
||||||
|
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||||
|
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -424,10 +429,17 @@ type TestAccountRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SyncFromCRSRequest struct {
|
type SyncFromCRSRequest struct {
|
||||||
BaseURL string `json:"base_url" binding:"required"`
|
BaseURL string `json:"base_url" binding:"required"`
|
||||||
Username string `json:"username" binding:"required"`
|
Username string `json:"username" binding:"required"`
|
||||||
Password string `json:"password" binding:"required"`
|
Password string `json:"password" binding:"required"`
|
||||||
SyncProxies *bool `json:"sync_proxies"`
|
SyncProxies *bool `json:"sync_proxies"`
|
||||||
|
SelectedAccountIDs []string `json:"selected_account_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreviewFromCRSRequest struct {
|
||||||
|
BaseURL string `json:"base_url" binding:"required"`
|
||||||
|
Username string `json:"username" binding:"required"`
|
||||||
|
Password string `json:"password" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test handles testing account connectivity with SSE streaming
|
// Test handles testing account connectivity with SSE streaming
|
||||||
@@ -466,10 +478,11 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
|
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
|
||||||
BaseURL: req.BaseURL,
|
BaseURL: req.BaseURL,
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Password: req.Password,
|
Password: req.Password,
|
||||||
SyncProxies: syncProxies,
|
SyncProxies: syncProxies,
|
||||||
|
SelectedAccountIDs: req.SelectedAccountIDs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Provide detailed error message for CRS sync failures
|
// Provide detailed error message for CRS sync failures
|
||||||
@@ -480,6 +493,28 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
|
|||||||
response.Success(c, result)
|
response.Success(c, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PreviewFromCRS handles previewing accounts from CRS before sync
|
||||||
|
// POST /api/v1/admin/accounts/sync/crs/preview
|
||||||
|
func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
|
||||||
|
var req PreviewFromCRSRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.crsSyncService.PreviewFromCRS(c.Request.Context(), service.SyncFromCRSInput{
|
||||||
|
BaseURL: req.BaseURL,
|
||||||
|
Username: req.Username,
|
||||||
|
Password: req.Password,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "CRS preview failed: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
// Refresh handles refreshing account credentials
|
// Refresh handles refreshing account credentials
|
||||||
// POST /api/v1/admin/accounts/:id/refresh
|
// POST /api/v1/admin/accounts/:id/refresh
|
||||||
func (h *AccountHandler) Refresh(c *gin.Context) {
|
func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||||
@@ -1399,7 +1434,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
|||||||
accounts := make([]*service.Account, 0)
|
accounts := make([]*service.Account, 0)
|
||||||
|
|
||||||
if len(req.AccountIDs) == 0 {
|
if len(req.AccountIDs) == 0 {
|
||||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
|
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
|||||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
|
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||||
return s.accounts, int64(len(s.accounts)), nil
|
return s.accounts, int64(len(s.accounts)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -65,3 +65,27 @@ func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, tokenInfo)
|
response.Success(c, tokenInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AntigravityRefreshTokenRequest represents the request for validating Antigravity refresh token
|
||||||
|
type AntigravityRefreshTokenRequest struct {
|
||||||
|
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||||
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken validates an Antigravity refresh token and returns full token info
|
||||||
|
// POST /api/v1/admin/antigravity/oauth/refresh-token
|
||||||
|
func (h *AntigravityOAuthHandler) RefreshToken(c *gin.Context) {
|
||||||
|
var req AntigravityRefreshTokenRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "请求无效: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, tokenInfo)
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ type CreateErrorPassthroughRuleRequest struct {
|
|||||||
ResponseCode *int `json:"response_code"`
|
ResponseCode *int `json:"response_code"`
|
||||||
PassthroughBody *bool `json:"passthrough_body"`
|
PassthroughBody *bool `json:"passthrough_body"`
|
||||||
CustomMessage *string `json:"custom_message"`
|
CustomMessage *string `json:"custom_message"`
|
||||||
|
SkipMonitoring *bool `json:"skip_monitoring"`
|
||||||
Description *string `json:"description"`
|
Description *string `json:"description"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,6 +49,7 @@ type UpdateErrorPassthroughRuleRequest struct {
|
|||||||
ResponseCode *int `json:"response_code"`
|
ResponseCode *int `json:"response_code"`
|
||||||
PassthroughBody *bool `json:"passthrough_body"`
|
PassthroughBody *bool `json:"passthrough_body"`
|
||||||
CustomMessage *string `json:"custom_message"`
|
CustomMessage *string `json:"custom_message"`
|
||||||
|
SkipMonitoring *bool `json:"skip_monitoring"`
|
||||||
Description *string `json:"description"`
|
Description *string `json:"description"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,6 +124,9 @@ func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
rule.PassthroughBody = true
|
rule.PassthroughBody = true
|
||||||
}
|
}
|
||||||
|
if req.SkipMonitoring != nil {
|
||||||
|
rule.SkipMonitoring = *req.SkipMonitoring
|
||||||
|
}
|
||||||
rule.ResponseCode = req.ResponseCode
|
rule.ResponseCode = req.ResponseCode
|
||||||
rule.CustomMessage = req.CustomMessage
|
rule.CustomMessage = req.CustomMessage
|
||||||
rule.Description = req.Description
|
rule.Description = req.Description
|
||||||
@@ -190,6 +195,7 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
|
|||||||
ResponseCode: existing.ResponseCode,
|
ResponseCode: existing.ResponseCode,
|
||||||
PassthroughBody: existing.PassthroughBody,
|
PassthroughBody: existing.PassthroughBody,
|
||||||
CustomMessage: existing.CustomMessage,
|
CustomMessage: existing.CustomMessage,
|
||||||
|
SkipMonitoring: existing.SkipMonitoring,
|
||||||
Description: existing.Description,
|
Description: existing.Description,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,6 +236,9 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
|
|||||||
if req.Description != nil {
|
if req.Description != nil {
|
||||||
rule.Description = req.Description
|
rule.Description = req.Description
|
||||||
}
|
}
|
||||||
|
if req.SkipMonitoring != nil {
|
||||||
|
rule.SkipMonitoring = *req.SkipMonitoring
|
||||||
|
}
|
||||||
|
|
||||||
// 确保切片不为 nil
|
// 确保切片不为 nil
|
||||||
if rule.ErrorCodes == nil {
|
if rule.ErrorCodes == nil {
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
writer := csv.NewWriter(&buf)
|
writer := csv.NewWriter(&buf)
|
||||||
|
|
||||||
// Write header
|
// Write header
|
||||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
|
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil {
|
||||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -213,6 +213,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
if code.UsedBy != nil {
|
if code.UsedBy != nil {
|
||||||
usedBy = fmt.Sprintf("%d", *code.UsedBy)
|
usedBy = fmt.Sprintf("%d", *code.UsedBy)
|
||||||
}
|
}
|
||||||
|
usedByEmail := ""
|
||||||
|
if code.User != nil {
|
||||||
|
usedByEmail = code.User.Email
|
||||||
|
}
|
||||||
usedAt := ""
|
usedAt := ""
|
||||||
if code.UsedAt != nil {
|
if code.UsedAt != nil {
|
||||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||||
@@ -224,6 +228,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
fmt.Sprintf("%.2f", code.Value),
|
fmt.Sprintf("%.2f", code.Value),
|
||||||
code.Status,
|
code.Status,
|
||||||
usedBy,
|
usedBy,
|
||||||
|
usedByEmail,
|
||||||
usedAt,
|
usedAt,
|
||||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
|||||||
@@ -235,9 +235,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
|
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||||
|
|
||||||
|
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||||
|
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||||
|
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -245,6 +253,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||||
|
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||||
|
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||||
|
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||||
|
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||||
|
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||||
|
failedAccountIDs = make(map[int64]struct{})
|
||||||
|
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
if lastFailoverErr != nil {
|
if lastFailoverErr != nil {
|
||||||
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
|
||||||
} else {
|
} else {
|
||||||
@@ -339,11 +360,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
|
||||||
lastFailoverErr = failoverErr
|
lastFailoverErr = failoverErr
|
||||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||||
forceCacheBilling = true
|
forceCacheBilling = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||||
|
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
||||||
|
sameAccountRetryCount[account.ID]++
|
||||||
|
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||||
|
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
||||||
|
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同账号重试用尽,执行临时封禁并切换账号
|
||||||
|
if failoverErr.RetryableOnSameAccount {
|
||||||
|
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
||||||
return
|
return
|
||||||
@@ -396,10 +434,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
fallbackUsed := false
|
fallbackUsed := false
|
||||||
|
|
||||||
|
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||||
|
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||||
|
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) {
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
maxAccountSwitches := h.maxAccountSwitches
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
|
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
retryWithFallback := false
|
retryWithFallback := false
|
||||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||||
@@ -412,6 +458,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||||
|
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||||
|
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||||
|
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||||
|
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||||
|
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||||
|
failedAccountIDs = make(map[int64]struct{})
|
||||||
|
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
if lastFailoverErr != nil {
|
if lastFailoverErr != nil {
|
||||||
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
|
||||||
} else {
|
} else {
|
||||||
@@ -539,11 +598,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
|
||||||
lastFailoverErr = failoverErr
|
lastFailoverErr = failoverErr
|
||||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||||
forceCacheBilling = true
|
forceCacheBilling = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||||
|
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
||||||
|
sameAccountRetryCount[account.ID]++
|
||||||
|
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||||
|
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
||||||
|
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同账号重试用尽,执行临时封禁并切换账号
|
||||||
|
if failoverErr.RetryableOnSameAccount {
|
||||||
|
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
||||||
return
|
return
|
||||||
@@ -823,6 +899,23 @@ func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFa
|
|||||||
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||||
|
maxSameAccountRetries = 2
|
||||||
|
// sameAccountRetryDelay 同账号重试间隔
|
||||||
|
sameAccountRetryDelay = 500 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
|
||||||
|
func sleepSameAccountRetryDelay(ctx context.Context) bool {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
case <-time.After(sameAccountRetryDelay):
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
|
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
|
||||||
// 返回 false 表示 context 已取消。
|
// 返回 false 表示 context 已取消。
|
||||||
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
||||||
@@ -838,6 +931,27 @@ func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
|
||||||
|
// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用,
|
||||||
|
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
|
||||||
|
// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。
|
||||||
|
// 返回 false 表示 context 已取消。
|
||||||
|
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
|
||||||
|
// 固定短延时:2s
|
||||||
|
// Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数),
|
||||||
|
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||||
|
const delay = 2 * time.Second
|
||||||
|
|
||||||
|
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
case <-time.After(delay):
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||||
statusCode := failoverErr.StatusCode
|
statusCode := failoverErr.StatusCode
|
||||||
responseBody := failoverErr.ResponseBody
|
responseBody := failoverErr.ResponseBody
|
||||||
@@ -857,6 +971,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
|
|||||||
msg = *rule.CustomMessage
|
msg = *rule.CustomMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rule.SkipMonitoring {
|
||||||
|
c.Set(service.OpsSkipPassthroughKey, true)
|
||||||
|
}
|
||||||
|
|
||||||
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
|
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// sleepAntigravitySingleAccountBackoff 测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.True(t, ok, "should return true when context is not canceled")
|
||||||
|
// 固定延迟 2s
|
||||||
|
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
|
||||||
|
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // 立即取消
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.False(t, ok, "should return false when context is canceled")
|
||||||
|
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
|
||||||
|
// 验证不同 retryCount 都使用固定 2s 延迟
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.True(t, ok)
|
||||||
|
// 即使 retryCount=5,延迟仍然是固定的 2s
|
||||||
|
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
|
||||||
|
require.Less(t, elapsed, 5*time.Second)
|
||||||
|
}
|
||||||
@@ -327,6 +327,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
var lastFailoverErr *service.UpstreamFailoverError
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||||
|
|
||||||
|
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||||
|
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||||
|
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -334,6 +341,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||||
|
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||||
|
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||||
|
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||||
|
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||||
|
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||||
|
failedAccountIDs = make(map[int64]struct{})
|
||||||
|
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -534,6 +554,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
|
|||||||
msg = *rule.CustomMessage
|
msg = *rule.CustomMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rule.SkipMonitoring {
|
||||||
|
c.Set(service.OpsSkipPassthroughKey, true)
|
||||||
|
}
|
||||||
|
|
||||||
googleError(c, respCode, msg)
|
googleError(c, respCode, msg)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -354,6 +354,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
|
|||||||
msg = *rule.CustomMessage
|
msg = *rule.CustomMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rule.SkipMonitoring {
|
||||||
|
c.Set(service.OpsSkipPassthroughKey, true)
|
||||||
|
}
|
||||||
|
|
||||||
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
|
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -537,6 +537,13 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
|
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
|
||||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||||
|
|
||||||
|
// Skip logging if a passthrough rule with skip_monitoring=true matched.
|
||||||
|
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
|
||||||
|
if skip, _ := v.(bool); skip {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
enqueueOpsErrorLog(ops, entry, requestBody)
|
enqueueOpsErrorLog(ops, entry, requestBody)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -544,6 +551,13 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
body := w.buf.Bytes()
|
body := w.buf.Bytes()
|
||||||
parsed := parseOpsErrorResponse(body)
|
parsed := parseOpsErrorResponse(body)
|
||||||
|
|
||||||
|
// Skip logging if a passthrough rule with skip_monitoring=true matched.
|
||||||
|
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
|
||||||
|
if skip, _ := v.(bool); skip {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Skip logging if the error should be filtered based on settings
|
// Skip logging if the error should be filtered based on settings
|
||||||
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
|
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type ErrorPassthroughRule struct {
|
|||||||
ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用)
|
ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用)
|
||||||
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
|
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
|
||||||
CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用)
|
CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用)
|
||||||
|
SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录
|
||||||
Description *string `json:"description"` // 规则描述
|
Description *string `json:"description"` // 规则描述
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ type ClaudeMessage struct {
|
|||||||
|
|
||||||
// ThinkingConfig Thinking 配置
|
// ThinkingConfig Thinking 配置
|
||||||
type ThinkingConfig struct {
|
type ThinkingConfig struct {
|
||||||
Type string `json:"type"` // "enabled" or "disabled"
|
Type string `json:"type"` // "enabled" / "adaptive" / "disabled"
|
||||||
BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget
|
BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -115,6 +115,23 @@ type LoadCodeAssistResponse struct {
|
|||||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnboardUserRequest onboardUser 请求
|
||||||
|
type OnboardUserRequest struct {
|
||||||
|
TierID string `json:"tierId"`
|
||||||
|
Metadata struct {
|
||||||
|
IDEType string `json:"ideType"`
|
||||||
|
Platform string `json:"platform,omitempty"`
|
||||||
|
PluginType string `json:"pluginType,omitempty"`
|
||||||
|
} `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnboardUserResponse onboardUser 响应
|
||||||
|
type OnboardUserResponse struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
Response map[string]any `json:"response,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// GetTier 获取账户类型
|
// GetTier 获取账户类型
|
||||||
// 优先返回 paidTier(付费订阅级别),否则返回 currentTier
|
// 优先返回 paidTier(付费订阅级别),否则返回 currentTier
|
||||||
func (r *LoadCodeAssistResponse) GetTier() string {
|
func (r *LoadCodeAssistResponse) GetTier() string {
|
||||||
@@ -361,6 +378,117 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
|||||||
return nil, nil, lastErr
|
return nil, nil, lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnboardUser 触发账号 onboarding,并返回 project_id
|
||||||
|
// 说明:
|
||||||
|
// 1) 部分账号 loadCodeAssist 不会立即返回 cloudaicompanionProject;
|
||||||
|
// 2) 这时需要调用 onboardUser 完成初始化,之后才能拿到 project_id。
|
||||||
|
func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
|
||||||
|
tierID = strings.TrimSpace(tierID)
|
||||||
|
if tierID == "" {
|
||||||
|
return "", fmt.Errorf("tier_id 为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := OnboardUserRequest{TierID: tierID}
|
||||||
|
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||||
|
reqBody.Metadata.Platform = "PLATFORM_UNSPECIFIED"
|
||||||
|
reqBody.Metadata.PluginType = "GEMINI"
|
||||||
|
|
||||||
|
bodyBytes, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
availableURLs := BaseURLs
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for urlIdx, baseURL := range availableURLs {
|
||||||
|
apiURL := baseURL + "/v1internal:onboardUser"
|
||||||
|
|
||||||
|
for attempt := 1; attempt <= 5; attempt++ {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes))
|
||||||
|
if err != nil {
|
||||||
|
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = fmt.Errorf("onboardUser 请求失败: %w", err)
|
||||||
|
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||||
|
log.Printf("[antigravity] onboardUser URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return "", lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("读取响应失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||||
|
log.Printf("[antigravity] onboardUser URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
lastErr = fmt.Errorf("onboardUser 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||||
|
return "", lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
var onboardResp OnboardUserResponse
|
||||||
|
if err := json.Unmarshal(respBodyBytes, &onboardResp); err != nil {
|
||||||
|
lastErr = fmt.Errorf("onboardUser 响应解析失败: %w", err)
|
||||||
|
return "", lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if onboardResp.Done {
|
||||||
|
if projectID := extractProjectIDFromOnboardResponse(onboardResp.Response); projectID != "" {
|
||||||
|
DefaultURLAvailability.MarkSuccess(baseURL)
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
lastErr = fmt.Errorf("onboardUser 完成但未返回 project_id")
|
||||||
|
return "", lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// done=false 时等待后重试(与 CLIProxyAPI 行为一致)
|
||||||
|
select {
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr != nil {
|
||||||
|
return "", lastErr
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("onboardUser 未返回 project_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractProjectIDFromOnboardResponse(resp map[string]any) string {
|
||||||
|
if len(resp) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := resp["cloudaicompanionProject"]; ok {
|
||||||
|
switch project := v.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(project)
|
||||||
|
case map[string]any:
|
||||||
|
if id, ok := project["id"].(string); ok {
|
||||||
|
return strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// ModelQuotaInfo 模型配额信息
|
// ModelQuotaInfo 模型配额信息
|
||||||
type ModelQuotaInfo struct {
|
type ModelQuotaInfo struct {
|
||||||
RemainingFraction float64 `json:"remainingFraction"`
|
RemainingFraction float64 `json:"remainingFraction"`
|
||||||
|
|||||||
76
backend/internal/pkg/antigravity/client_test.go
Normal file
76
backend/internal/pkg/antigravity/client_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package antigravity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractProjectIDFromOnboardResponse(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
resp map[string]any
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil response",
|
||||||
|
resp: nil,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty response",
|
||||||
|
resp: map[string]any{},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "project as string",
|
||||||
|
resp: map[string]any{
|
||||||
|
"cloudaicompanionProject": "my-project-123",
|
||||||
|
},
|
||||||
|
want: "my-project-123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "project as string with spaces",
|
||||||
|
resp: map[string]any{
|
||||||
|
"cloudaicompanionProject": " my-project-123 ",
|
||||||
|
},
|
||||||
|
want: "my-project-123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "project as map with id",
|
||||||
|
resp: map[string]any{
|
||||||
|
"cloudaicompanionProject": map[string]any{
|
||||||
|
"id": "proj-from-map",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: "proj-from-map",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "project as map without id",
|
||||||
|
resp: map[string]any{
|
||||||
|
"cloudaicompanionProject": map[string]any{
|
||||||
|
"name": "some-name",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing cloudaicompanionProject key",
|
||||||
|
resp: map[string]any{
|
||||||
|
"otherField": "value",
|
||||||
|
},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := extractProjectIDFromOnboardResponse(tc.resp)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Fatalf("extractProjectIDFromOnboardResponse() = %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -155,6 +155,7 @@ type GeminiUsageMetadata struct {
|
|||||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||||
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
||||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||||
|
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
|
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
|
||||||
|
|||||||
@@ -64,6 +64,10 @@ const MaxTokensBudgetPadding = 1000
|
|||||||
// Gemini 2.5 Flash thinking budget 上限
|
// Gemini 2.5 Flash thinking budget 上限
|
||||||
const Gemini25FlashThinkingBudgetLimit = 24576
|
const Gemini25FlashThinkingBudgetLimit = 24576
|
||||||
|
|
||||||
|
// 对于 Antigravity 的 Claude(budget-only)模型,该语义最终等价为 thinkingBudget=24576。
|
||||||
|
// 这里复用相同数值以保持行为一致。
|
||||||
|
const ClaudeAdaptiveHighThinkingBudgetTokens = Gemini25FlashThinkingBudgetLimit
|
||||||
|
|
||||||
// ensureMaxTokensGreaterThanBudget 确保 max_tokens > budget_tokens
|
// ensureMaxTokensGreaterThanBudget 确保 max_tokens > budget_tokens
|
||||||
// Claude API 要求启用 thinking 时,max_tokens 必须大于 thinking.budget_tokens
|
// Claude API 要求启用 thinking 时,max_tokens 必须大于 thinking.budget_tokens
|
||||||
// 返回调整后的 maxTokens 和是否进行了调整
|
// 返回调整后的 maxTokens 和是否进行了调整
|
||||||
@@ -96,7 +100,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 检测是否启用 thinking
|
// 检测是否启用 thinking
|
||||||
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
isThinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
||||||
|
|
||||||
// 只有 Gemini 模型支持 dummy thought workaround
|
// 只有 Gemini 模型支持 dummy thought workaround
|
||||||
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
||||||
@@ -198,8 +202,7 @@ type modelInfo struct {
|
|||||||
|
|
||||||
// modelInfoMap 模型前缀 → 模型信息映射
|
// modelInfoMap 模型前缀 → 模型信息映射
|
||||||
// 只有在此映射表中的模型才会注入身份提示词
|
// 只有在此映射表中的模型才会注入身份提示词
|
||||||
// 注意:当前 claude-opus-4-6 会被映射到 claude-opus-4-5-thinking,
|
// 注意:模型映射逻辑在网关层完成;这里仅用于按模型前缀判断是否注入身份提示词。
|
||||||
// 但保留此条目以便后续 Antigravity 上游支持 4.6 时快速切换
|
|
||||||
var modelInfoMap = map[string]modelInfo{
|
var modelInfoMap = map[string]modelInfo{
|
||||||
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
|
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
|
||||||
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
|
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
|
||||||
@@ -271,6 +274,21 @@ func filterOpenCodePrompt(text string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
||||||
|
var systemBlockFilterPrefixes = []string{
|
||||||
|
"x-anthropic-billing-header",
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
|
||||||
|
func filterSystemBlockByPrefix(text string) string {
|
||||||
|
for _, prefix := range systemBlockFilterPrefixes {
|
||||||
|
if strings.HasPrefix(text, prefix) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
|
||||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
|
||||||
var parts []GeminiPart
|
var parts []GeminiPart
|
||||||
@@ -287,8 +305,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
if strings.Contains(sysStr, "You are Antigravity") {
|
if strings.Contains(sysStr, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
// 过滤 OpenCode 默认提示词
|
// 过滤 OpenCode 默认提示词和黑名单前缀
|
||||||
filtered := filterOpenCodePrompt(sysStr)
|
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
|
||||||
if filtered != "" {
|
if filtered != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
}
|
}
|
||||||
@@ -302,8 +320,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
|
|||||||
if strings.Contains(block.Text, "You are Antigravity") {
|
if strings.Contains(block.Text, "You are Antigravity") {
|
||||||
userHasAntigravityIdentity = true
|
userHasAntigravityIdentity = true
|
||||||
}
|
}
|
||||||
// 过滤 OpenCode 默认提示词
|
// 过滤 OpenCode 默认提示词和黑名单前缀
|
||||||
filtered := filterOpenCodePrompt(block.Text)
|
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
|
||||||
if filtered != "" {
|
if filtered != "" {
|
||||||
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
|
||||||
}
|
}
|
||||||
@@ -578,6 +596,10 @@ func maxOutputTokensLimit(model string) int {
|
|||||||
return maxOutputTokensUpperBound
|
return maxOutputTokensUpperBound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isAntigravityOpus46Model(model string) bool {
|
||||||
|
return strings.HasPrefix(strings.ToLower(model), "claude-opus-4-6")
|
||||||
|
}
|
||||||
|
|
||||||
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
||||||
maxLimit := maxOutputTokensLimit(req.Model)
|
maxLimit := maxOutputTokensLimit(req.Model)
|
||||||
config := &GeminiGenerationConfig{
|
config := &GeminiGenerationConfig{
|
||||||
@@ -591,25 +613,36 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Thinking 配置
|
// Thinking 配置
|
||||||
if req.Thinking != nil && req.Thinking.Type == "enabled" {
|
if req.Thinking != nil && (req.Thinking.Type == "enabled" || req.Thinking.Type == "adaptive") {
|
||||||
config.ThinkingConfig = &GeminiThinkingConfig{
|
config.ThinkingConfig = &GeminiThinkingConfig{
|
||||||
IncludeThoughts: true,
|
IncludeThoughts: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// - thinking.type=enabled:budget_tokens>0 用显式预算
|
||||||
|
// - thinking.type=adaptive:仅在 Antigravity 的 Opus 4.6 上覆写为 (24576)
|
||||||
|
budget := -1
|
||||||
if req.Thinking.BudgetTokens > 0 {
|
if req.Thinking.BudgetTokens > 0 {
|
||||||
budget := req.Thinking.BudgetTokens
|
budget = req.Thinking.BudgetTokens
|
||||||
|
}
|
||||||
|
if req.Thinking.Type == "adaptive" && isAntigravityOpus46Model(req.Model) {
|
||||||
|
budget = ClaudeAdaptiveHighThinkingBudgetTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// 正预算需要做上限与 max_tokens 约束;动态预算(-1)直接透传给上游。
|
||||||
|
if budget > 0 {
|
||||||
// gemini-2.5-flash 上限
|
// gemini-2.5-flash 上限
|
||||||
if strings.Contains(req.Model, "gemini-2.5-flash") && budget > Gemini25FlashThinkingBudgetLimit {
|
if strings.Contains(req.Model, "gemini-2.5-flash") && budget > Gemini25FlashThinkingBudgetLimit {
|
||||||
budget = Gemini25FlashThinkingBudgetLimit
|
budget = Gemini25FlashThinkingBudgetLimit
|
||||||
}
|
}
|
||||||
config.ThinkingConfig.ThinkingBudget = budget
|
|
||||||
|
|
||||||
// 自动修正:max_tokens 必须大于 budget_tokens
|
// 自动修正:max_tokens 必须大于 budget_tokens(Claude 上游要求)
|
||||||
if adjusted, ok := ensureMaxTokensGreaterThanBudget(config.MaxOutputTokens, budget); ok {
|
if adjusted, ok := ensureMaxTokensGreaterThanBudget(config.MaxOutputTokens, budget); ok {
|
||||||
log.Printf("[Antigravity] Auto-adjusted max_tokens from %d to %d (must be > budget_tokens=%d)",
|
log.Printf("[Antigravity] Auto-adjusted max_tokens from %d to %d (must be > budget_tokens=%d)",
|
||||||
config.MaxOutputTokens, adjusted, budget)
|
config.MaxOutputTokens, adjusted, budget)
|
||||||
config.MaxOutputTokens = adjusted
|
config.MaxOutputTokens = adjusted
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
config.ThinkingConfig.ThinkingBudget = budget
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.MaxOutputTokens > maxLimit {
|
if config.MaxOutputTokens > maxLimit {
|
||||||
|
|||||||
@@ -259,3 +259,93 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
thinking *ThinkingConfig
|
||||||
|
wantBudget int
|
||||||
|
wantPresent bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enabled without budget defaults to dynamic (-1)",
|
||||||
|
model: "claude-opus-4-6-thinking",
|
||||||
|
thinking: &ThinkingConfig{Type: "enabled"},
|
||||||
|
wantBudget: -1,
|
||||||
|
wantPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled with budget uses the provided value",
|
||||||
|
model: "claude-opus-4-6-thinking",
|
||||||
|
thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1024},
|
||||||
|
wantBudget: 1024,
|
||||||
|
wantPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled with -1 budget uses dynamic (-1)",
|
||||||
|
model: "claude-opus-4-6-thinking",
|
||||||
|
thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: -1},
|
||||||
|
wantBudget: -1,
|
||||||
|
wantPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "adaptive on opus4.6 maps to high budget (24576)",
|
||||||
|
model: "claude-opus-4-6-thinking",
|
||||||
|
thinking: &ThinkingConfig{Type: "adaptive", BudgetTokens: 20000},
|
||||||
|
wantBudget: ClaudeAdaptiveHighThinkingBudgetTokens,
|
||||||
|
wantPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "adaptive on non-opus model keeps default dynamic (-1)",
|
||||||
|
model: "claude-sonnet-4-5-thinking",
|
||||||
|
thinking: &ThinkingConfig{Type: "adaptive"},
|
||||||
|
wantBudget: -1,
|
||||||
|
wantPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled does not emit thinkingConfig",
|
||||||
|
model: "claude-opus-4-6-thinking",
|
||||||
|
thinking: &ThinkingConfig{Type: "disabled", BudgetTokens: 1024},
|
||||||
|
wantBudget: 0,
|
||||||
|
wantPresent: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil thinking does not emit thinkingConfig",
|
||||||
|
model: "claude-opus-4-6-thinking",
|
||||||
|
thinking: nil,
|
||||||
|
wantBudget: 0,
|
||||||
|
wantPresent: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := &ClaudeRequest{
|
||||||
|
Model: tt.model,
|
||||||
|
Thinking: tt.thinking,
|
||||||
|
}
|
||||||
|
cfg := buildGenerationConfig(req)
|
||||||
|
if cfg == nil {
|
||||||
|
t.Fatalf("expected non-nil generationConfig")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantPresent {
|
||||||
|
if cfg.ThinkingConfig == nil {
|
||||||
|
t.Fatalf("expected thinkingConfig to be present")
|
||||||
|
}
|
||||||
|
if !cfg.ThinkingConfig.IncludeThoughts {
|
||||||
|
t.Fatalf("expected includeThoughts=true")
|
||||||
|
}
|
||||||
|
if cfg.ThinkingConfig.ThinkingBudget != tt.wantBudget {
|
||||||
|
t.Fatalf("expected thinkingBudget=%d, got %d", tt.wantBudget, cfg.ThinkingConfig.ThinkingBudget)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.ThinkingConfig != nil {
|
||||||
|
t.Fatalf("expected thinkingConfig to be nil, got %+v", cfg.ThinkingConfig)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
|||||||
if geminiResp.UsageMetadata != nil {
|
if geminiResp.UsageMetadata != nil {
|
||||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||||
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||||
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
|
||||||
usage.CacheReadInputTokens = cached
|
usage.CacheReadInputTokens = cached
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
|||||||
if geminiResp.UsageMetadata != nil {
|
if geminiResp.UsageMetadata != nil {
|
||||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||||
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||||
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
|
||||||
p.cacheReadTokens = cached
|
p.cacheReadTokens = cached
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,7 +146,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
if v1Resp.Response.UsageMetadata != nil {
|
if v1Resp.Response.UsageMetadata != nil {
|
||||||
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
|
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
|
||||||
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
|
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
|
||||||
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
|
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
|
||||||
usage.CacheReadInputTokens = cached
|
usage.CacheReadInputTokens = cached
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,4 +28,8 @@ const (
|
|||||||
// IsMaxTokensOneHaikuRequest 标识当前请求是否为 max_tokens=1 + haiku 模型的探测请求
|
// IsMaxTokensOneHaikuRequest 标识当前请求是否为 max_tokens=1 + haiku 模型的探测请求
|
||||||
// 用于 ClaudeCodeOnly 验证绕过(绕过 system prompt 检查,但仍需验证 User-Agent)
|
// 用于 ClaudeCodeOnly 验证绕过(绕过 system prompt 检查,但仍需验证 User-Agent)
|
||||||
IsMaxTokensOneHaikuRequest Key = "ctx_is_max_tokens_one_haiku"
|
IsMaxTokensOneHaikuRequest Key = "ctx_is_max_tokens_one_haiku"
|
||||||
|
|
||||||
|
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
|
||||||
|
// 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。
|
||||||
|
SingleAccountRetry Key = "ctx_single_account_retry"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -282,6 +282,34 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
|
|||||||
return &accounts[0], nil
|
return &accounts[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||||
|
rows, err := r.sql.QueryContext(ctx, `
|
||||||
|
SELECT id, extra->>'crs_account_id'
|
||||||
|
FROM accounts
|
||||||
|
WHERE deleted_at IS NULL
|
||||||
|
AND extra->>'crs_account_id' IS NOT NULL
|
||||||
|
AND extra->>'crs_account_id' != ''
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
result := make(map[string]int64)
|
||||||
|
for rows.Next() {
|
||||||
|
var id int64
|
||||||
|
var crsID string
|
||||||
|
if err := rows.Scan(&id, &crsID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[crsID] = id
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
|
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -407,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
return r.ListWithFilters(ctx, params, "", "", "", "", 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
q := r.client.Account.Query()
|
q := r.client.Account.Query()
|
||||||
|
|
||||||
if platform != "" {
|
if platform != "" {
|
||||||
@@ -420,11 +448,19 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
|||||||
q = q.Where(dbaccount.TypeEQ(accountType))
|
q = q.Where(dbaccount.TypeEQ(accountType))
|
||||||
}
|
}
|
||||||
if status != "" {
|
if status != "" {
|
||||||
q = q.Where(dbaccount.StatusEQ(status))
|
switch status {
|
||||||
|
case "rate_limited":
|
||||||
|
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
|
||||||
|
default:
|
||||||
|
q = q.Where(dbaccount.StatusEQ(status))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if search != "" {
|
if search != "" {
|
||||||
q = q.Where(dbaccount.NameContainsFold(search))
|
q = q.Where(dbaccount.NameContainsFold(search))
|
||||||
}
|
}
|
||||||
|
if groupID > 0 {
|
||||||
|
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
|
||||||
|
}
|
||||||
|
|
||||||
total, err := q.Count(ctx)
|
total, err := q.Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
|
|
||||||
tt.setup(client)
|
tt.setup(client)
|
||||||
|
|
||||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
|
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(accounts, tt.wantCount)
|
s.Require().Len(accounts, tt.wantCount)
|
||||||
if tt.validate != nil {
|
if tt.validate != nil {
|
||||||
@@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
|||||||
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
||||||
s.Require().Equal(group.ID, got.Groups[0].ID)
|
s.Require().Equal(group.ID, got.Groups[0].ID)
|
||||||
|
|
||||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
|
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0)
|
||||||
s.Require().NoError(err, "ListWithFilters")
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
s.Require().Equal(int64(1), page.Total)
|
s.Require().Equal(int64(1), page.Total)
|
||||||
s.Require().Len(accounts, 1)
|
s.Require().Len(accounts, 1)
|
||||||
|
|||||||
@@ -54,7 +54,8 @@ func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.Err
|
|||||||
SetPriority(rule.Priority).
|
SetPriority(rule.Priority).
|
||||||
SetMatchMode(rule.MatchMode).
|
SetMatchMode(rule.MatchMode).
|
||||||
SetPassthroughCode(rule.PassthroughCode).
|
SetPassthroughCode(rule.PassthroughCode).
|
||||||
SetPassthroughBody(rule.PassthroughBody)
|
SetPassthroughBody(rule.PassthroughBody).
|
||||||
|
SetSkipMonitoring(rule.SkipMonitoring)
|
||||||
|
|
||||||
if len(rule.ErrorCodes) > 0 {
|
if len(rule.ErrorCodes) > 0 {
|
||||||
builder.SetErrorCodes(rule.ErrorCodes)
|
builder.SetErrorCodes(rule.ErrorCodes)
|
||||||
@@ -90,7 +91,8 @@ func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.Err
|
|||||||
SetPriority(rule.Priority).
|
SetPriority(rule.Priority).
|
||||||
SetMatchMode(rule.MatchMode).
|
SetMatchMode(rule.MatchMode).
|
||||||
SetPassthroughCode(rule.PassthroughCode).
|
SetPassthroughCode(rule.PassthroughCode).
|
||||||
SetPassthroughBody(rule.PassthroughBody)
|
SetPassthroughBody(rule.PassthroughBody).
|
||||||
|
SetSkipMonitoring(rule.SkipMonitoring)
|
||||||
|
|
||||||
// 处理可选字段
|
// 处理可选字段
|
||||||
if len(rule.ErrorCodes) > 0 {
|
if len(rule.ErrorCodes) > 0 {
|
||||||
@@ -149,6 +151,7 @@ func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model
|
|||||||
Platforms: e.Platforms,
|
Platforms: e.Platforms,
|
||||||
PassthroughCode: e.PassthroughCode,
|
PassthroughCode: e.PassthroughCode,
|
||||||
PassthroughBody: e.PassthroughBody,
|
PassthroughBody: e.PassthroughBody,
|
||||||
|
SkipMonitoring: e.SkipMonitoring,
|
||||||
CreatedAt: e.CreatedAt,
|
CreatedAt: e.CreatedAt,
|
||||||
UpdatedAt: e.UpdatedAt,
|
UpdatedAt: e.UpdatedAt,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
)
|
)
|
||||||
@@ -106,7 +107,12 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
|
|||||||
q = q.Where(redeemcode.StatusEQ(status))
|
q = q.Where(redeemcode.StatusEQ(status))
|
||||||
}
|
}
|
||||||
if search != "" {
|
if search != "" {
|
||||||
q = q.Where(redeemcode.CodeContainsFold(search))
|
q = q.Where(
|
||||||
|
redeemcode.Or(
|
||||||
|
redeemcode.CodeContainsFold(search),
|
||||||
|
redeemcode.HasUserWith(user.EmailContainsFold(search)),
|
||||||
|
),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
total, err := q.Count(ctx)
|
total, err := q.Count(ctx)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||||
@@ -191,6 +192,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
|||||||
dbuser.EmailContainsFold(filters.Search),
|
dbuser.EmailContainsFold(filters.Search),
|
||||||
dbuser.UsernameContainsFold(filters.Search),
|
dbuser.UsernameContainsFold(filters.Search),
|
||||||
dbuser.NotesContainsFold(filters.Search),
|
dbuser.NotesContainsFold(filters.Search),
|
||||||
|
dbuser.HasAPIKeysWith(apikey.KeyContainsFold(filters.Search)),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -936,7 +936,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
|
|||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1049,6 +1049,10 @@ func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
|
|||||||
return int64(len(ids)), nil
|
return int64(len(ids)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubProxyRepo struct{}
|
type stubProxyRepo struct{}
|
||||||
|
|
||||||
func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error {
|
func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error {
|
||||||
|
|||||||
@@ -209,6 +209,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||||
accounts.POST("", h.Admin.Account.Create)
|
accounts.POST("", h.Admin.Account.Create)
|
||||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||||
|
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
|
||||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||||
@@ -280,6 +281,7 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers)
|
|||||||
{
|
{
|
||||||
antigravity.POST("/oauth/auth-url", h.Admin.AntigravityOAuth.GenerateAuthURL)
|
antigravity.POST("/oauth/auth-url", h.Admin.AntigravityOAuth.GenerateAuthURL)
|
||||||
antigravity.POST("/oauth/exchange-code", h.Admin.AntigravityOAuth.ExchangeCode)
|
antigravity.POST("/oauth/exchange-code", h.Admin.AntigravityOAuth.ExchangeCode)
|
||||||
|
antigravity.POST("/oauth/refresh-token", h.Admin.AntigravityOAuth.RefreshToken)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,11 +25,14 @@ type AccountRepository interface {
|
|||||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||||
// Returns (nil, nil) if not found.
|
// Returns (nil, nil) if not found.
|
||||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
||||||
|
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
|
||||||
|
// for all accounts that have been synced from CRS.
|
||||||
|
ListCRSAccountIDs(ctx context.Context) (map[string]int64, error)
|
||||||
Update(ctx context.Context, account *Account) error
|
Update(ctx context.Context, account *Account) error
|
||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
|
||||||
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
|
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
|
||||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
|
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
|
||||||
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
||||||
ListActive(ctx context.Context) ([]Account, error)
|
ListActive(ctx context.Context) ([]Account, error)
|
||||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||||
|
|||||||
@@ -54,6 +54,10 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st
|
|||||||
panic("unexpected GetByCRSAccountID call")
|
panic("unexpected GetByCRSAccountID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||||
|
panic("unexpected ListCRSAccountIDs call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *accountRepoStub) Update(ctx context.Context, account *Account) error {
|
func (s *accountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||||
panic("unexpected Update call")
|
panic("unexpected Update call")
|
||||||
}
|
}
|
||||||
@@ -71,7 +75,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
|
|||||||
panic("unexpected List call")
|
panic("unexpected List call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||||
panic("unexpected ListWithFilters call")
|
panic("unexpected ListWithFilters call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ type AdminService interface {
|
|||||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||||
|
|
||||||
// Account management
|
// Account management
|
||||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
|
||||||
GetAccount(ctx context.Context, id int64) (*Account, error)
|
GetAccount(ctx context.Context, id int64) (*Account, error)
|
||||||
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
||||||
@@ -1021,9 +1021,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Account management implementations
|
// Account management implementations
|
||||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
|
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
|
||||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct {
|
|||||||
listWithFiltersErr error
|
listWithFiltersErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||||
s.listWithFiltersCalls++
|
s.listWithFiltersCalls++
|
||||||
s.listWithFiltersParams = params
|
s.listWithFiltersParams = params
|
||||||
s.listWithFiltersPlatform = platform
|
s.listWithFiltersPlatform = platform
|
||||||
@@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
|||||||
}
|
}
|
||||||
svc := &adminServiceImpl{accountRepo: repo}
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
|
||||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
|
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, int64(10), total)
|
require.Equal(t, int64(10), total)
|
||||||
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||||
|
|||||||
@@ -16,10 +16,12 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@@ -39,6 +41,12 @@ const (
|
|||||||
antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待)
|
antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待)
|
||||||
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
|
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
|
||||||
|
|
||||||
|
// MODEL_CAPACITY_EXHAUSTED 专用重试参数
|
||||||
|
// 模型容量不足时,所有账号共享同一容量池,切换账号无意义
|
||||||
|
// 使用固定 1s 间隔重试,最多重试 60 次
|
||||||
|
antigravityModelCapacityRetryMaxAttempts = 60
|
||||||
|
antigravityModelCapacityRetryWait = 1 * time.Second
|
||||||
|
|
||||||
// Google RPC 状态和类型常量
|
// Google RPC 状态和类型常量
|
||||||
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
|
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
|
||||||
googleRPCStatusUnavailable = "UNAVAILABLE"
|
googleRPCStatusUnavailable = "UNAVAILABLE"
|
||||||
@@ -46,6 +54,22 @@ const (
|
|||||||
googleRPCTypeErrorInfo = "type.googleapis.com/google.rpc.ErrorInfo"
|
googleRPCTypeErrorInfo = "type.googleapis.com/google.rpc.ErrorInfo"
|
||||||
googleRPCReasonModelCapacityExhausted = "MODEL_CAPACITY_EXHAUSTED"
|
googleRPCReasonModelCapacityExhausted = "MODEL_CAPACITY_EXHAUSTED"
|
||||||
googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED"
|
googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED"
|
||||||
|
|
||||||
|
// 单账号 503 退避重试:Service 层原地重试的最大次数
|
||||||
|
// 在 handleSmartRetry 中,对于 shouldRateLimitModel(长延迟 ≥ 7s)的情况,
|
||||||
|
// 多账号模式下会设限流+切换账号;但单账号模式下改为原地等待+重试。
|
||||||
|
antigravitySingleAccountSmartRetryMaxAttempts = 3
|
||||||
|
|
||||||
|
// 单账号 503 退避重试:原地重试时单次最大等待时间
|
||||||
|
// 防止上游返回过长的 retryDelay 导致请求卡住太久
|
||||||
|
antigravitySingleAccountSmartRetryMaxWait = 15 * time.Second
|
||||||
|
|
||||||
|
// 单账号 503 退避重试:原地重试的总累计等待时间上限
|
||||||
|
// 超过此上限将不再重试,直接返回 503
|
||||||
|
antigravitySingleAccountSmartRetryTotalMaxWait = 30 * time.Second
|
||||||
|
|
||||||
|
// MODEL_CAPACITY_EXHAUSTED 全局去重:重试全部失败后的 cooldown 时间
|
||||||
|
antigravityModelCapacityCooldown = 10 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
|
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
|
||||||
@@ -54,8 +78,15 @@ var antigravityPassthroughErrorMessages = []string{
|
|||||||
"prompt is too long",
|
"prompt is too long",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MODEL_CAPACITY_EXHAUSTED 全局去重:避免多个并发请求同时对同一模型进行容量耗尽重试
|
||||||
|
var (
|
||||||
|
modelCapacityExhaustedMu sync.RWMutex
|
||||||
|
modelCapacityExhaustedUntil = make(map[string]time.Time) // modelName -> cooldown until
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
||||||
|
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
|
||||||
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -117,6 +148,20 @@ type antigravityRetryLoopResult struct {
|
|||||||
resp *http.Response
|
resp *http.Response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveAntigravityForwardBaseURL 解析转发用 base URL。
|
||||||
|
// 默认使用 daily(ForwardBaseURLs 的首个地址);当环境变量为 prod 时使用第二个地址。
|
||||||
|
func resolveAntigravityForwardBaseURL() string {
|
||||||
|
baseURLs := antigravity.ForwardBaseURLs()
|
||||||
|
if len(baseURLs) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv)))
|
||||||
|
if mode == "prod" && len(baseURLs) > 1 {
|
||||||
|
return baseURLs[1]
|
||||||
|
}
|
||||||
|
return baseURLs[0]
|
||||||
|
}
|
||||||
|
|
||||||
// smartRetryAction 智能重试的处理结果
|
// smartRetryAction 智能重试的处理结果
|
||||||
type smartRetryAction int
|
type smartRetryAction int
|
||||||
|
|
||||||
@@ -144,10 +189,17 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 判断是否触发智能重试
|
// 判断是否触发智能重试
|
||||||
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody)
|
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody)
|
||||||
|
|
||||||
// 情况1: retryDelay >= 阈值,限流模型并切换账号
|
// 情况1: retryDelay >= 阈值,限流模型并切换账号
|
||||||
if shouldRateLimitModel {
|
if shouldRateLimitModel {
|
||||||
|
// 单账号 503 退避重试模式:不设限流、不切换账号,改为原地等待+重试
|
||||||
|
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||||
|
// 多账号场景下切换账号是最优选择,但单账号场景下设限流毫无意义(只会导致双重等待)。
|
||||||
|
if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) {
|
||||||
|
return s.handleSingleAccountRetryInPlace(p, resp, respBody, baseURL, waitDuration, modelName)
|
||||||
|
}
|
||||||
|
|
||||||
rateLimitDuration := waitDuration
|
rateLimitDuration := waitDuration
|
||||||
if rateLimitDuration <= 0 {
|
if rateLimitDuration <= 0 {
|
||||||
rateLimitDuration = antigravityDefaultRateLimitDuration
|
rateLimitDuration = antigravityDefaultRateLimitDuration
|
||||||
@@ -174,20 +226,48 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次)
|
// 情况2: retryDelay < 阈值(或 MODEL_CAPACITY_EXHAUSTED),智能重试
|
||||||
if shouldSmartRetry {
|
if shouldSmartRetry {
|
||||||
var lastRetryResp *http.Response
|
var lastRetryResp *http.Response
|
||||||
var lastRetryBody []byte
|
var lastRetryBody []byte
|
||||||
|
|
||||||
for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ {
|
// MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数(60 次,固定 1s 间隔)
|
||||||
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
|
maxAttempts := antigravitySmartRetryMaxAttempts
|
||||||
p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID)
|
if isModelCapacityExhausted {
|
||||||
|
maxAttempts = antigravityModelCapacityRetryMaxAttempts
|
||||||
|
waitDuration = antigravityModelCapacityRetryWait
|
||||||
|
|
||||||
|
// 全局去重:如果其他 goroutine 已在重试同一模型且尚在 cooldown 中,直接返回 503
|
||||||
|
if modelName != "" {
|
||||||
|
modelCapacityExhaustedMu.RLock()
|
||||||
|
cooldownUntil, exists := modelCapacityExhaustedUntil[modelName]
|
||||||
|
modelCapacityExhaustedMu.RUnlock()
|
||||||
|
if exists && time.Now().Before(cooldownUntil) {
|
||||||
|
log.Printf("%s status=%d model_capacity_exhausted_dedup model=%s account=%d cooldown_until=%v (skip retry)",
|
||||||
|
p.prefix, resp.StatusCode, modelName, p.account.ID, cooldownUntil.Format("15:04:05"))
|
||||||
|
return &smartRetryResult{
|
||||||
|
action: smartRetryActionBreakWithResp,
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Header: resp.Header.Clone(),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
|
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
|
||||||
|
p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID)
|
||||||
|
|
||||||
|
timer := time.NewTimer(waitDuration)
|
||||||
select {
|
select {
|
||||||
case <-p.ctx.Done():
|
case <-p.ctx.Done():
|
||||||
|
timer.Stop()
|
||||||
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
|
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
|
||||||
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
|
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
|
||||||
case <-time.After(waitDuration):
|
case <-timer.C:
|
||||||
}
|
}
|
||||||
|
|
||||||
// 智能重试:创建新请求
|
// 智能重试:创建新请求
|
||||||
@@ -207,13 +287,19 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
|||||||
|
|
||||||
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||||
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||||||
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts)
|
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts)
|
||||||
|
// 重试成功,清除 MODEL_CAPACITY_EXHAUSTED cooldown
|
||||||
|
if isModelCapacityExhausted && modelName != "" {
|
||||||
|
modelCapacityExhaustedMu.Lock()
|
||||||
|
delete(modelCapacityExhaustedUntil, modelName)
|
||||||
|
modelCapacityExhaustedMu.Unlock()
|
||||||
|
}
|
||||||
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
|
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 网络错误时,继续重试
|
// 网络错误时,继续重试
|
||||||
if retryErr != nil || retryResp == nil {
|
if retryErr != nil || retryResp == nil {
|
||||||
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr)
|
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, maxAttempts, retryErr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,20 +309,20 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
|||||||
}
|
}
|
||||||
lastRetryResp = retryResp
|
lastRetryResp = retryResp
|
||||||
if retryResp != nil {
|
if retryResp != nil {
|
||||||
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
|
||||||
_ = retryResp.Body.Close()
|
_ = retryResp.Body.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析新的重试信息,用于下次重试的等待时间
|
// 解析新的重试信息,用于下次重试的等待时间(MODEL_CAPACITY_EXHAUSTED 使用固定循环,跳过)
|
||||||
if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil {
|
if !isModelCapacityExhausted && attempt < maxAttempts && lastRetryBody != nil {
|
||||||
newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
|
newShouldRetry, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
|
||||||
if newShouldRetry && newWaitDuration > 0 {
|
if newShouldRetry && newWaitDuration > 0 {
|
||||||
waitDuration = newWaitDuration
|
waitDuration = newWaitDuration
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 所有重试都失败,限流当前模型并切换账号
|
// 所有重试都失败
|
||||||
rateLimitDuration := waitDuration
|
rateLimitDuration := waitDuration
|
||||||
if rateLimitDuration <= 0 {
|
if rateLimitDuration <= 0 {
|
||||||
rateLimitDuration = antigravityDefaultRateLimitDuration
|
rateLimitDuration = antigravityDefaultRateLimitDuration
|
||||||
@@ -245,8 +331,45 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
|||||||
if retryBody == nil {
|
if retryBody == nil {
|
||||||
retryBody = respBody
|
retryBody = respBody
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MODEL_CAPACITY_EXHAUSTED:模型容量不足,切换账号无意义
|
||||||
|
// 直接返回上游错误响应,不设置模型限流,不切换账号
|
||||||
|
if isModelCapacityExhausted {
|
||||||
|
// 设置 cooldown,让后续请求快速失败,避免重复重试
|
||||||
|
if modelName != "" {
|
||||||
|
modelCapacityExhaustedMu.Lock()
|
||||||
|
modelCapacityExhaustedUntil[modelName] = time.Now().Add(antigravityModelCapacityCooldown)
|
||||||
|
modelCapacityExhaustedMu.Unlock()
|
||||||
|
}
|
||||||
|
log.Printf("%s status=%d smart_retry_exhausted_model_capacity attempts=%d model=%s account=%d body=%s (model capacity exhausted, not switching account)",
|
||||||
|
p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200))
|
||||||
|
return &smartRetryResult{
|
||||||
|
action: smartRetryActionBreakWithResp,
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Header: resp.Header.Clone(),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号,
|
||||||
|
// 直接返回 503 让 Handler 层的单账号退避循环做最终处理。
|
||||||
|
if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) {
|
||||||
|
log.Printf("%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)",
|
||||||
|
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200))
|
||||||
|
return &smartRetryResult{
|
||||||
|
action: smartRetryActionBreakWithResp,
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Header: resp.Header.Clone(),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)",
|
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)",
|
||||||
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200))
|
p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200))
|
||||||
|
|
||||||
resetAt := time.Now().Add(rateLimitDuration)
|
resetAt := time.Now().Add(rateLimitDuration)
|
||||||
if p.accountRepo != nil && modelName != "" {
|
if p.accountRepo != nil && modelName != "" {
|
||||||
@@ -279,25 +402,163 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
|||||||
return &smartRetryResult{action: smartRetryActionContinue}
|
return &smartRetryResult{action: smartRetryActionContinue}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleSingleAccountRetryInPlace 单账号 503 退避重试的原地重试逻辑。
|
||||||
|
//
|
||||||
|
// 在多账号场景下,收到 503 + 长 retryDelay(≥ 7s)时会设置模型限流 + 切换账号;
|
||||||
|
// 但在单账号场景下,设限流毫无意义(因为切换回来的还是同一个账号,还要等限流过期)。
|
||||||
|
// 此方法改为在 Service 层原地等待 + 重试,避免双重等待问题:
|
||||||
|
//
|
||||||
|
// 旧流程:Service 设限流 → Handler 退避等待 → Service 等限流过期 → 再请求(总耗时 = 退避 + 限流)
|
||||||
|
// 新流程:Service 直接等 retryDelay → 重试 → 成功/再等 → 重试...(总耗时 ≈ 实际 retryDelay × 重试次数)
|
||||||
|
//
|
||||||
|
// 约束:
|
||||||
|
// - 单次等待不超过 antigravitySingleAccountSmartRetryMaxWait
|
||||||
|
// - 总累计等待不超过 antigravitySingleAccountSmartRetryTotalMaxWait
|
||||||
|
// - 最多重试 antigravitySingleAccountSmartRetryMaxAttempts 次
|
||||||
|
func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||||
|
p antigravityRetryLoopParams,
|
||||||
|
resp *http.Response,
|
||||||
|
respBody []byte,
|
||||||
|
baseURL string,
|
||||||
|
waitDuration time.Duration,
|
||||||
|
modelName string,
|
||||||
|
) *smartRetryResult {
|
||||||
|
// 限制单次等待时间
|
||||||
|
if waitDuration > antigravitySingleAccountSmartRetryMaxWait {
|
||||||
|
waitDuration = antigravitySingleAccountSmartRetryMaxWait
|
||||||
|
}
|
||||||
|
if waitDuration < antigravitySmartRetryMinWait {
|
||||||
|
waitDuration = antigravitySmartRetryMinWait
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)",
|
||||||
|
p.prefix, resp.StatusCode, modelName, p.account.ID, waitDuration)
|
||||||
|
|
||||||
|
var lastRetryResp *http.Response
|
||||||
|
var lastRetryBody []byte
|
||||||
|
totalWaited := time.Duration(0)
|
||||||
|
|
||||||
|
for attempt := 1; attempt <= antigravitySingleAccountSmartRetryMaxAttempts; attempt++ {
|
||||||
|
// 检查累计等待是否超限
|
||||||
|
if totalWaited+waitDuration > antigravitySingleAccountSmartRetryTotalMaxWait {
|
||||||
|
remaining := antigravitySingleAccountSmartRetryTotalMaxWait - totalWaited
|
||||||
|
if remaining <= 0 {
|
||||||
|
log.Printf("%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up",
|
||||||
|
p.prefix, totalWaited, antigravitySingleAccountSmartRetryTotalMaxWait)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
waitDuration = remaining
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d",
|
||||||
|
p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID)
|
||||||
|
|
||||||
|
timer := time.NewTimer(waitDuration)
|
||||||
|
select {
|
||||||
|
case <-p.ctx.Done():
|
||||||
|
timer.Stop()
|
||||||
|
log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix)
|
||||||
|
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
totalWaited += waitDuration
|
||||||
|
|
||||||
|
// 创建新请求
|
||||||
|
retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("%s single_account_503_retry: request_build_failed error=%v", p.prefix, err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||||
|
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
log.Printf("%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v",
|
||||||
|
p.prefix, retryResp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited)
|
||||||
|
// 关闭之前的响应
|
||||||
|
if lastRetryResp != nil {
|
||||||
|
_ = lastRetryResp.Body.Close()
|
||||||
|
}
|
||||||
|
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 网络错误时继续重试
|
||||||
|
if retryErr != nil || retryResp == nil {
|
||||||
|
log.Printf("%s single_account_503_retry: network_error attempt=%d/%d error=%v",
|
||||||
|
p.prefix, attempt, antigravitySingleAccountSmartRetryMaxAttempts, retryErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭之前的响应
|
||||||
|
if lastRetryResp != nil {
|
||||||
|
_ = lastRetryResp.Body.Close()
|
||||||
|
}
|
||||||
|
lastRetryResp = retryResp
|
||||||
|
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
|
||||||
|
_ = retryResp.Body.Close()
|
||||||
|
|
||||||
|
// 解析新的重试信息,更新下次等待时间
|
||||||
|
if attempt < antigravitySingleAccountSmartRetryMaxAttempts && lastRetryBody != nil {
|
||||||
|
_, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
|
||||||
|
if newWaitDuration > 0 {
|
||||||
|
waitDuration = newWaitDuration
|
||||||
|
if waitDuration > antigravitySingleAccountSmartRetryMaxWait {
|
||||||
|
waitDuration = antigravitySingleAccountSmartRetryMaxWait
|
||||||
|
}
|
||||||
|
if waitDuration < antigravitySmartRetryMinWait {
|
||||||
|
waitDuration = antigravitySmartRetryMinWait
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 所有重试都失败,不设限流,直接返回 503
|
||||||
|
// Handler 层的单账号退避循环会做最终处理
|
||||||
|
retryBody := lastRetryBody
|
||||||
|
if retryBody == nil {
|
||||||
|
retryBody = respBody
|
||||||
|
}
|
||||||
|
log.Printf("%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)",
|
||||||
|
p.prefix, resp.StatusCode, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited, modelName, p.account.ID, truncateForLog(retryBody, 200))
|
||||||
|
|
||||||
|
return &smartRetryResult{
|
||||||
|
action: smartRetryActionBreakWithResp,
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Header: resp.Header.Clone(),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
||||||
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
||||||
// 预检查:如果账号已限流,直接返回切换信号
|
// 预检查:如果账号已限流,直接返回切换信号
|
||||||
if p.requestedModel != "" {
|
if p.requestedModel != "" {
|
||||||
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
|
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
|
||||||
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
|
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求。
|
||||||
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
|
||||||
return nil, &AntigravityAccountSwitchError{
|
// 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace
|
||||||
OriginalAccountID: p.account.ID,
|
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
|
||||||
RateLimitedModel: p.requestedModel,
|
if isSingleAccountRetry(p.ctx) {
|
||||||
IsStickySession: p.isStickySession,
|
log.Printf("%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
|
||||||
|
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||||||
|
} else {
|
||||||
|
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
|
||||||
|
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
|
||||||
|
return nil, &AntigravityAccountSwitchError{
|
||||||
|
OriginalAccountID: p.account.ID,
|
||||||
|
RateLimitedModel: p.requestedModel,
|
||||||
|
IsStickySession: p.isStickySession,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
baseURL := resolveAntigravityForwardBaseURL()
|
||||||
if len(availableURLs) == 0 {
|
if baseURL == "" {
|
||||||
availableURLs = antigravity.BaseURLs
|
return nil, errors.New("no antigravity forward base url configured")
|
||||||
}
|
}
|
||||||
|
availableURLs := []string{baseURL}
|
||||||
|
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
var usedBaseURL string
|
var usedBaseURL string
|
||||||
@@ -371,12 +632,12 @@ urlFallbackLoop:
|
|||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
// ★ 统一入口:自定义错误码 + 临时不可调度
|
// ★ 统一入口:自定义错误码 + 临时不可调度
|
||||||
if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
||||||
if policyErr != nil {
|
if policyErr != nil {
|
||||||
return nil, policyErr
|
return nil, policyErr
|
||||||
}
|
}
|
||||||
resp = &http.Response{
|
resp = &http.Response{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: outStatus,
|
||||||
Header: resp.Header.Clone(),
|
Header: resp.Header.Clone(),
|
||||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
}
|
}
|
||||||
@@ -610,21 +871,22 @@ func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, accoun
|
|||||||
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
|
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环
|
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。
|
||||||
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) {
|
// ErrorPolicySkipped 时 outStatus 为 500(前端约定:未命中的错误返回 500)。
|
||||||
|
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) {
|
||||||
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
|
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
|
||||||
case ErrorPolicySkipped:
|
case ErrorPolicySkipped:
|
||||||
return true, nil
|
return true, http.StatusInternalServerError, nil
|
||||||
case ErrorPolicyMatched:
|
case ErrorPolicyMatched:
|
||||||
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
|
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
|
||||||
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
||||||
return true, nil
|
return true, statusCode, nil
|
||||||
case ErrorPolicyTempUnscheduled:
|
case ErrorPolicyTempUnscheduled:
|
||||||
slog.Info("temp_unschedulable_matched",
|
slog.Info("temp_unschedulable_matched",
|
||||||
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
||||||
return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
||||||
}
|
}
|
||||||
return false, nil
|
return false, statusCode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapAntigravityModel 获取映射后的模型名
|
// mapAntigravityModel 获取映射后的模型名
|
||||||
@@ -734,11 +996,11 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
|||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
// URL fallback 循环
|
baseURL := resolveAntigravityForwardBaseURL()
|
||||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
if baseURL == "" {
|
||||||
if len(availableURLs) == 0 {
|
return nil, errors.New("no antigravity forward base url configured")
|
||||||
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
|
|
||||||
}
|
}
|
||||||
|
availableURLs := []string{baseURL}
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for urlIdx, baseURL := range availableURLs {
|
for urlIdx, baseURL := range availableURLs {
|
||||||
@@ -1047,7 +1309,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
|
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
|
||||||
}
|
}
|
||||||
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
||||||
thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
||||||
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
||||||
|
|
||||||
// 获取 access_token
|
// 获取 access_token
|
||||||
@@ -1203,7 +1465,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
|
||||||
_ = retryResp.Body.Close()
|
_ = retryResp.Body.Close()
|
||||||
if retryResp.StatusCode == http.StatusTooManyRequests {
|
if retryResp.StatusCode == http.StatusTooManyRequests {
|
||||||
retryBaseURL := ""
|
retryBaseURL := ""
|
||||||
@@ -1284,6 +1546,27 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
|
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
|
||||||
|
|
||||||
|
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||||
|
if isGoogleProjectConfigError(msg) {
|
||||||
|
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||||
|
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
||||||
|
log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
@@ -1824,6 +2107,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
// Always record upstream context for Ops error logs, even when we will failover.
|
// Always record upstream context for Ops error logs, even when we will failover.
|
||||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
|
|
||||||
|
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
|
||||||
|
if resp.StatusCode == http.StatusBadRequest && isGoogleProjectConfigError(strings.ToLower(upstreamMsg)) {
|
||||||
|
log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: requestID,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps, RetryableOnSameAccount: true}
|
||||||
|
}
|
||||||
|
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
@@ -1919,6 +2218,44 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isGoogleProjectConfigError 判断(已提取的小写)错误消息是否属于 Google 服务端配置类问题。
|
||||||
|
// 只精确匹配已知的服务端侧错误,避免对客户端请求错误做无意义重试。
|
||||||
|
// 适用于所有走 Google 后端的平台(Antigravity、Gemini)。
|
||||||
|
func isGoogleProjectConfigError(lowerMsg string) bool {
|
||||||
|
// Google 间歇性 Bug:Project ID 有效但被临时识别失败
|
||||||
|
return strings.Contains(lowerMsg, "invalid project resource name")
|
||||||
|
}
|
||||||
|
|
||||||
|
// googleConfigErrorCooldown 服务端配置类 400 错误的临时封禁时长
|
||||||
|
const googleConfigErrorCooldown = 1 * time.Minute
|
||||||
|
|
||||||
|
// tempUnscheduleGoogleConfigError 对服务端配置类 400 错误触发临时封禁,
|
||||||
|
// 避免短时间内反复调度到同一个有问题的账号。
|
||||||
|
func tempUnscheduleGoogleConfigError(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) {
|
||||||
|
until := time.Now().Add(googleConfigErrorCooldown)
|
||||||
|
reason := "400: invalid project resource name (auto temp-unschedule 1m)"
|
||||||
|
if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil {
|
||||||
|
log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err)
|
||||||
|
} else {
|
||||||
|
log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// emptyResponseCooldown 空流式响应的临时封禁时长
|
||||||
|
const emptyResponseCooldown = 1 * time.Minute
|
||||||
|
|
||||||
|
// tempUnscheduleEmptyResponse 对空流式响应触发临时封禁,
|
||||||
|
// 避免短时间内反复调度到同一个返回空响应的账号。
|
||||||
|
func tempUnscheduleEmptyResponse(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) {
|
||||||
|
until := time.Now().Add(emptyResponseCooldown)
|
||||||
|
reason := "empty stream response (auto temp-unschedule 1m)"
|
||||||
|
if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil {
|
||||||
|
log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err)
|
||||||
|
} else {
|
||||||
|
log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
|
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
|
||||||
// 返回 true 表示正常完成等待,false 表示 context 已取消
|
// 返回 true 表示正常完成等待,false 表示 context 已取消
|
||||||
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||||||
@@ -1935,14 +2272,22 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
|||||||
sleepFor = 0
|
sleepFor = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
timer := time.NewTimer(sleepFor)
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
timer.Stop()
|
||||||
return false
|
return false
|
||||||
case <-time.After(sleepFor):
|
case <-timer.C:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记
|
||||||
|
func isSingleAccountRetry(ctx context.Context) bool {
|
||||||
|
v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool)
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
// setModelRateLimitByModelName 使用官方模型 ID 设置模型级限流
|
// setModelRateLimitByModelName 使用官方模型 ID 设置模型级限流
|
||||||
// 直接使用上游返回的模型 ID(如 claude-sonnet-4-5)作为限流 key
|
// 直接使用上游返回的模型 ID(如 claude-sonnet-4-5)作为限流 key
|
||||||
// 返回是否已成功设置(若模型名为空或 repo 为 nil 将返回 false)
|
// 返回是否已成功设置(若模型名为空或 repo 为 nil 将返回 false)
|
||||||
@@ -1977,8 +2322,9 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
|
|||||||
|
|
||||||
// antigravitySmartRetryInfo 智能重试所需的信息
|
// antigravitySmartRetryInfo 智能重试所需的信息
|
||||||
type antigravitySmartRetryInfo struct {
|
type antigravitySmartRetryInfo struct {
|
||||||
RetryDelay time.Duration // 重试延迟时间
|
RetryDelay time.Duration // 重试延迟时间
|
||||||
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5")
|
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5")
|
||||||
|
IsModelCapacityExhausted bool // 是否为模型容量不足(MODEL_CAPACITY_EXHAUSTED)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息
|
// parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息
|
||||||
@@ -2093,31 +2439,40 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &antigravitySmartRetryInfo{
|
return &antigravitySmartRetryInfo{
|
||||||
RetryDelay: retryDelay,
|
RetryDelay: retryDelay,
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
|
IsModelCapacityExhausted: hasModelCapacityExhausted,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试
|
// shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试
|
||||||
// 返回:
|
// 返回:
|
||||||
// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold)
|
// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold,或 MODEL_CAPACITY_EXHAUSTED)
|
||||||
// - shouldRateLimitModel: 是否应该限流模型(retryDelay >= antigravityRateLimitThreshold)
|
// - shouldRateLimitModel: 是否应该限流模型并切换账号(仅 RATE_LIMIT_EXCEEDED 且 retryDelay >= 阈值)
|
||||||
// - waitDuration: 等待时间(智能重试时使用,shouldRateLimitModel=true 时为 0)
|
// - waitDuration: 等待时间
|
||||||
// - modelName: 限流的模型名称
|
// - modelName: 限流的模型名称
|
||||||
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) {
|
// - isModelCapacityExhausted: 是否为模型容量不足(MODEL_CAPACITY_EXHAUSTED)
|
||||||
|
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string, isModelCapacityExhausted bool) {
|
||||||
if account.Platform != PlatformAntigravity {
|
if account.Platform != PlatformAntigravity {
|
||||||
return false, false, 0, ""
|
return false, false, 0, "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
info := parseAntigravitySmartRetryInfo(respBody)
|
info := parseAntigravitySmartRetryInfo(respBody)
|
||||||
if info == nil {
|
if info == nil {
|
||||||
return false, false, 0, ""
|
return false, false, 0, "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MODEL_CAPACITY_EXHAUSTED(模型容量不足):所有账号共享同一模型容量池
|
||||||
|
// 切换账号无意义,使用固定 1s 间隔重试
|
||||||
|
if info.IsModelCapacityExhausted {
|
||||||
|
return true, false, antigravityModelCapacityRetryWait, info.ModelName, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RATE_LIMIT_EXCEEDED(账号级限流):
|
||||||
// retryDelay >= 阈值:直接限流模型,不重试
|
// retryDelay >= 阈值:直接限流模型,不重试
|
||||||
// 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 30s
|
// 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 30s
|
||||||
if info.RetryDelay >= antigravityRateLimitThreshold {
|
if info.RetryDelay >= antigravityRateLimitThreshold {
|
||||||
return false, true, info.RetryDelay, info.ModelName
|
return false, true, info.RetryDelay, info.ModelName, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// retryDelay < 阈值:智能重试
|
// retryDelay < 阈值:智能重试
|
||||||
@@ -2126,7 +2481,7 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou
|
|||||||
waitDuration = antigravitySmartRetryMinWait
|
waitDuration = antigravitySmartRetryMinWait
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, false, waitDuration, info.ModelName
|
return true, false, waitDuration, info.ModelName, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleModelRateLimitParams 模型级限流处理参数
|
// handleModelRateLimitParams 模型级限流处理参数
|
||||||
@@ -2152,8 +2507,9 @@ type handleModelRateLimitResult struct {
|
|||||||
|
|
||||||
// handleModelRateLimit 处理模型级限流(在原有逻辑之前调用)
|
// handleModelRateLimit 处理模型级限流(在原有逻辑之前调用)
|
||||||
// 仅处理 429/503,解析模型名和 retryDelay
|
// 仅处理 429/503,解析模型名和 retryDelay
|
||||||
// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true,由调用方等待后重试
|
// - MODEL_CAPACITY_EXHAUSTED: 返回 Handled=true(实际重试由 handleSmartRetry 处理)
|
||||||
// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
|
// - RATE_LIMIT_EXCEEDED + retryDelay < 阈值: 返回 ShouldRetry=true,由调用方等待后重试
|
||||||
|
// - RATE_LIMIT_EXCEEDED + retryDelay >= 阈值: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
|
||||||
func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult {
|
func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult {
|
||||||
if p.statusCode != 429 && p.statusCode != 503 {
|
if p.statusCode != 429 && p.statusCode != 503 {
|
||||||
return &handleModelRateLimitResult{Handled: false}
|
return &handleModelRateLimitResult{Handled: false}
|
||||||
@@ -2164,7 +2520,17 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
|
|||||||
return &handleModelRateLimitResult{Handled: false}
|
return &handleModelRateLimitResult{Handled: false}
|
||||||
}
|
}
|
||||||
|
|
||||||
// < antigravityRateLimitThreshold: 等待后重试
|
// MODEL_CAPACITY_EXHAUSTED:模型容量不足,所有账号共享同一容量池
|
||||||
|
// 切换账号无意义,不设置模型限流(实际重试由 handleSmartRetry 处理)
|
||||||
|
if info.IsModelCapacityExhausted {
|
||||||
|
log.Printf("%s status=%d model_capacity_exhausted model=%s (not switching account, retry handled by smart retry)",
|
||||||
|
p.prefix, p.statusCode, info.ModelName)
|
||||||
|
return &handleModelRateLimitResult{
|
||||||
|
Handled: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RATE_LIMIT_EXCEEDED: < antigravityRateLimitThreshold: 等待后重试
|
||||||
if info.RetryDelay < antigravityRateLimitThreshold {
|
if info.RetryDelay < antigravityRateLimitThreshold {
|
||||||
log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v",
|
log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v",
|
||||||
p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
|
p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
|
||||||
@@ -2175,7 +2541,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
|
// RATE_LIMIT_EXCEEDED: >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
|
||||||
s.setModelRateLimitAndClearSession(p, info)
|
s.setModelRateLimitAndClearSession(p, info)
|
||||||
|
|
||||||
return &handleModelRateLimitResult{
|
return &handleModelRateLimitResult{
|
||||||
@@ -2242,6 +2608,10 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
|||||||
requestedModel string,
|
requestedModel string,
|
||||||
groupID int64, sessionHash string, isStickySession bool,
|
groupID int64, sessionHash string, isStickySession bool,
|
||||||
) *handleModelRateLimitResult {
|
) *handleModelRateLimitResult {
|
||||||
|
// 遵守自定义错误码策略:未命中则跳过所有限流处理
|
||||||
|
if !account.ShouldHandleErrorCode(statusCode) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
// 模型级限流处理(优先)
|
// 模型级限流处理(优先)
|
||||||
result := s.handleModelRateLimit(&handleModelRateLimitParams{
|
result := s.handleModelRateLimit(&handleModelRateLimitParams{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -2719,9 +3089,14 @@ returnResponse:
|
|||||||
// 选择最后一个有效响应
|
// 选择最后一个有效响应
|
||||||
finalResponse := pickGeminiCollectResult(last, lastWithParts)
|
finalResponse := pickGeminiCollectResult(last, lastWithParts)
|
||||||
|
|
||||||
// 处理空响应情况
|
// 处理空响应情况 — 触发同账号重试 + failover 切换账号
|
||||||
if last == nil && lastWithParts == nil {
|
if last == nil && lastWithParts == nil {
|
||||||
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
log.Printf("[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover")
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
|
||||||
|
RetryableOnSameAccount: true,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果收集到了图片 parts,需要合并到最终响应中
|
// 如果收集到了图片 parts,需要合并到最终响应中
|
||||||
@@ -2939,6 +3314,21 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
|
|||||||
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
|
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查错误透传规则
|
||||||
|
if ptStatus, ptErrType, ptErrMsg, matched := applyErrorPassthroughRule(
|
||||||
|
c, account.Platform, upstreamStatus, body,
|
||||||
|
0, "", "",
|
||||||
|
); matched {
|
||||||
|
c.JSON(ptStatus, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": gin.H{"type": ptErrType, "message": ptErrMsg},
|
||||||
|
})
|
||||||
|
if upstreamMsg == "" {
|
||||||
|
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||||||
|
}
|
||||||
|
|
||||||
var statusCode int
|
var statusCode int
|
||||||
var errType, errMsg string
|
var errType, errMsg string
|
||||||
|
|
||||||
@@ -3134,10 +3524,14 @@ returnResponse:
|
|||||||
// 选择最后一个有效响应
|
// 选择最后一个有效响应
|
||||||
finalResponse := pickGeminiCollectResult(last, lastWithParts)
|
finalResponse := pickGeminiCollectResult(last, lastWithParts)
|
||||||
|
|
||||||
// 处理空响应情况
|
// 处理空响应情况 — 触发同账号重试 + failover 切换账号
|
||||||
if last == nil && lastWithParts == nil {
|
if last == nil && lastWithParts == nil {
|
||||||
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
log.Printf("[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover")
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
|
||||||
|
RetryableOnSameAccount: true,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 将收集的所有 parts 合并到最终响应中
|
// 将收集的所有 parts 合并到最终响应中
|
||||||
@@ -3717,6 +4111,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs
|
|||||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
||||||
usage.CacheCreationInputTokens = int(v)
|
usage.CacheCreationInputTokens = int(v)
|
||||||
}
|
}
|
||||||
|
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||||
|
if cc, ok := u["cache_creation"].(map[string]any); ok {
|
||||||
|
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheCreation5mTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheCreation1hTokens = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
||||||
@@ -3739,6 +4142,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage
|
|||||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
||||||
usage.CacheCreationInputTokens = int(v)
|
usage.CacheCreationInputTokens = int(v)
|
||||||
}
|
}
|
||||||
|
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||||
|
if cc, ok := u["cache_creation"].(map[string]any); ok {
|
||||||
|
if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheCreation5mTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
|
||||||
|
usage.CacheCreation1hTokens = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return usage
|
return usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -553,6 +553,75 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
|||||||
require.NotContains(t, body, "event: error")
|
require.NotContains(t, body, "event: error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestHandleGeminiStreamingResponse_ThoughtsTokenCount
|
||||||
|
// 验证:Gemini 流式转发时 thoughtsTokenCount 被计入 OutputTokens
|
||||||
|
func TestHandleGeminiStreamingResponse_ThoughtsTokenCount(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":30,"thoughtsTokenCount":80,"cachedContentTokenCount":10}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
// promptTokenCount=100, cachedContentTokenCount=10 → InputTokens=90
|
||||||
|
require.Equal(t, 90, result.usage.InputTokens)
|
||||||
|
// candidatesTokenCount=30 + thoughtsTokenCount=80 → OutputTokens=110
|
||||||
|
require.Equal(t, 110, result.usage.OutputTokens)
|
||||||
|
require.Equal(t, 10, result.usage.CacheReadInputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleClaudeStreamingResponse_ThoughtsTokenCount
|
||||||
|
// 验证:Gemini→Claude 流式转换时 thoughtsTokenCount 被计入 OutputTokens
|
||||||
|
func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":10,"thoughtsTokenCount":25}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
// promptTokenCount=50 → InputTokens=50
|
||||||
|
require.Equal(t, 50, result.usage.InputTokens)
|
||||||
|
// candidatesTokenCount=10 + thoughtsTokenCount=25 → OutputTokens=35
|
||||||
|
require.Equal(t, 35, result.usage.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
// --- 流式客户端断开检测测试 ---
|
// --- 流式客户端断开检测测试 ---
|
||||||
|
|
||||||
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
|
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
|
||||||
|
|||||||
@@ -192,6 +192,43 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
|
|||||||
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
|
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id)
|
||||||
|
func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64) (*AntigravityTokenInfo, error) {
|
||||||
|
var proxyURL string
|
||||||
|
if proxyID != nil {
|
||||||
|
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||||
|
if err == nil && proxy != nil {
|
||||||
|
proxyURL = proxy.URL()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 刷新 token
|
||||||
|
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户信息(email)
|
||||||
|
client := antigravity.NewClient(proxyURL)
|
||||||
|
userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
|
||||||
|
} else {
|
||||||
|
tokenInfo.Email = userInfo.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 project_id(容错,失败不阻塞)
|
||||||
|
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
|
||||||
|
if loadErr != nil {
|
||||||
|
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
|
||||||
|
tokenInfo.ProjectIDMissing = true
|
||||||
|
} else {
|
||||||
|
tokenInfo.ProjectID = projectID
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
func isNonRetryableAntigravityOAuthError(err error) bool {
|
func isNonRetryableAntigravityOAuthError(err error) bool {
|
||||||
msg := err.Error()
|
msg := err.Error()
|
||||||
nonRetryable := []string{
|
nonRetryable := []string{
|
||||||
@@ -273,12 +310,21 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := antigravity.NewClient(proxyURL)
|
client := antigravity.NewClient(proxyURL)
|
||||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
|
||||||
|
|
||||||
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||||
return loadResp.CloudAICompanionProject, nil
|
return loadResp.CloudAICompanionProject, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
|
||||||
|
return projectID, nil
|
||||||
|
} else if onboardErr != nil {
|
||||||
|
lastErr = onboardErr
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 记录错误
|
// 记录错误
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
@@ -292,6 +338,65 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
|||||||
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
|
||||||
|
tierID := resolveDefaultTierID(loadRaw)
|
||||||
|
if tierID == "" {
|
||||||
|
return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
|
||||||
|
}
|
||||||
|
|
||||||
|
projectID, err := client.OnboardUser(ctx, accessToken, tierID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
|
||||||
|
}
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveDefaultTierID(loadRaw map[string]any) string {
|
||||||
|
if len(loadRaw) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
rawTiers, ok := loadRaw["allowedTiers"]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
tiers, ok := rawTiers.([]any)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rawTier := range tiers {
|
||||||
|
tier, ok := rawTier.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isDefault, _ := tier["isDefault"].(bool); !isDefault {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if id, ok := tier["id"].(string); ok {
|
||||||
|
id = strings.TrimSpace(id)
|
||||||
|
if id != "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// FillProjectID 仅获取 project_id,不刷新 OAuth token
|
||||||
|
func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) {
|
||||||
|
var proxyURL string
|
||||||
|
if account.ProxyID != nil {
|
||||||
|
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||||
|
if err == nil && proxy != nil {
|
||||||
|
proxyURL = proxy.URL()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
|
||||||
|
}
|
||||||
|
|
||||||
// BuildAccountCredentials 构建账户凭证
|
// BuildAccountCredentials 构建账户凭证
|
||||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||||
creds := map[string]any{
|
creds := map[string]any{
|
||||||
|
|||||||
82
backend/internal/service/antigravity_oauth_service_test.go
Normal file
82
backend/internal/service/antigravity_oauth_service_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveDefaultTierID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
loadRaw map[string]any
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil loadRaw",
|
||||||
|
loadRaw: nil,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing allowedTiers",
|
||||||
|
loadRaw: map[string]any{
|
||||||
|
"paidTier": map[string]any{"id": "g1-pro-tier"},
|
||||||
|
},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty allowedTiers",
|
||||||
|
loadRaw: map[string]any{"allowedTiers": []any{}},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tier missing id field",
|
||||||
|
loadRaw: map[string]any{
|
||||||
|
"allowedTiers": []any{
|
||||||
|
map[string]any{"isDefault": true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allowedTiers but no default",
|
||||||
|
loadRaw: map[string]any{
|
||||||
|
"allowedTiers": []any{
|
||||||
|
map[string]any{"id": "free-tier", "isDefault": false},
|
||||||
|
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default tier found",
|
||||||
|
loadRaw: map[string]any{
|
||||||
|
"allowedTiers": []any{
|
||||||
|
map[string]any{"id": "free-tier", "isDefault": true},
|
||||||
|
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: "free-tier",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default tier id with spaces",
|
||||||
|
loadRaw: map[string]any{
|
||||||
|
"allowedTiers": []any{
|
||||||
|
map[string]any{"id": " standard-tier ", "isDefault": true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: "standard-tier",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := resolveDefaultTierID(tc.loadRaw)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -86,7 +86,9 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
|
||||||
|
t.Setenv(antigravityForwardBaseURLEnv, "")
|
||||||
|
|
||||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||||
oldAvailability := antigravity.DefaultURLAvailability
|
oldAvailability := antigravity.DefaultURLAvailability
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -131,15 +133,16 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
|||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.resp)
|
require.NotNil(t, result.resp)
|
||||||
defer func() { _ = result.resp.Body.Close() }()
|
defer func() { _ = result.resp.Body.Close() }()
|
||||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
|
||||||
require.False(t, handleErrorCalled)
|
require.True(t, handleErrorCalled)
|
||||||
require.Len(t, upstream.calls, 2)
|
require.Len(t, upstream.calls, antigravityMaxRetries)
|
||||||
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
|
for _, callURL := range upstream.calls {
|
||||||
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
|
require.True(t, strings.HasPrefix(callURL, base1))
|
||||||
|
}
|
||||||
|
|
||||||
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||||
require.NotEmpty(t, available)
|
require.NotEmpty(t, available)
|
||||||
require.Equal(t, base2, available[0])
|
require.Equal(t, base1, available[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
|
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
|
||||||
@@ -188,13 +191,14 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
|
|||||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
|
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
|
||||||
func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
|
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
|
||||||
|
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
|
||||||
repo := &stubAntigravityAccountRepo{}
|
repo := &stubAntigravityAccountRepo{}
|
||||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity}
|
account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity}
|
||||||
|
|
||||||
// 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流
|
// 503 + MODEL_CAPACITY_EXHAUSTED → 等待重试,不切换账号
|
||||||
body := []byte(`{
|
body := []byte(`{
|
||||||
"error": {
|
"error": {
|
||||||
"status": "UNAVAILABLE",
|
"status": "UNAVAILABLE",
|
||||||
@@ -207,13 +211,13 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
|
|||||||
|
|
||||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
||||||
|
|
||||||
// 应该触发模型限流
|
// MODEL_CAPACITY_EXHAUSTED 应该标记为已处理,不切换账号,不设置模型限流
|
||||||
|
// 实际重试由 handleSmartRetry 处理
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.Handled)
|
require.True(t, result.Handled)
|
||||||
require.NotNil(t, result.SwitchError)
|
require.False(t, result.ShouldRetry, "MODEL_CAPACITY_EXHAUSTED should not trigger retry from handleModelRateLimit path")
|
||||||
require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel)
|
require.Nil(t, result.SwitchError, "MODEL_CAPACITY_EXHAUSTED should not trigger account switch")
|
||||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
|
||||||
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理)
|
// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理)
|
||||||
@@ -301,11 +305,12 @@ func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) {
|
|||||||
|
|
||||||
func TestParseAntigravitySmartRetryInfo(t *testing.T) {
|
func TestParseAntigravitySmartRetryInfo(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
body string
|
body string
|
||||||
expectedDelay time.Duration
|
expectedDelay time.Duration
|
||||||
expectedModel string
|
expectedModel string
|
||||||
expectedNil bool
|
expectedNil bool
|
||||||
|
expectedIsModelCapacityExhausted bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "valid complete response with RATE_LIMIT_EXCEEDED",
|
name: "valid complete response with RATE_LIMIT_EXCEEDED",
|
||||||
@@ -368,8 +373,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
|
|||||||
"message": "No capacity available for model gemini-3-pro-high on the server"
|
"message": "No capacity available for model gemini-3-pro-high on the server"
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
expectedDelay: 39 * time.Second,
|
expectedDelay: 39 * time.Second,
|
||||||
expectedModel: "gemini-3-pro-high",
|
expectedModel: "gemini-3-pro-high",
|
||||||
|
expectedIsModelCapacityExhausted: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil",
|
name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil",
|
||||||
@@ -480,6 +486,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
|
|||||||
if result.ModelName != tt.expectedModel {
|
if result.ModelName != tt.expectedModel {
|
||||||
t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel)
|
t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel)
|
||||||
}
|
}
|
||||||
|
if result.IsModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
|
||||||
|
t.Errorf("IsModelCapacityExhausted = %v, want %v", result.IsModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -491,13 +500,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
|||||||
apiKeyAccount := &Account{Type: AccountTypeAPIKey}
|
apiKeyAccount := &Account{Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
account *Account
|
account *Account
|
||||||
body string
|
body string
|
||||||
expectedShouldRetry bool
|
expectedShouldRetry bool
|
||||||
expectedShouldRateLimit bool
|
expectedShouldRateLimit bool
|
||||||
minWait time.Duration
|
expectedIsModelCapacityExhausted bool
|
||||||
modelName string
|
minWait time.Duration
|
||||||
|
modelName string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "OAuth account with short delay (< 7s) - smart retry",
|
name: "OAuth account with short delay (< 7s) - smart retry",
|
||||||
@@ -611,13 +621,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
expectedShouldRetry: false,
|
expectedShouldRetry: true,
|
||||||
expectedShouldRateLimit: true,
|
expectedShouldRateLimit: false,
|
||||||
minWait: 39 * time.Second,
|
expectedIsModelCapacityExhausted: true,
|
||||||
modelName: "gemini-3-pro-high",
|
minWait: 1 * time.Second,
|
||||||
|
modelName: "gemini-3-pro-high",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit",
|
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use fixed wait",
|
||||||
account: oauthAccount,
|
account: oauthAccount,
|
||||||
body: `{
|
body: `{
|
||||||
"error": {
|
"error": {
|
||||||
@@ -629,10 +640,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
|||||||
"message": "No capacity available for model gemini-2.5-flash on the server"
|
"message": "No capacity available for model gemini-2.5-flash on the server"
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
expectedShouldRetry: false,
|
expectedShouldRetry: true,
|
||||||
expectedShouldRateLimit: true,
|
expectedShouldRateLimit: false,
|
||||||
minWait: 30 * time.Second,
|
expectedIsModelCapacityExhausted: true,
|
||||||
modelName: "gemini-2.5-flash",
|
minWait: 1 * time.Second,
|
||||||
|
modelName: "gemini-2.5-flash",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit",
|
name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit",
|
||||||
@@ -656,13 +668,16 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
|
shouldRetry, shouldRateLimit, wait, model, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
|
||||||
if shouldRetry != tt.expectedShouldRetry {
|
if shouldRetry != tt.expectedShouldRetry {
|
||||||
t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry)
|
t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry)
|
||||||
}
|
}
|
||||||
if shouldRateLimit != tt.expectedShouldRateLimit {
|
if shouldRateLimit != tt.expectedShouldRateLimit {
|
||||||
t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit)
|
t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit)
|
||||||
}
|
}
|
||||||
|
if isModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
|
||||||
|
t.Errorf("isModelCapacityExhausted = %v, want %v", isModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
|
||||||
|
}
|
||||||
if shouldRetry {
|
if shouldRetry {
|
||||||
if wait < tt.minWait {
|
if wait < tt.minWait {
|
||||||
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
|
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
|
||||||
@@ -915,6 +930,22 @@ func TestIsAntigravityAccountSwitchError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResolveAntigravityForwardBaseURL_DefaultDaily(t *testing.T) {
|
||||||
|
t.Setenv(antigravityForwardBaseURLEnv, "")
|
||||||
|
|
||||||
|
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||||
|
defer func() {
|
||||||
|
antigravity.BaseURLs = oldBaseURLs
|
||||||
|
}()
|
||||||
|
|
||||||
|
prodURL := "https://prod.test"
|
||||||
|
dailyURL := "https://daily.test"
|
||||||
|
antigravity.BaseURLs = []string{dailyURL, prodURL}
|
||||||
|
|
||||||
|
resolved := resolveAntigravityForwardBaseURL()
|
||||||
|
require.Equal(t, dailyURL, resolved)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAntigravityAccountSwitchError_Error(t *testing.T) {
|
func TestAntigravityAccountSwitchError_Error(t *testing.T) {
|
||||||
err := &AntigravityAccountSwitchError{
|
err := &AntigravityAccountSwitchError{
|
||||||
OriginalAccountID: 789,
|
OriginalAccountID: 789,
|
||||||
|
|||||||
@@ -0,0 +1,904 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 辅助函数:构造带 SingleAccountRetry 标记的 context
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func ctxWithSingleAccountRetry() context.Context {
|
||||||
|
return context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 1. isSingleAccountRetry 测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestIsSingleAccountRetry_True(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
|
||||||
|
require.True(t, isSingleAccountRetry(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsSingleAccountRetry_False_NoValue(t *testing.T) {
|
||||||
|
require.False(t, isSingleAccountRetry(context.Background()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsSingleAccountRetry_False_ExplicitFalse(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, false)
|
||||||
|
require.False(t, isSingleAccountRetry(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsSingleAccountRetry_False_WrongType(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, "true")
|
||||||
|
require.False(t, isSingleAccountRetry(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 2. 常量验证
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestSingleAccountRetryConstants(t *testing.T) {
|
||||||
|
require.Equal(t, 3, antigravitySingleAccountSmartRetryMaxAttempts,
|
||||||
|
"单账号原地重试最多 3 次")
|
||||||
|
require.Equal(t, 15*time.Second, antigravitySingleAccountSmartRetryMaxWait,
|
||||||
|
"单次最大等待 15s")
|
||||||
|
require.Equal(t, 30*time.Second, antigravitySingleAccountSmartRetryTotalMaxWait,
|
||||||
|
"总累计等待不超过 30s")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 3. handleSmartRetry + 503 + SingleAccountRetry → 走 handleSingleAccountRetryInPlace
|
||||||
|
// (而非设模型限流 + 切换账号)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace
|
||||||
|
// 核心场景:503 + retryDelay >= 7s + SingleAccountRetry 标记
|
||||||
|
// → 不设模型限流、不切换账号,改为原地重试
|
||||||
|
func TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace(t *testing.T) {
|
||||||
|
// 原地重试成功
|
||||||
|
successResp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||||
|
}
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
responses: []*http.Response{successResp},
|
||||||
|
errors: []error{nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "acc-single",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 503 + 39s >= 7s 阈值 + MODEL_CAPACITY_EXHAUSTED
|
||||||
|
respBody := []byte(`{
|
||||||
|
"error": {
|
||||||
|
"code": 503,
|
||||||
|
"status": "UNAVAILABLE",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
|
||||||
|
],
|
||||||
|
"message": "No capacity available for model gemini-3-pro-high on the server"
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: ctxWithSingleAccountRetry(), // 关键:设置单账号标记
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
accountRepo: repo,
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
availableURLs := []string{"https://ag-1.test"}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
// 关键断言:返回 resp(原地重试成功),而非 switchError(切换账号)
|
||||||
|
require.NotNil(t, result.resp, "should return successful response from in-place retry")
|
||||||
|
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||||
|
require.Nil(t, result.switchError, "should NOT return switchError in single account mode")
|
||||||
|
require.Nil(t, result.err)
|
||||||
|
|
||||||
|
// 验证未设模型限流(单账号模式不应设限流)
|
||||||
|
require.Len(t, repo.modelRateLimitCalls, 0,
|
||||||
|
"should NOT set model rate limit in single account retry mode")
|
||||||
|
|
||||||
|
// 验证确实调用了 upstream(原地重试)
|
||||||
|
require.GreaterOrEqual(t, len(upstream.calls), 1, "should have made at least one retry call")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches
|
||||||
|
// 对照组:503 + retryDelay >= 7s + 无 SingleAccountRetry 标记
|
||||||
|
// → 照常设模型限流 + 切换账号
|
||||||
|
func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *testing.T) {
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
account := &Account{
|
||||||
|
ID: 2,
|
||||||
|
Name: "acc-multi",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 503 + 39s >= 7s 阈值(使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,
|
||||||
|
// 因为 MODEL_CAPACITY_EXHAUSTED 走独立的重试路径,不触发 shouldRateLimitModel)
|
||||||
|
respBody := []byte(`{
|
||||||
|
"error": {
|
||||||
|
"code": 503,
|
||||||
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: context.Background(), // 关键:无单账号标记
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
accountRepo: repo,
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
availableURLs := []string{"https://ag-1.test"}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
// 对照:多账号模式返回 switchError
|
||||||
|
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
|
||||||
|
require.Nil(t, result.resp, "should not return resp when switchError is set")
|
||||||
|
|
||||||
|
// 对照:多账号模式应设模型限流
|
||||||
|
require.Len(t, repo.modelRateLimitCalls, 1,
|
||||||
|
"multi-account mode SHOULD set model rate limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches
|
||||||
|
// 边界情况:429(非 503)+ SingleAccountRetry 标记
|
||||||
|
// → 单账号原地重试仅针对 503,429 依然走切换账号逻辑
|
||||||
|
func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *testing.T) {
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
account := &Account{
|
||||||
|
ID: 3,
|
||||||
|
Name: "acc-429",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 429 + 15s >= 7s 阈值
|
||||||
|
respBody := []byte(`{
|
||||||
|
"error": {
|
||||||
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusTooManyRequests, // 429,不是 503
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: ctxWithSingleAccountRetry(), // 有单账号标记
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
accountRepo: repo,
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
availableURLs := []string{"https://ag-1.test"}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
// 429 即使有单账号标记,也应走切换账号
|
||||||
|
require.NotNil(t, result.switchError, "429 should still return switchError even with SingleAccountRetry")
|
||||||
|
require.Len(t, repo.modelRateLimitCalls, 1,
|
||||||
|
"429 should still set model rate limit even with SingleAccountRetry")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 4. handleSmartRetry + 503 + 短延迟 + SingleAccountRetry → 智能重试耗尽后不设限流
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
|
||||||
|
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503,不设限流
|
||||||
|
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
|
||||||
|
// 智能重试也返回 503
|
||||||
|
failRespBody := `{
|
||||||
|
"error": {
|
||||||
|
"code": 503,
|
||||||
|
"status": "UNAVAILABLE",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
failResp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||||
|
}
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
responses: []*http.Response{failResp},
|
||||||
|
errors: []error{nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
account := &Account{
|
||||||
|
ID: 4,
|
||||||
|
Name: "acc-short-503",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0.1s < 7s 阈值
|
||||||
|
respBody := []byte(`{
|
||||||
|
"error": {
|
||||||
|
"code": 503,
|
||||||
|
"status": "UNAVAILABLE",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: ctxWithSingleAccountRetry(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
accountRepo: repo,
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
availableURLs := []string{"https://ag-1.test"}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
// 关键断言:单账号 503 模式下,智能重试耗尽后直接返回 503 响应,不切换
|
||||||
|
require.NotNil(t, result.resp, "should return 503 response directly for single account mode")
|
||||||
|
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
|
||||||
|
require.Nil(t, result.switchError, "should NOT switch account in single account mode")
|
||||||
|
|
||||||
|
// 关键断言:不设模型限流
|
||||||
|
require.Len(t, repo.modelRateLimitCalls, 0,
|
||||||
|
"should NOT set model rate limit for 503 in single account mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit
|
||||||
|
// 对照组:503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流
|
||||||
|
// 使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,因为后者走独立的 60 次重试路径
|
||||||
|
func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) {
|
||||||
|
failRespBody := `{
|
||||||
|
"error": {
|
||||||
|
"code": 503,
|
||||||
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
failResp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||||
|
}
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
responses: []*http.Response{failResp},
|
||||||
|
errors: []error{nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
account := &Account{
|
||||||
|
ID: 5,
|
||||||
|
Name: "acc-multi-503",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody := []byte(`{
|
||||||
|
"error": {
|
||||||
|
"code": 503,
|
||||||
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: context.Background(), // 无单账号标记
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
accountRepo: repo,
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
availableURLs := []string{"https://ag-1.test"}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
// 对照:多账号模式应返回 switchError
|
||||||
|
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
|
||||||
|
// 对照:多账号模式应设模型限流
|
||||||
|
require.Len(t, repo.modelRateLimitCalls, 1,
|
||||||
|
"multi-account mode should set model rate limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 5. handleSingleAccountRetryInPlace 直接测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestHandleSingleAccountRetryInPlace_Success 原地重试成功
|
||||||
|
func TestHandleSingleAccountRetryInPlace_Success(t *testing.T) {
|
||||||
|
successResp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||||
|
}
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
responses: []*http.Response{successResp},
|
||||||
|
errors: []error{nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 10,
|
||||||
|
Name: "acc-inplace-ok",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: ctxWithSingleAccountRetry(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
require.NotNil(t, result.resp, "should return successful response")
|
||||||
|
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||||
|
require.Nil(t, result.switchError, "should not switch account on success")
|
||||||
|
require.Nil(t, result.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleSingleAccountRetryInPlace_AllRetriesFail 所有重试都失败,返回 503(不设限流)
|
||||||
|
func TestHandleSingleAccountRetryInPlace_AllRetriesFail(t *testing.T) {
|
||||||
|
// 构造 3 个 503 响应(对应 3 次原地重试)
|
||||||
|
var responses []*http.Response
|
||||||
|
var errors []error
|
||||||
|
for i := 0; i < antigravitySingleAccountSmartRetryMaxAttempts; i++ {
|
||||||
|
responses = append(responses, &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{
|
||||||
|
"error": {
|
||||||
|
"code": 503,
|
||||||
|
"status": "UNAVAILABLE",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`)),
|
||||||
|
})
|
||||||
|
errors = append(errors, nil)
|
||||||
|
}
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
responses: responses,
|
||||||
|
errors: errors,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 11,
|
||||||
|
Name: "acc-inplace-fail",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
origBody := []byte(`{"error":{"code":503,"status":"UNAVAILABLE"}}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{"X-Test": {"original"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: ctxWithSingleAccountRetry(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSingleAccountRetryInPlace(params, resp, origBody, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
// 关键:返回 503 resp,不返回 switchError
|
||||||
|
require.NotNil(t, result.resp, "should return 503 response directly")
|
||||||
|
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
|
||||||
|
require.Nil(t, result.switchError, "should NOT return switchError - let Handler handle it")
|
||||||
|
require.Nil(t, result.err)
|
||||||
|
|
||||||
|
// 验证确实重试了指定次数
|
||||||
|
require.Len(t, upstream.calls, antigravitySingleAccountSmartRetryMaxAttempts,
|
||||||
|
"should have made exactly maxAttempts retry calls")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleSingleAccountRetryInPlace_WaitDurationClamped 等待时间被限制在 [min, max] 范围
|
||||||
|
func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) {
|
||||||
|
// 用短延迟的成功响应,只验证不 panic
|
||||||
|
successResp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||||
|
}
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
responses: []*http.Response{successResp},
|
||||||
|
errors: []error{nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 12,
|
||||||
|
Name: "acc-clamp",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: ctxWithSingleAccountRetry(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
|
||||||
|
// 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait
|
||||||
|
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro")
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
require.NotNil(t, result.resp)
|
||||||
|
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleSingleAccountRetryInPlace_ContextCanceled context 取消时立即返回
|
||||||
|
func TestHandleSingleAccountRetryInPlace_ContextCanceled(t *testing.T) {
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
responses: []*http.Response{nil},
|
||||||
|
errors: []error{nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 13,
|
||||||
|
Name: "acc-cancel",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true)
|
||||||
|
cancel() // 立即取消
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: ctx,
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
require.Error(t, result.err, "should return context error")
|
||||||
|
// 不应调用 upstream(因为在等待阶段就被取消了)
|
||||||
|
require.Len(t, upstream.calls, 0, "should not call upstream when context is canceled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry 网络错误时继续重试
|
||||||
|
func TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry(t *testing.T) {
|
||||||
|
successResp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||||
|
}
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
// 第1次网络错误(nil resp),第2次成功
|
||||||
|
responses: []*http.Response{nil, successResp},
|
||||||
|
errors: []error{nil, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 14,
|
||||||
|
Name: "acc-net-retry",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: ctxWithSingleAccountRetry(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
require.NotNil(t, result.resp, "should return successful response after network error recovery")
|
||||||
|
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||||
|
require.Len(t, upstream.calls, 2, "first call fails (network error), second succeeds")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 6. antigravityRetryLoop 预检查:单账号模式跳过限流
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit
|
||||||
|
// 预检查中,如果有 SingleAccountRetry 标记,即使账号已限流也跳过直接发请求
|
||||||
|
func TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit(t *testing.T) {
|
||||||
|
// 创建一个已设模型限流的账号
|
||||||
|
upstream := &recordingOKUpstream{}
|
||||||
|
account := &Account{
|
||||||
|
ID: 20,
|
||||||
|
Name: "acc-rate-limited",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Schedulable: true,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Extra: map[string]any{
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
"claude-sonnet-4-5": map[string]any{
|
||||||
|
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||||
|
ctx: ctxWithSingleAccountRetry(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
requestedModel: "claude-sonnet-4-5",
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err, "should not return error")
|
||||||
|
require.NotNil(t, result, "should return result")
|
||||||
|
require.NotNil(t, result.resp, "should have response")
|
||||||
|
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||||
|
// 关键:尽管限流了,有 SingleAccountRetry 标记时仍然到达了 upstream
|
||||||
|
require.Equal(t, 1, upstream.calls, "should have reached upstream despite rate limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit
|
||||||
|
// 对照组:无 SingleAccountRetry + 已限流 → 预检查返回 switchError
|
||||||
|
func TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit(t *testing.T) {
|
||||||
|
upstream := &recordingOKUpstream{}
|
||||||
|
account := &Account{
|
||||||
|
ID: 21,
|
||||||
|
Name: "acc-rate-limited-multi",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Schedulable: true,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Extra: map[string]any{
|
||||||
|
modelRateLimitsKey: map[string]any{
|
||||||
|
"claude-sonnet-4-5": map[string]any{
|
||||||
|
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||||
|
ctx: context.Background(), // 无单账号标记
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
requestedModel: "claude-sonnet-4-5",
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Nil(t, result, "should not return result on rate limit switch")
|
||||||
|
require.NotNil(t, err, "should return error")
|
||||||
|
|
||||||
|
var switchErr *AntigravityAccountSwitchError
|
||||||
|
require.ErrorAs(t, err, &switchErr, "should return AntigravityAccountSwitchError")
|
||||||
|
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||||
|
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
|
||||||
|
|
||||||
|
// upstream 不应被调用(预检查就短路了)
|
||||||
|
require.Equal(t, 0, upstream.calls, "upstream should NOT be called when pre-check blocks")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 7. 端到端集成场景测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E
|
||||||
|
// 端到端场景:503 + 单账号 + 原地重试第2次成功
|
||||||
|
func TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E(t *testing.T) {
|
||||||
|
// 第1次原地重试仍返回 503,第2次成功
|
||||||
|
fail503Body := `{
|
||||||
|
"error": {
|
||||||
|
"code": 503,
|
||||||
|
"status": "UNAVAILABLE",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
resp503 := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(fail503Body)),
|
||||||
|
}
|
||||||
|
successResp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||||
|
}
|
||||||
|
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
responses: []*http.Response{resp503, successResp},
|
||||||
|
errors: []error{nil, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 30,
|
||||||
|
Name: "acc-e2e",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: ctxWithSingleAccountRetry(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
require.NotNil(t, result.resp, "should return successful response after 2nd attempt")
|
||||||
|
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||||
|
require.Nil(t, result.switchError)
|
||||||
|
require.Len(t, upstream.calls, 2, "first 503, second OK")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E
|
||||||
|
// 通过 antigravityRetryLoop → handleSmartRetry → handleSingleAccountRetryInPlace 完整链路
|
||||||
|
func TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E(t *testing.T) {
|
||||||
|
// 初始请求返回 503 + 长延迟
|
||||||
|
initial503Body := []byte(`{
|
||||||
|
"error": {
|
||||||
|
"code": 503,
|
||||||
|
"status": "UNAVAILABLE",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "10s"}
|
||||||
|
],
|
||||||
|
"message": "No capacity available"
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
initial503Resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(initial503Body)),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原地重试成功
|
||||||
|
successResp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||||
|
}
|
||||||
|
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
// 第1次调用(retryLoop 主循环)返回 503
|
||||||
|
// 第2次调用(handleSingleAccountRetryInPlace 原地重试)返回 200
|
||||||
|
responses: []*http.Response{initial503Resp, successResp},
|
||||||
|
errors: []error{nil, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
account := &Account{
|
||||||
|
ID: 31,
|
||||||
|
Name: "acc-e2e-loop",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Schedulable: true,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||||
|
ctx: ctxWithSingleAccountRetry(),
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
accountRepo: repo,
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err, "should not return error on successful retry")
|
||||||
|
require.NotNil(t, result, "should return result")
|
||||||
|
require.NotNil(t, result.resp, "should return response")
|
||||||
|
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||||
|
|
||||||
|
// 验证未设模型限流
|
||||||
|
require.Len(t, repo.modelRateLimitCalls, 0,
|
||||||
|
"should NOT set model rate limit in single account retry mode")
|
||||||
|
}
|
||||||
@@ -294,8 +294,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
|
|||||||
require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)")
|
require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError
|
// TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess 测试 503 MODEL_CAPACITY_EXHAUSTED 重试成功
|
||||||
func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) {
|
// MODEL_CAPACITY_EXHAUSTED 使用固定 1s 间隔重试,不切换账号
|
||||||
|
func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T) {
|
||||||
repo := &stubAntigravityAccountRepo{}
|
repo := &stubAntigravityAccountRepo{}
|
||||||
account := &Account{
|
account := &Account{
|
||||||
ID: 3,
|
ID: 3,
|
||||||
@@ -304,7 +305,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
|||||||
Platform: PlatformAntigravity,
|
Platform: PlatformAntigravity,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值
|
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s(上游 retryDelay 应被忽略,使用固定 1s)
|
||||||
respBody := []byte(`{
|
respBody := []byte(`{
|
||||||
"error": {
|
"error": {
|
||||||
"code": 503,
|
"code": 503,
|
||||||
@@ -322,6 +323,14 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
|||||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mock: 第 1 次重试返回 200 成功
|
||||||
|
upstream := &mockSmartRetryUpstream{
|
||||||
|
responses: []*http.Response{
|
||||||
|
{StatusCode: http.StatusOK, Header: http.Header{}, Body: io.NopCloser(strings.NewReader(`{"ok":true}`))},
|
||||||
|
},
|
||||||
|
errors: []error{nil},
|
||||||
|
}
|
||||||
|
|
||||||
params := antigravityRetryLoopParams{
|
params := antigravityRetryLoopParams{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
prefix: "[test]",
|
prefix: "[test]",
|
||||||
@@ -330,6 +339,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
|||||||
action: "generateContent",
|
action: "generateContent",
|
||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
accountRepo: repo,
|
accountRepo: repo,
|
||||||
|
httpUpstream: upstream,
|
||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
return nil
|
return nil
|
||||||
@@ -343,16 +353,67 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
|||||||
|
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
require.Nil(t, result.resp)
|
require.NotNil(t, result.resp, "should return successful response")
|
||||||
|
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||||
require.Nil(t, result.err)
|
require.Nil(t, result.err)
|
||||||
require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted")
|
require.Nil(t, result.switchError, "MODEL_CAPACITY_EXHAUSTED should not return switchError")
|
||||||
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
|
|
||||||
require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel)
|
|
||||||
require.True(t, result.switchError.IsStickySession)
|
|
||||||
|
|
||||||
// 验证模型限流已设置
|
// 不应设置模型限流
|
||||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
|
||||||
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
|
require.Len(t, upstream.calls, 1, "should have made one retry call before success")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel 测试 MODEL_CAPACITY_EXHAUSTED 上下文取消
|
||||||
|
func TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel(t *testing.T) {
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
account := &Account{
|
||||||
|
ID: 3,
|
||||||
|
Name: "acc-3",
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody := []byte(`{
|
||||||
|
"error": {
|
||||||
|
"code": 503,
|
||||||
|
"status": "UNAVAILABLE",
|
||||||
|
"details": [
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||||
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 立即取消上下文,验证重试循环能正确退出
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
params := antigravityRetryLoopParams{
|
||||||
|
ctx: ctx,
|
||||||
|
prefix: "[test]",
|
||||||
|
account: account,
|
||||||
|
accessToken: "token",
|
||||||
|
action: "generateContent",
|
||||||
|
body: []byte(`{"input":"test"}`),
|
||||||
|
accountRepo: repo,
|
||||||
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||||
|
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||||
|
require.Error(t, result.err, "should return context error")
|
||||||
|
require.Nil(t, result.switchError, "should not return switchError on context cancel")
|
||||||
|
require.Empty(t, repo.modelRateLimitCalls, "should not set model rate limit on context cancel")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑
|
// TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑
|
||||||
@@ -1129,20 +1190,20 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
|
// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
|
||||||
// 503 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
|
// 429 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
|
||||||
func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) {
|
func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) {
|
||||||
failRespBody := `{
|
failRespBody := `{
|
||||||
"error": {
|
"error": {
|
||||||
"code": 503,
|
"code": 429,
|
||||||
"status": "UNAVAILABLE",
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
"details": [
|
"details": [
|
||||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
failResp := &http.Response{
|
failResp := &http.Response{
|
||||||
StatusCode: http.StatusServiceUnavailable,
|
StatusCode: http.StatusTooManyRequests,
|
||||||
Header: http.Header{},
|
Header: http.Header{},
|
||||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||||
}
|
}
|
||||||
@@ -1162,16 +1223,16 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
|
|||||||
|
|
||||||
respBody := []byte(`{
|
respBody := []byte(`{
|
||||||
"error": {
|
"error": {
|
||||||
"code": 503,
|
"code": 429,
|
||||||
"status": "UNAVAILABLE",
|
"status": "RESOURCE_EXHAUSTED",
|
||||||
"details": [
|
"details": [
|
||||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}`)
|
}`)
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
StatusCode: http.StatusServiceUnavailable,
|
StatusCode: http.StatusTooManyRequests,
|
||||||
Header: http.Header{},
|
Header: http.Header{},
|
||||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,12 +7,14 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
antigravityTokenRefreshSkew = 3 * time.Minute
|
antigravityTokenRefreshSkew = 3 * time.Minute
|
||||||
antigravityTokenCacheSkew = 5 * time.Minute
|
antigravityTokenCacheSkew = 5 * time.Minute
|
||||||
|
antigravityBackfillCooldown = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||||
@@ -23,6 +25,7 @@ type AntigravityTokenProvider struct {
|
|||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache AntigravityTokenCache
|
tokenCache AntigravityTokenCache
|
||||||
antigravityOAuthService *AntigravityOAuthService
|
antigravityOAuthService *AntigravityOAuthService
|
||||||
|
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAntigravityTokenProvider(
|
func NewAntigravityTokenProvider(
|
||||||
@@ -93,13 +96,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
p.mergeCredentials(account, tokenInfo)
|
||||||
for k, v := range account.Credentials {
|
|
||||||
if _, exists := newCredentials[k]; !exists {
|
|
||||||
newCredentials[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
account.Credentials = newCredentials
|
|
||||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
||||||
}
|
}
|
||||||
@@ -113,6 +110,21 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return "", errors.New("access_token not found in credentials")
|
return "", errors.New("access_token not found in credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现
|
||||||
|
// "Invalid project resource name projects/"。
|
||||||
|
// 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。
|
||||||
|
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
||||||
|
if p.shouldAttemptBackfill(account.ID) {
|
||||||
|
p.markBackfillAttempted(account.ID)
|
||||||
|
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
||||||
|
account.Credentials["project_id"] = projectID
|
||||||
|
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||||
|
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
@@ -144,6 +156,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
|
||||||
|
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
|
||||||
|
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
if _, exists := newCredentials[k]; !exists {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
account.Credentials = newCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试)
|
||||||
|
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
|
||||||
|
if v, ok := p.backfillCooldown.Load(accountID); ok {
|
||||||
|
if lastAttempt, ok := v.(time.Time); ok {
|
||||||
|
return time.Since(lastAttempt) > antigravityBackfillCooldown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
|
||||||
|
p.backfillCooldown.Store(accountID, time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
func AntigravityTokenCacheKey(account *Account) string {
|
func AntigravityTokenCacheKey(account *Account) string {
|
||||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ type ModelPricing struct {
|
|||||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||||
CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退
|
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||||
CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退
|
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
|||||||
if s.pricingService != nil {
|
if s.pricingService != nil {
|
||||||
litellmPricing := s.pricingService.GetModelPricing(model)
|
litellmPricing := s.pricingService.GetModelPricing(model)
|
||||||
if litellmPricing != nil {
|
if litellmPricing != nil {
|
||||||
|
// 启用 5m/1h 分类计费的条件:
|
||||||
|
// 1. 存在 1h 价格
|
||||||
|
// 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费)
|
||||||
|
price5m := litellmPricing.CacheCreationInputTokenCost
|
||||||
|
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
|
||||||
|
enableBreakdown := price1h > 0 && price1h > price5m
|
||||||
return &ModelPricing{
|
return &ModelPricing{
|
||||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||||
SupportsCacheBreakdown: false,
|
CacheCreation5mPrice: price5m,
|
||||||
|
CacheCreation1hPrice: price1h,
|
||||||
|
SupportsCacheBreakdown: enableBreakdown,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
|
|||||||
|
|
||||||
// 计算缓存费用
|
// 计算缓存费用
|
||||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||||
// 支持详细缓存分类的模型(5分钟/1小时缓存)
|
// 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token)
|
||||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
|
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
|
||||||
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
|
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
|
||||||
|
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
|
||||||
|
} else {
|
||||||
|
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
|
||||||
|
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// 标准缓存创建价格(per-token)
|
// 标准缓存创建价格(per-token)
|
||||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||||
@@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
|
|||||||
|
|
||||||
// 范围内部分:正常计费
|
// 范围内部分:正常计费
|
||||||
inRangeTokens := UsageTokens{
|
inRangeTokens := UsageTokens{
|
||||||
InputTokens: inRangeInputTokens,
|
InputTokens: inRangeInputTokens,
|
||||||
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||||
CacheCreationTokens: tokens.CacheCreationTokens,
|
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||||
CacheReadTokens: inRangeCacheTokens,
|
CacheReadTokens: inRangeCacheTokens,
|
||||||
|
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
|
||||||
|
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
|
||||||
}
|
}
|
||||||
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
112
backend/internal/service/crs_sync_helpers_test.go
Normal file
112
backend/internal/service/crs_sync_helpers_test.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildSelectedSet(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ids []string
|
||||||
|
wantNil bool
|
||||||
|
wantSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil input returns nil (backward compatible: create all)",
|
||||||
|
ids: nil,
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty slice returns empty map (create none)",
|
||||||
|
ids: []string{},
|
||||||
|
wantNil: false,
|
||||||
|
wantSize: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single ID",
|
||||||
|
ids: []string{"abc-123"},
|
||||||
|
wantNil: false,
|
||||||
|
wantSize: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple IDs",
|
||||||
|
ids: []string{"a", "b", "c"},
|
||||||
|
wantNil: false,
|
||||||
|
wantSize: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicate IDs are deduplicated",
|
||||||
|
ids: []string{"a", "a", "b"},
|
||||||
|
wantNil: false,
|
||||||
|
wantSize: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := buildSelectedSet(tt.ids)
|
||||||
|
if tt.wantNil {
|
||||||
|
if got != nil {
|
||||||
|
t.Errorf("buildSelectedSet(%v) = %v, want nil", tt.ids, got)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("buildSelectedSet(%v) = nil, want non-nil map", tt.ids)
|
||||||
|
}
|
||||||
|
if len(got) != tt.wantSize {
|
||||||
|
t.Errorf("buildSelectedSet(%v) has %d entries, want %d", tt.ids, len(got), tt.wantSize)
|
||||||
|
}
|
||||||
|
// Verify all unique IDs are present
|
||||||
|
for _, id := range tt.ids {
|
||||||
|
if _, ok := got[id]; !ok {
|
||||||
|
t.Errorf("buildSelectedSet(%v) missing key %q", tt.ids, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldCreateAccount(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
crsID string
|
||||||
|
selectedSet map[string]struct{}
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil set allows all (backward compatible)",
|
||||||
|
crsID: "any-id",
|
||||||
|
selectedSet: nil,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty set blocks all",
|
||||||
|
crsID: "any-id",
|
||||||
|
selectedSet: map[string]struct{}{},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ID in set is allowed",
|
||||||
|
crsID: "abc-123",
|
||||||
|
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ID not in set is blocked",
|
||||||
|
crsID: "xyz-789",
|
||||||
|
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := shouldCreateAccount(tt.crsID, tt.selectedSet)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("shouldCreateAccount(%q, %v) = %v, want %v",
|
||||||
|
tt.crsID, tt.selectedSet, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -45,10 +45,11 @@ func NewCRSSyncService(
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SyncFromCRSInput struct {
|
type SyncFromCRSInput struct {
|
||||||
BaseURL string
|
BaseURL string
|
||||||
Username string
|
Username string
|
||||||
Password string
|
Password string
|
||||||
SyncProxies bool
|
SyncProxies bool
|
||||||
|
SelectedAccountIDs []string // if non-empty, only create new accounts with these CRS IDs
|
||||||
}
|
}
|
||||||
|
|
||||||
type SyncFromCRSItemResult struct {
|
type SyncFromCRSItemResult struct {
|
||||||
@@ -190,25 +191,27 @@ type crsGeminiAPIKeyAccount struct {
|
|||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
|
// fetchCRSExport validates the connection parameters, authenticates with CRS,
|
||||||
|
// and returns the exported accounts. Shared by SyncFromCRS and PreviewFromCRS.
|
||||||
|
func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, password string) (*crsExportResponse, error) {
|
||||||
if s.cfg == nil {
|
if s.cfg == nil {
|
||||||
return nil, errors.New("config is not available")
|
return nil, errors.New("config is not available")
|
||||||
}
|
}
|
||||||
baseURL := strings.TrimSpace(input.BaseURL)
|
normalizedURL := strings.TrimSpace(baseURL)
|
||||||
if s.cfg.Security.URLAllowlist.Enabled {
|
if s.cfg.Security.URLAllowlist.Enabled {
|
||||||
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
|
normalized, err := normalizeBaseURL(normalizedURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
baseURL = normalized
|
normalizedURL = normalized
|
||||||
} else {
|
} else {
|
||||||
normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
normalized, err := urlvalidator.ValidateURLFormat(normalizedURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid base_url: %w", err)
|
return nil, fmt.Errorf("invalid base_url: %w", err)
|
||||||
}
|
}
|
||||||
baseURL = normalized
|
normalizedURL = normalized
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
|
if strings.TrimSpace(username) == "" || strings.TrimSpace(password) == "" {
|
||||||
return nil, errors.New("username and password are required")
|
return nil, errors.New("username and password are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,12 +224,16 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
client = &http.Client{Timeout: 20 * time.Second}
|
client = &http.Client{Timeout: 20 * time.Second}
|
||||||
}
|
}
|
||||||
|
|
||||||
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
|
adminToken, err := crsLogin(ctx, client, normalizedURL, username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
exported, err := crsExportAccounts(ctx, client, baseURL, adminToken)
|
return crsExportAccounts(ctx, client, normalizedURL, adminToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
|
||||||
|
exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -241,6 +248,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
selectedSet := buildSelectedSet(input.SelectedAccountIDs)
|
||||||
|
|
||||||
var proxies []Proxy
|
var proxies []Proxy
|
||||||
if input.SyncProxies {
|
if input.SyncProxies {
|
||||||
proxies, _ = s.proxyRepo.ListActive(ctx)
|
proxies, _ = s.proxyRepo.ListActive(ctx)
|
||||||
@@ -329,6 +338,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
}
|
}
|
||||||
|
|
||||||
if existing == nil {
|
if existing == nil {
|
||||||
|
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||||
|
item.Action = "skipped"
|
||||||
|
item.Error = "not selected"
|
||||||
|
result.Skipped++
|
||||||
|
result.Items = append(result.Items, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Name: defaultName(src.Name, src.ID),
|
Name: defaultName(src.Name, src.ID),
|
||||||
Platform: PlatformAnthropic,
|
Platform: PlatformAnthropic,
|
||||||
@@ -446,6 +462,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
}
|
}
|
||||||
|
|
||||||
if existing == nil {
|
if existing == nil {
|
||||||
|
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||||
|
item.Action = "skipped"
|
||||||
|
item.Error = "not selected"
|
||||||
|
result.Skipped++
|
||||||
|
result.Items = append(result.Items, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Name: defaultName(src.Name, src.ID),
|
Name: defaultName(src.Name, src.ID),
|
||||||
Platform: PlatformAnthropic,
|
Platform: PlatformAnthropic,
|
||||||
@@ -569,6 +592,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
}
|
}
|
||||||
|
|
||||||
if existing == nil {
|
if existing == nil {
|
||||||
|
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||||
|
item.Action = "skipped"
|
||||||
|
item.Error = "not selected"
|
||||||
|
result.Skipped++
|
||||||
|
result.Items = append(result.Items, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Name: defaultName(src.Name, src.ID),
|
Name: defaultName(src.Name, src.ID),
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
@@ -690,6 +720,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
}
|
}
|
||||||
|
|
||||||
if existing == nil {
|
if existing == nil {
|
||||||
|
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||||
|
item.Action = "skipped"
|
||||||
|
item.Error = "not selected"
|
||||||
|
result.Skipped++
|
||||||
|
result.Items = append(result.Items, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Name: defaultName(src.Name, src.ID),
|
Name: defaultName(src.Name, src.ID),
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
@@ -798,6 +835,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
}
|
}
|
||||||
|
|
||||||
if existing == nil {
|
if existing == nil {
|
||||||
|
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||||
|
item.Action = "skipped"
|
||||||
|
item.Error = "not selected"
|
||||||
|
result.Skipped++
|
||||||
|
result.Items = append(result.Items, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Name: defaultName(src.Name, src.ID),
|
Name: defaultName(src.Name, src.ID),
|
||||||
Platform: PlatformGemini,
|
Platform: PlatformGemini,
|
||||||
@@ -909,6 +953,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
|||||||
}
|
}
|
||||||
|
|
||||||
if existing == nil {
|
if existing == nil {
|
||||||
|
if !shouldCreateAccount(src.ID, selectedSet) {
|
||||||
|
item.Action = "skipped"
|
||||||
|
item.Error = "not selected"
|
||||||
|
result.Skipped++
|
||||||
|
result.Items = append(result.Items, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Name: defaultName(src.Name, src.ID),
|
Name: defaultName(src.Name, src.ID),
|
||||||
Platform: PlatformGemini,
|
Platform: PlatformGemini,
|
||||||
@@ -1253,3 +1304,102 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account
|
|||||||
|
|
||||||
return newCredentials
|
return newCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildSelectedSet converts a slice of selected CRS account IDs to a set for O(1) lookup.
|
||||||
|
// Returns nil if ids is nil (field not sent → backward compatible: create all).
|
||||||
|
// Returns an empty map if ids is non-nil but empty (user selected none → create none).
|
||||||
|
func buildSelectedSet(ids []string) map[string]struct{} {
|
||||||
|
if ids == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
set := make(map[string]struct{}, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
set[id] = struct{}{}
|
||||||
|
}
|
||||||
|
return set
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldCreateAccount checks if a new CRS account should be created based on user selection.
|
||||||
|
// Returns true if selectedSet is nil (backward compatible: create all) or if crsID is in the set.
|
||||||
|
func shouldCreateAccount(crsID string, selectedSet map[string]struct{}) bool {
|
||||||
|
if selectedSet == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
_, ok := selectedSet[crsID]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreviewFromCRSResult contains the preview of accounts from CRS before sync.
|
||||||
|
type PreviewFromCRSResult struct {
|
||||||
|
NewAccounts []CRSPreviewAccount `json:"new_accounts"`
|
||||||
|
ExistingAccounts []CRSPreviewAccount `json:"existing_accounts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CRSPreviewAccount represents a single account in the preview result.
|
||||||
|
type CRSPreviewAccount struct {
|
||||||
|
CRSAccountID string `json:"crs_account_id"`
|
||||||
|
Kind string `json:"kind"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreviewFromCRS connects to CRS, fetches all accounts, and classifies them
|
||||||
|
// as new or existing by batch-querying local crs_account_id mappings.
|
||||||
|
func (s *CRSSyncService) PreviewFromCRS(ctx context.Context, input SyncFromCRSInput) (*PreviewFromCRSResult, error) {
|
||||||
|
exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Batch query all existing CRS account IDs
|
||||||
|
existingCRSIDs, err := s.accountRepo.ListCRSAccountIDs(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list existing CRS accounts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &PreviewFromCRSResult{
|
||||||
|
NewAccounts: make([]CRSPreviewAccount, 0),
|
||||||
|
ExistingAccounts: make([]CRSPreviewAccount, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
classify := func(crsID, kind, name, platform, accountType string) {
|
||||||
|
preview := CRSPreviewAccount{
|
||||||
|
CRSAccountID: crsID,
|
||||||
|
Kind: kind,
|
||||||
|
Name: defaultName(name, crsID),
|
||||||
|
Platform: platform,
|
||||||
|
Type: accountType,
|
||||||
|
}
|
||||||
|
if _, exists := existingCRSIDs[crsID]; exists {
|
||||||
|
result.ExistingAccounts = append(result.ExistingAccounts, preview)
|
||||||
|
} else {
|
||||||
|
result.NewAccounts = append(result.NewAccounts, preview)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, src := range exported.Data.ClaudeAccounts {
|
||||||
|
authType := strings.TrimSpace(src.AuthType)
|
||||||
|
if authType == "" {
|
||||||
|
authType = AccountTypeOAuth
|
||||||
|
}
|
||||||
|
classify(src.ID, src.Kind, src.Name, PlatformAnthropic, authType)
|
||||||
|
}
|
||||||
|
for _, src := range exported.Data.ClaudeConsoleAccounts {
|
||||||
|
classify(src.ID, src.Kind, src.Name, PlatformAnthropic, AccountTypeAPIKey)
|
||||||
|
}
|
||||||
|
for _, src := range exported.Data.OpenAIOAuthAccounts {
|
||||||
|
classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeOAuth)
|
||||||
|
}
|
||||||
|
for _, src := range exported.Data.OpenAIResponsesAccounts {
|
||||||
|
classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeAPIKey)
|
||||||
|
}
|
||||||
|
for _, src := range exported.Data.GeminiOAuthAccounts {
|
||||||
|
classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeOAuth)
|
||||||
|
}
|
||||||
|
for _, src := range exported.Data.GeminiAPIKeyAccounts {
|
||||||
|
classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -61,6 +61,11 @@ func applyErrorPassthroughRule(
|
|||||||
errMsg = *rule.CustomMessage
|
errMsg = *rule.CustomMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 命中 skip_monitoring 时在 context 中标记,供 ops_error_logger 跳过记录。
|
||||||
|
if rule.SkipMonitoring {
|
||||||
|
c.Set(OpsSkipPassthroughKey, true)
|
||||||
|
}
|
||||||
|
|
||||||
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
|
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
|
||||||
errType = "upstream_error"
|
errType = "upstream_error"
|
||||||
return status, errType, errMsg, true
|
return status, errType, errMsg, true
|
||||||
|
|||||||
@@ -194,6 +194,63 @@ func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) {
|
|||||||
assert.Equal(t, "Gemini上游失败", errField["message"])
|
assert.Equal(t, "Gemini上游失败", errField["message"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
|
||||||
|
rule.SkipMonitoring = true
|
||||||
|
|
||||||
|
ruleSvc := &ErrorPassthroughService{}
|
||||||
|
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
|
||||||
|
BindErrorPassthroughService(c, ruleSvc)
|
||||||
|
|
||||||
|
_, _, _, matched := applyErrorPassthroughRule(
|
||||||
|
c,
|
||||||
|
PlatformAnthropic,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
[]byte(`{"error":{"message":"prompt is too long"}}`),
|
||||||
|
http.StatusBadGateway,
|
||||||
|
"upstream_error",
|
||||||
|
"Upstream request failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.True(t, matched)
|
||||||
|
v, exists := c.Get(OpsSkipPassthroughKey)
|
||||||
|
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
||||||
|
boolVal, ok := v.(bool)
|
||||||
|
assert.True(t, ok, "value should be bool")
|
||||||
|
assert.True(t, boolVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyErrorPassthroughRule_NoSkipMonitoringDoesNotSetContextKey(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
|
||||||
|
rule.SkipMonitoring = false
|
||||||
|
|
||||||
|
ruleSvc := &ErrorPassthroughService{}
|
||||||
|
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
|
||||||
|
BindErrorPassthroughService(c, ruleSvc)
|
||||||
|
|
||||||
|
_, _, _, matched := applyErrorPassthroughRule(
|
||||||
|
c,
|
||||||
|
PlatformAnthropic,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
[]byte(`{"error":{"message":"prompt is too long"}}`),
|
||||||
|
http.StatusBadGateway,
|
||||||
|
"upstream_error",
|
||||||
|
"Upstream request failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.True(t, matched)
|
||||||
|
_, exists := c.Get(OpsSkipPassthroughKey)
|
||||||
|
assert.False(t, exists, "OpsSkipPassthroughKey should NOT be set when skip_monitoring=false")
|
||||||
|
}
|
||||||
|
|
||||||
func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule {
|
func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule {
|
||||||
return &model.ErrorPassthroughRule{
|
return &model.ErrorPassthroughRule{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
|
|||||||
@@ -45,10 +45,20 @@ type ErrorPassthroughService struct {
|
|||||||
cache ErrorPassthroughCache
|
cache ErrorPassthroughCache
|
||||||
|
|
||||||
// 本地内存缓存,用于快速匹配
|
// 本地内存缓存,用于快速匹配
|
||||||
localCache []*model.ErrorPassthroughRule
|
localCache []*cachedPassthroughRule
|
||||||
localCacheMu sync.RWMutex
|
localCacheMu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cachedPassthroughRule 预计算的规则缓存,避免运行时重复 ToLower
|
||||||
|
type cachedPassthroughRule struct {
|
||||||
|
*model.ErrorPassthroughRule
|
||||||
|
lowerKeywords []string // 预计算的小写关键词
|
||||||
|
lowerPlatforms []string // 预计算的小写平台
|
||||||
|
errorCodeSet map[int]struct{} // 预计算的 error code set
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxBodyMatchLen = 8 << 10 // 8KB,错误信息不会在 8KB 之后才出现
|
||||||
|
|
||||||
// NewErrorPassthroughService 创建错误透传规则服务
|
// NewErrorPassthroughService 创建错误透传规则服务
|
||||||
func NewErrorPassthroughService(
|
func NewErrorPassthroughService(
|
||||||
repo ErrorPassthroughRepository,
|
repo ErrorPassthroughRepository,
|
||||||
@@ -150,17 +160,19 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyStr := strings.ToLower(string(body))
|
lowerPlatform := strings.ToLower(platform)
|
||||||
|
var bodyLower string // 延迟初始化,只在需要关键词匹配时计算
|
||||||
|
var bodyLowerDone bool
|
||||||
|
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if !rule.Enabled {
|
if !rule.Enabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !s.platformMatches(rule, platform) {
|
if !s.platformMatchesCached(rule, lowerPlatform) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if s.ruleMatches(rule, statusCode, bodyStr) {
|
if s.ruleMatchesOptimized(rule, statusCode, body, &bodyLower, &bodyLowerDone) {
|
||||||
return rule
|
return rule.ErrorPassthroughRule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,7 +180,7 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getCachedRules 获取缓存的规则列表(按优先级排序)
|
// getCachedRules 获取缓存的规则列表(按优先级排序)
|
||||||
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
|
func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule {
|
||||||
s.localCacheMu.RLock()
|
s.localCacheMu.RLock()
|
||||||
rules := s.localCache
|
rules := s.localCache
|
||||||
s.localCacheMu.RUnlock()
|
s.localCacheMu.RUnlock()
|
||||||
@@ -223,17 +235,39 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setLocalCache 设置本地缓存
|
// setLocalCache 设置本地缓存,预计算小写值和 set 以避免运行时重复计算
|
||||||
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
|
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
|
||||||
|
cached := make([]*cachedPassthroughRule, len(rules))
|
||||||
|
for i, r := range rules {
|
||||||
|
cr := &cachedPassthroughRule{ErrorPassthroughRule: r}
|
||||||
|
if len(r.Keywords) > 0 {
|
||||||
|
cr.lowerKeywords = make([]string, len(r.Keywords))
|
||||||
|
for j, kw := range r.Keywords {
|
||||||
|
cr.lowerKeywords[j] = strings.ToLower(kw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(r.Platforms) > 0 {
|
||||||
|
cr.lowerPlatforms = make([]string, len(r.Platforms))
|
||||||
|
for j, p := range r.Platforms {
|
||||||
|
cr.lowerPlatforms[j] = strings.ToLower(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(r.ErrorCodes) > 0 {
|
||||||
|
cr.errorCodeSet = make(map[int]struct{}, len(r.ErrorCodes))
|
||||||
|
for _, code := range r.ErrorCodes {
|
||||||
|
cr.errorCodeSet[code] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cached[i] = cr
|
||||||
|
}
|
||||||
|
|
||||||
// 按优先级排序
|
// 按优先级排序
|
||||||
sorted := make([]*model.ErrorPassthroughRule, len(rules))
|
sort.Slice(cached, func(i, j int) bool {
|
||||||
copy(sorted, rules)
|
return cached[i].Priority < cached[j].Priority
|
||||||
sort.Slice(sorted, func(i, j int) bool {
|
|
||||||
return sorted[i].Priority < sorted[j].Priority
|
|
||||||
})
|
})
|
||||||
|
|
||||||
s.localCacheMu.Lock()
|
s.localCacheMu.Lock()
|
||||||
s.localCache = sorted
|
s.localCache = cached
|
||||||
s.localCacheMu.Unlock()
|
s.localCacheMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -273,62 +307,79 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// platformMatches 检查平台是否匹配
|
// ensureBodyLower 延迟初始化 body 的小写版本,只做一次转换,限制 8KB
|
||||||
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
|
func ensureBodyLower(body []byte, bodyLower *string, done *bool) string {
|
||||||
// 如果没有配置平台限制,则匹配所有平台
|
if *done {
|
||||||
if len(rule.Platforms) == 0 {
|
return *bodyLower
|
||||||
|
}
|
||||||
|
b := body
|
||||||
|
if len(b) > maxBodyMatchLen {
|
||||||
|
b = b[:maxBodyMatchLen]
|
||||||
|
}
|
||||||
|
*bodyLower = strings.ToLower(string(b))
|
||||||
|
*done = true
|
||||||
|
return *bodyLower
|
||||||
|
}
|
||||||
|
|
||||||
|
// platformMatchesCached 使用预计算的小写平台检查是否匹配
|
||||||
|
func (s *ErrorPassthroughService) platformMatchesCached(rule *cachedPassthroughRule, lowerPlatform string) bool {
|
||||||
|
if len(rule.lowerPlatforms) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
for _, p := range rule.lowerPlatforms {
|
||||||
platform = strings.ToLower(platform)
|
if p == lowerPlatform {
|
||||||
for _, p := range rule.Platforms {
|
|
||||||
if strings.ToLower(p) == platform {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ruleMatches 检查规则是否匹配
|
// ruleMatchesOptimized 优化的规则匹配,支持短路和延迟 body 转换
|
||||||
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
|
func (s *ErrorPassthroughService) ruleMatchesOptimized(rule *cachedPassthroughRule, statusCode int, body []byte, bodyLower *string, bodyLowerDone *bool) bool {
|
||||||
hasErrorCodes := len(rule.ErrorCodes) > 0
|
hasErrorCodes := len(rule.errorCodeSet) > 0
|
||||||
hasKeywords := len(rule.Keywords) > 0
|
hasKeywords := len(rule.lowerKeywords) > 0
|
||||||
|
|
||||||
// 如果没有配置任何条件,不匹配
|
|
||||||
if !hasErrorCodes && !hasKeywords {
|
if !hasErrorCodes && !hasKeywords {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
|
codeMatch := !hasErrorCodes || s.containsIntSet(rule.errorCodeSet, statusCode)
|
||||||
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
|
|
||||||
|
|
||||||
if rule.MatchMode == model.MatchModeAll {
|
if rule.MatchMode == model.MatchModeAll {
|
||||||
// "all" 模式:所有配置的条件都必须满足
|
// "all" 模式:所有配置的条件都必须满足,短路
|
||||||
return codeMatch && keywordMatch
|
if hasErrorCodes && !codeMatch {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if hasKeywords {
|
||||||
|
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||||
|
}
|
||||||
|
return codeMatch
|
||||||
}
|
}
|
||||||
|
|
||||||
// "any" 模式:任一条件满足即可
|
// "any" 模式:任一条件满足即可,短路
|
||||||
if hasErrorCodes && hasKeywords {
|
if hasErrorCodes && hasKeywords {
|
||||||
return codeMatch || keywordMatch
|
if codeMatch {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||||
}
|
}
|
||||||
return codeMatch && keywordMatch
|
// 只配置了一种条件
|
||||||
|
if hasKeywords {
|
||||||
|
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||||
|
}
|
||||||
|
return codeMatch
|
||||||
}
|
}
|
||||||
|
|
||||||
// containsInt 检查切片是否包含指定整数
|
// containsIntSet 使用 map 查找替代线性扫描
|
||||||
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
|
func (s *ErrorPassthroughService) containsIntSet(set map[int]struct{}, val int) bool {
|
||||||
for _, v := range slice {
|
_, ok := set[val]
|
||||||
if v == val {
|
return ok
|
||||||
return true
|
}
|
||||||
}
|
|
||||||
}
|
// containsAnyKeywordCached 使用预计算的小写关键词检查匹配
|
||||||
return false
|
func (s *ErrorPassthroughService) containsAnyKeywordCached(bodyLower string, lowerKeywords []string) bool {
|
||||||
}
|
for _, kw := range lowerKeywords {
|
||||||
|
if strings.Contains(bodyLower, kw) {
|
||||||
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
|
|
||||||
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
|
|
||||||
for _, kw := range keywords {
|
|
||||||
if strings.Contains(bodyLower, strings.ToLower(kw)) {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -145,32 +145,58 @@ func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughServic
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newCachedRuleForTest 从 model.ErrorPassthroughRule 创建 cachedPassthroughRule(测试用)
|
||||||
|
func newCachedRuleForTest(rule *model.ErrorPassthroughRule) *cachedPassthroughRule {
|
||||||
|
cr := &cachedPassthroughRule{ErrorPassthroughRule: rule}
|
||||||
|
if len(rule.Keywords) > 0 {
|
||||||
|
cr.lowerKeywords = make([]string, len(rule.Keywords))
|
||||||
|
for j, kw := range rule.Keywords {
|
||||||
|
cr.lowerKeywords[j] = strings.ToLower(kw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(rule.Platforms) > 0 {
|
||||||
|
cr.lowerPlatforms = make([]string, len(rule.Platforms))
|
||||||
|
for j, p := range rule.Platforms {
|
||||||
|
cr.lowerPlatforms[j] = strings.ToLower(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(rule.ErrorCodes) > 0 {
|
||||||
|
cr.errorCodeSet = make(map[int]struct{}, len(rule.ErrorCodes))
|
||||||
|
for _, code := range rule.ErrorCodes {
|
||||||
|
cr.errorCodeSet[code] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cr
|
||||||
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// 测试 ruleMatches 核心匹配逻辑
|
// 测试 ruleMatchesOptimized 核心匹配逻辑
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
func TestRuleMatches_NoConditions(t *testing.T) {
|
func TestRuleMatches_NoConditions(t *testing.T) {
|
||||||
// 没有配置任何条件时,不应该匹配
|
// 没有配置任何条件时,不应该匹配
|
||||||
svc := newTestService(nil)
|
svc := newTestService(nil)
|
||||||
rule := &model.ErrorPassthroughRule{
|
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
ErrorCodes: []int{},
|
ErrorCodes: []int{},
|
||||||
Keywords: []string{},
|
Keywords: []string{},
|
||||||
MatchMode: model.MatchModeAny,
|
MatchMode: model.MatchModeAny,
|
||||||
}
|
})
|
||||||
|
|
||||||
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
|
var bodyLower string
|
||||||
|
var bodyLowerDone bool
|
||||||
|
assert.False(t, svc.ruleMatchesOptimized(rule, 422, []byte("some error message"), &bodyLower, &bodyLowerDone),
|
||||||
"没有配置条件时不应该匹配")
|
"没有配置条件时不应该匹配")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||||
svc := newTestService(nil)
|
svc := newTestService(nil)
|
||||||
rule := &model.ErrorPassthroughRule{
|
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
ErrorCodes: []int{422, 400},
|
ErrorCodes: []int{422, 400},
|
||||||
Keywords: []string{},
|
Keywords: []string{},
|
||||||
MatchMode: model.MatchModeAny,
|
MatchMode: model.MatchModeAny,
|
||||||
}
|
})
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -186,7 +212,9 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
var bodyLower string
|
||||||
|
var bodyLowerDone bool
|
||||||
|
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||||
assert.Equal(t, tt.expected, result)
|
assert.Equal(t, tt.expected, result)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -194,12 +222,12 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
|||||||
|
|
||||||
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||||
svc := newTestService(nil)
|
svc := newTestService(nil)
|
||||||
rule := &model.ErrorPassthroughRule{
|
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
ErrorCodes: []int{},
|
ErrorCodes: []int{},
|
||||||
Keywords: []string{"context limit", "model not supported"},
|
Keywords: []string{"context limit", "model not supported"},
|
||||||
MatchMode: model.MatchModeAny,
|
MatchMode: model.MatchModeAny,
|
||||||
}
|
})
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -210,16 +238,14 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
|||||||
{"关键词匹配 context limit", 500, "error: context limit reached", true},
|
{"关键词匹配 context limit", 500, "error: context limit reached", true},
|
||||||
{"关键词匹配 model not supported", 400, "the model not supported here", true},
|
{"关键词匹配 model not supported", 400, "the model not supported here", true},
|
||||||
{"关键词不匹配", 422, "some other error", false},
|
{"关键词不匹配", 422, "some other error", false},
|
||||||
// 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的
|
{"关键词大小写 - 自动转换", 500, "Context Limit exceeded", true},
|
||||||
// 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches
|
|
||||||
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// 模拟 MatchRule 的行为:先转换为小写
|
var bodyLower string
|
||||||
bodyLower := strings.ToLower(tt.body)
|
var bodyLowerDone bool
|
||||||
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
|
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||||
assert.Equal(t, tt.expected, result)
|
assert.Equal(t, tt.expected, result)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -228,12 +254,12 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
|||||||
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||||
// any 模式:错误码 OR 关键词
|
// any 模式:错误码 OR 关键词
|
||||||
svc := newTestService(nil)
|
svc := newTestService(nil)
|
||||||
rule := &model.ErrorPassthroughRule{
|
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
ErrorCodes: []int{422, 400},
|
ErrorCodes: []int{422, 400},
|
||||||
Keywords: []string{"context limit"},
|
Keywords: []string{"context limit"},
|
||||||
MatchMode: model.MatchModeAny,
|
MatchMode: model.MatchModeAny,
|
||||||
}
|
})
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -274,7 +300,9 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
var bodyLower string
|
||||||
|
var bodyLowerDone bool
|
||||||
|
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||||
assert.Equal(t, tt.expected, result, tt.reason)
|
assert.Equal(t, tt.expected, result, tt.reason)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -283,12 +311,12 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
|||||||
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
||||||
// all 模式:错误码 AND 关键词
|
// all 模式:错误码 AND 关键词
|
||||||
svc := newTestService(nil)
|
svc := newTestService(nil)
|
||||||
rule := &model.ErrorPassthroughRule{
|
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
ErrorCodes: []int{422, 400},
|
ErrorCodes: []int{422, 400},
|
||||||
Keywords: []string{"context limit"},
|
Keywords: []string{"context limit"},
|
||||||
MatchMode: model.MatchModeAll,
|
MatchMode: model.MatchModeAll,
|
||||||
}
|
})
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -329,14 +357,16 @@ func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
var bodyLower string
|
||||||
|
var bodyLowerDone bool
|
||||||
|
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||||
assert.Equal(t, tt.expected, result, tt.reason)
|
assert.Equal(t, tt.expected, result, tt.reason)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// 测试 platformMatches 平台匹配逻辑
|
// 测试 platformMatchesCached 平台匹配逻辑
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
func TestPlatformMatches(t *testing.T) {
|
func TestPlatformMatches(t *testing.T) {
|
||||||
@@ -394,10 +424,10 @@ func TestPlatformMatches(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
rule := &model.ErrorPassthroughRule{
|
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||||
Platforms: tt.rulePlatforms,
|
Platforms: tt.rulePlatforms,
|
||||||
}
|
})
|
||||||
result := svc.platformMatches(rule, tt.requestPlatform)
|
result := svc.platformMatchesCached(rule, strings.ToLower(tt.requestPlatform))
|
||||||
assert.Equal(t, tt.expected, result)
|
assert.Equal(t, tt.expected, result)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
|
|||||||
customCodes: []any{float64(500)},
|
customCodes: []any{float64(500)},
|
||||||
expectHandleError: 0,
|
expectHandleError: 0,
|
||||||
expectUpstream: 1,
|
expectUpstream: 1,
|
||||||
expectStatusCode: 429,
|
expectStatusCode: 500,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "500_in_custom_codes_matched",
|
name: "500_in_custom_codes_matched",
|
||||||
@@ -364,3 +364,109 @@ func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
|
|||||||
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
||||||
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// epTrackingRepo — records SetRateLimited / SetError calls for verification.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type epTrackingRepo struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
rateLimitedCalls int
|
||||||
|
rateLimitedID int64
|
||||||
|
setErrCalls int
|
||||||
|
setErrID int64
|
||||||
|
tempCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
|
||||||
|
r.rateLimitedCalls++
|
||||||
|
r.rateLimitedID = id
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error {
|
||||||
|
r.setErrCalls++
|
||||||
|
r.setErrID = id
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||||
|
r.tempCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit
|
||||||
|
//
|
||||||
|
// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码),
|
||||||
|
// 当上游返回 429/500/503/401 时:
|
||||||
|
// - 返回给客户端的状态码必须是 500(而不是透传原始状态码)
|
||||||
|
// - 不调用 SetRateLimited(不进入限流状态)
|
||||||
|
// - 不调用 SetError(不停止调度)
|
||||||
|
// - 不调用 handleError
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) {
|
||||||
|
errorCodes := []int{429, 500, 503, 401, 403}
|
||||||
|
|
||||||
|
for _, upstreamStatus := range errorCodes {
|
||||||
|
t.Run(http.StatusText(upstreamStatus), func(t *testing.T) {
|
||||||
|
saveAndSetBaseURLs(t)
|
||||||
|
|
||||||
|
upstream := &epFixedUpstream{
|
||||||
|
statusCode: upstreamStatus,
|
||||||
|
body: `{"error":"some upstream error"}`,
|
||||||
|
}
|
||||||
|
repo := &epTrackingRepo{}
|
||||||
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 500,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Schedulable: true,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(599)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var handleErrorCount int
|
||||||
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||||
|
handleErrorCount++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
result, err := svc.antigravityRetryLoop(p)
|
||||||
|
|
||||||
|
// 不应返回 error(Skipped 不触发账号切换)
|
||||||
|
require.NoError(t, err, "should not return error")
|
||||||
|
require.NotNil(t, result, "result should not be nil")
|
||||||
|
require.NotNil(t, result.resp, "response should not be nil")
|
||||||
|
defer func() { _ = result.resp.Body.Close() }()
|
||||||
|
|
||||||
|
// 状态码必须是 500(不透传原始状态码)
|
||||||
|
require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode,
|
||||||
|
"skipped error should return 500, not %d", upstreamStatus)
|
||||||
|
|
||||||
|
// 不调用 handleError
|
||||||
|
require.Equal(t, 0, handleErrorCount,
|
||||||
|
"handleError should NOT be called for skipped errors")
|
||||||
|
|
||||||
|
// 不标记限流
|
||||||
|
require.Equal(t, 0, repo.rateLimitedCalls,
|
||||||
|
"SetRateLimited should NOT be called for skipped errors")
|
||||||
|
|
||||||
|
// 不停止调度
|
||||||
|
require.Equal(t, 0, repo.setErrCalls,
|
||||||
|
"SetError should NOT be called for skipped errors")
|
||||||
|
|
||||||
|
// 只调用一次上游(不重试)
|
||||||
|
require.Equal(t, 1, upstream.calls,
|
||||||
|
"should call upstream exactly once (no retry)")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode int
|
statusCode int
|
||||||
body []byte
|
body []byte
|
||||||
expectedHandled bool
|
expectedHandled bool
|
||||||
|
expectedStatus int // expected outStatus
|
||||||
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
||||||
handleErrorCalls int
|
handleErrorCalls int
|
||||||
}{
|
}{
|
||||||
@@ -171,6 +172,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
body: []byte(`"error"`),
|
body: []byte(`"error"`),
|
||||||
expectedHandled: false,
|
expectedHandled: false,
|
||||||
|
expectedStatus: 500, // passthrough
|
||||||
handleErrorCalls: 0,
|
handleErrorCalls: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -187,6 +189,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 500, // not in custom codes
|
statusCode: 500, // not in custom codes
|
||||||
body: []byte(`"error"`),
|
body: []byte(`"error"`),
|
||||||
expectedHandled: true,
|
expectedHandled: true,
|
||||||
|
expectedStatus: http.StatusInternalServerError, // skipped → 500
|
||||||
handleErrorCalls: 0,
|
handleErrorCalls: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -203,6 +206,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
body: []byte(`"error"`),
|
body: []byte(`"error"`),
|
||||||
expectedHandled: true,
|
expectedHandled: true,
|
||||||
|
expectedStatus: 500, // matched → original status
|
||||||
handleErrorCalls: 1,
|
handleErrorCalls: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -225,6 +229,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 503,
|
statusCode: 503,
|
||||||
body: []byte(`overloaded`),
|
body: []byte(`overloaded`),
|
||||||
expectedHandled: true,
|
expectedHandled: true,
|
||||||
|
expectedStatus: 503, // temp_unscheduled → original status
|
||||||
expectedSwitchErr: true,
|
expectedSwitchErr: true,
|
||||||
handleErrorCalls: 0,
|
handleErrorCalls: 0,
|
||||||
},
|
},
|
||||||
@@ -250,9 +255,10 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
||||||
|
|
||||||
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
||||||
|
require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch")
|
||||||
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
||||||
|
|
||||||
if tt.expectedSwitchErr {
|
if tt.expectedSwitchErr {
|
||||||
|
|||||||
@@ -77,6 +77,9 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun
|
|||||||
func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForPlatform) Update(ctx context.Context, account *Account) error {
|
func (m *mockAccountRepoForPlatform) Update(ctx context.Context, account *Account) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -84,7 +87,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
|
|||||||
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||||
|
|||||||
@@ -101,9 +101,9 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// thinking: {type: "enabled"}
|
// thinking: {type: "enabled" | "adaptive"}
|
||||||
if rawThinking, ok := req["thinking"].(map[string]any); ok {
|
if rawThinking, ok := req["thinking"].(map[string]any); ok {
|
||||||
if t, ok := rawThinking["type"].(string); ok && t == "enabled" {
|
if t, ok := rawThinking["type"].(string); ok && (t == "enabled" || t == "adaptive") {
|
||||||
parsed.ThinkingEnabled = true
|
parsed.ThinkingEnabled = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -161,9 +161,9 @@ func parseIntegralNumber(raw any) (int, bool) {
|
|||||||
// Returns filtered body or original body if filtering fails (fail-safe)
|
// Returns filtered body or original body if filtering fails (fail-safe)
|
||||||
// This prevents 400 errors from invalid thinking block signatures
|
// This prevents 400 errors from invalid thinking block signatures
|
||||||
//
|
//
|
||||||
// Strategy:
|
// 策略:
|
||||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
// - 当 thinking.type 不是 "enabled"/"adaptive":移除所有 thinking 相关块
|
||||||
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
// - 当 thinking.type 是 "enabled"/"adaptive":仅移除缺失/无效 signature 的 thinking 块(避免 400)
|
||||||
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
|
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
|
||||||
func FilterThinkingBlocks(body []byte) []byte {
|
func FilterThinkingBlocks(body []byte) []byte {
|
||||||
return filterThinkingBlocksInternal(body, false)
|
return filterThinkingBlocksInternal(body, false)
|
||||||
@@ -489,9 +489,9 @@ func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// filterThinkingBlocksInternal removes invalid thinking blocks from request
|
// filterThinkingBlocksInternal removes invalid thinking blocks from request
|
||||||
// Strategy:
|
// 策略:
|
||||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
// - 当 thinking.type 不是 "enabled"/"adaptive":移除所有 thinking 相关块
|
||||||
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
// - 当 thinking.type 是 "enabled"/"adaptive":仅移除缺失/无效 signature 的 thinking 块
|
||||||
func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
||||||
// Fast path: if body doesn't contain "thinking", skip parsing
|
// Fast path: if body doesn't contain "thinking", skip parsing
|
||||||
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||||
@@ -511,7 +511,7 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
|||||||
// Check if thinking is enabled
|
// Check if thinking is enabled
|
||||||
thinkingEnabled := false
|
thinkingEnabled := false
|
||||||
if thinking, ok := req["thinking"].(map[string]any); ok {
|
if thinking, ok := req["thinking"].(map[string]any); ok {
|
||||||
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
|
if thinkType, ok := thinking["type"].(string); ok && (thinkType == "enabled" || thinkType == "adaptive") {
|
||||||
thinkingEnabled = true
|
thinkingEnabled = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,6 +29,14 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
|
|||||||
require.True(t, parsed.ThinkingEnabled)
|
require.True(t, parsed.ThinkingEnabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseGatewayRequest_ThinkingAdaptiveEnabled(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"adaptive"},"messages":[{"content":"hi"}]}`)
|
||||||
|
parsed, err := ParseGatewayRequest(body, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
|
||||||
|
require.True(t, parsed.ThinkingEnabled)
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
|
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
|
||||||
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
|
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
|
||||||
parsed, err := ParseGatewayRequest(body, "")
|
parsed, err := ParseGatewayRequest(body, "")
|
||||||
@@ -209,6 +217,16 @@ func TestFilterThinkingBlocks(t *testing.T) {
|
|||||||
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`,
|
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`,
|
||||||
shouldFilter: true,
|
shouldFilter: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "does not filter signed thinking blocks when thinking adaptive",
|
||||||
|
input: `{"thinking":{"type":"adaptive"},"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"ok","signature":"sig_real_123"},{"type":"text","text":"B"}]}]}`,
|
||||||
|
shouldFilter: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filters unsigned thinking blocks when thinking adaptive",
|
||||||
|
input: `{"thinking":{"type":"adaptive"},"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"internal","signature":""},{"type":"text","text":"B"}]}]}`,
|
||||||
|
shouldFilter: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "handles no thinking blocks",
|
name: "handles no thinking blocks",
|
||||||
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,
|
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,
|
||||||
|
|||||||
@@ -243,6 +243,12 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
|
||||||
|
// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除
|
||||||
|
var systemBlockFilterPrefixes = []string{
|
||||||
|
"x-anthropic-billing-header",
|
||||||
|
}
|
||||||
|
|
||||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||||
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||||
|
|
||||||
@@ -343,6 +349,8 @@ type ClaudeUsage struct {
|
|||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||||
|
CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
|
||||||
|
CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForwardResult 转发结果
|
// ForwardResult 转发结果
|
||||||
@@ -362,15 +370,31 @@ type ForwardResult struct {
|
|||||||
|
|
||||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||||
type UpstreamFailoverError struct {
|
type UpstreamFailoverError struct {
|
||||||
StatusCode int
|
StatusCode int
|
||||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||||
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
||||||
|
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *UpstreamFailoverError) Error() string {
|
func (e *UpstreamFailoverError) Error() string {
|
||||||
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
|
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。
|
||||||
|
// 由 handler 层在同账号重试全部用尽、切换账号时调用。
|
||||||
|
func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) {
|
||||||
|
if failoverErr == nil || !failoverErr.RetryableOnSameAccount {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 根据状态码选择封禁策略
|
||||||
|
switch failoverErr.StatusCode {
|
||||||
|
case http.StatusBadRequest:
|
||||||
|
tempUnscheduleGoogleConfigError(ctx, s.accountRepo, accountID, "[handler]")
|
||||||
|
case http.StatusBadGateway:
|
||||||
|
tempUnscheduleEmptyResponse(ctx, s.accountRepo, accountID, "[handler]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GatewayService handles API gateway operations
|
// GatewayService handles API gateway operations
|
||||||
type GatewayService struct {
|
type GatewayService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
@@ -1683,6 +1707,17 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
|||||||
return accounts, useMixed, nil
|
return accounts, useMixed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
|
||||||
|
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
|
||||||
|
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
|
||||||
|
func (s *GatewayService) IsSingleAntigravityAccountGroup(ctx context.Context, groupID *int64) bool {
|
||||||
|
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAntigravity, true)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return len(accounts) == 1
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
|
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return false
|
return false
|
||||||
@@ -2673,6 +2708,60 @@ func hasClaudeCodePrefix(text string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// matchesFilterPrefix 检查文本是否匹配任一过滤前缀
|
||||||
|
func matchesFilterPrefix(text string) bool {
|
||||||
|
for _, prefix := range systemBlockFilterPrefixes {
|
||||||
|
if strings.HasPrefix(text, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素
|
||||||
|
// 直接从 body 解析 system,不依赖外部传入的 parsed.System(因为前置步骤可能已修改 body 中的 system)
|
||||||
|
func filterSystemBlocksByPrefix(body []byte) []byte {
|
||||||
|
sys := gjson.GetBytes(body, "system")
|
||||||
|
if !sys.Exists() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case sys.Type == gjson.String:
|
||||||
|
if matchesFilterPrefix(sys.Str) {
|
||||||
|
result, err := sjson.DeleteBytes(body, "system")
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
case sys.IsArray():
|
||||||
|
var parsed []any
|
||||||
|
if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
filtered := make([]any, 0, len(parsed))
|
||||||
|
changed := false
|
||||||
|
for _, item := range parsed {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) {
|
||||||
|
changed = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
}
|
||||||
|
if changed {
|
||||||
|
result, err := sjson.SetBytes(body, "system", filtered)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
|
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
|
||||||
// 处理 null、字符串、数组三种格式
|
// 处理 null、字符串、数组三种格式
|
||||||
func injectClaudeCodePrompt(body []byte, system any) []byte {
|
func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||||
@@ -2952,6 +3041,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据)
|
||||||
|
// 放在 inject/normalize 之后,确保不会被覆盖
|
||||||
|
if account.IsOAuth() {
|
||||||
|
body = filterSystemBlocksByPrefix(body)
|
||||||
|
}
|
||||||
|
|
||||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||||
body = enforceCacheControlLimit(body)
|
body = enforceCacheControlLimit(body)
|
||||||
|
|
||||||
@@ -3538,7 +3633,8 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
|||||||
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
|
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
|
thinkingType := gjson.GetBytes(body, "thinking.type").String()
|
||||||
|
if strings.EqualFold(thinkingType, "enabled") || strings.EqualFold(thinkingType, "adaptive") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -4307,6 +4403,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
|||||||
usage.InputTokens = msgStart.Message.Usage.InputTokens
|
usage.InputTokens = msgStart.Message.Usage.InputTokens
|
||||||
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
|
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
|
||||||
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
|
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
|
||||||
|
|
||||||
|
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||||
|
cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens")
|
||||||
|
cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens")
|
||||||
|
if cc5m.Exists() || cc1h.Exists() {
|
||||||
|
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||||
|
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
|
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
|
||||||
@@ -4335,6 +4439,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
|||||||
if msgDelta.Usage.CacheReadInputTokens > 0 {
|
if msgDelta.Usage.CacheReadInputTokens > 0 {
|
||||||
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||||
|
cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens")
|
||||||
|
cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens")
|
||||||
|
if cc5m.Exists() || cc1h.Exists() {
|
||||||
|
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||||
|
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4355,6 +4467,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
return nil, fmt.Errorf("parse response: %w", err)
|
return nil, fmt.Errorf("parse response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 解析嵌套的 cache_creation 对象中的 5m/1h 明细
|
||||||
|
cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens")
|
||||||
|
cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens")
|
||||||
|
if cc5m.Exists() || cc1h.Exists() {
|
||||||
|
response.Usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||||
|
response.Usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||||
|
}
|
||||||
|
|
||||||
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
||||||
if response.Usage.CacheReadInputTokens == 0 {
|
if response.Usage.CacheReadInputTokens == 0 {
|
||||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||||
@@ -4472,10 +4592,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
} else {
|
} else {
|
||||||
// Token 计费
|
// Token 计费
|
||||||
tokens := UsageTokens{
|
tokens := UsageTokens{
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
|
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||||
|
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||||
@@ -4509,6 +4631,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
|
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||||
|
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||||
InputCost: cost.InputCost,
|
InputCost: cost.InputCost,
|
||||||
OutputCost: cost.OutputCost,
|
OutputCost: cost.OutputCost,
|
||||||
CacheCreationCost: cost.CacheCreationCost,
|
CacheCreationCost: cost.CacheCreationCost,
|
||||||
@@ -4653,10 +4777,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
} else {
|
} else {
|
||||||
// Token 计费(使用长上下文计费方法)
|
// Token 计费(使用长上下文计费方法)
|
||||||
tokens := UsageTokens{
|
tokens := UsageTokens{
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
|
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||||
|
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||||
@@ -4690,6 +4816,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
|
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
||||||
|
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||||
InputCost: cost.InputCost,
|
InputCost: cost.InputCost,
|
||||||
OutputCost: cost.OutputCost,
|
OutputCost: cost.OutputCost,
|
||||||
CacheCreationCost: cost.CacheCreationCost,
|
CacheCreationCost: cost.CacheCreationCost,
|
||||||
|
|||||||
@@ -770,6 +770,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 错误策略优先:匹配则跳过重试直接处理。
|
||||||
|
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
|
||||||
|
resp = rebuilt
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
resp = rebuilt
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
@@ -839,7 +847,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
if upstreamReqID == "" {
|
if upstreamReqID == "" {
|
||||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
}
|
}
|
||||||
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
|
return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody)
|
||||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
@@ -872,6 +880,37 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
|
|
||||||
// ErrorPolicyNone → 原有逻辑
|
// ErrorPolicyNone → 原有逻辑
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||||
|
if isGoogleProjectConfigError(msg400) {
|
||||||
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
|
if upstreamReqID == "" {
|
||||||
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
|
}
|
||||||
|
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: upstreamReqID,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true}
|
||||||
|
}
|
||||||
|
}
|
||||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
if upstreamReqID == "" {
|
if upstreamReqID == "" {
|
||||||
@@ -1176,6 +1215,14 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
|
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 错误策略优先:匹配则跳过重试直接处理。
|
||||||
|
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
|
||||||
|
resp = rebuilt
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
resp = rebuilt
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
@@ -1283,7 +1330,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
if contentType == "" {
|
if contentType == "" {
|
||||||
contentType = "application/json"
|
contentType = "application/json"
|
||||||
}
|
}
|
||||||
c.Data(resp.StatusCode, contentType, respBody)
|
c.Data(http.StatusInternalServerError, contentType, respBody)
|
||||||
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
|
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
|
||||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
@@ -1314,6 +1361,34 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
|
|
||||||
// ErrorPolicyNone → 原有逻辑
|
// ErrorPolicyNone → 原有逻辑
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||||
|
if isGoogleProjectConfigError(msg400) {
|
||||||
|
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||||
|
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody)))
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||||
|
}
|
||||||
|
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: requestID,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody, RetryableOnSameAccount: true}
|
||||||
|
}
|
||||||
|
}
|
||||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||||
@@ -1425,6 +1500,26 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkErrorPolicyInLoop 在重试循环内预检查错误策略。
|
||||||
|
// 返回 true 表示策略已匹配(调用者应 break),resp 已重建可直接使用。
|
||||||
|
// 返回 false 表示 ErrorPolicyNone,resp 已重建,调用者继续走重试逻辑。
|
||||||
|
func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop(
|
||||||
|
ctx context.Context, account *Account, resp *http.Response,
|
||||||
|
) (matched bool, rebuilt *http.Response) {
|
||||||
|
if resp.StatusCode < 400 || s.rateLimitService == nil {
|
||||||
|
return false, resp
|
||||||
|
}
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
rebuilt = &http.Response{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Header: resp.Header.Clone(),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(body)),
|
||||||
|
}
|
||||||
|
policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body)
|
||||||
|
return policy != ErrorPolicyNone, rebuilt
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
|
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case 429, 500, 502, 503, 504, 529:
|
case 429, 500, 502, 503, 504, 529:
|
||||||
@@ -2568,11 +2663,12 @@ func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
|
|||||||
prompt, _ := asInt(usageMeta["promptTokenCount"])
|
prompt, _ := asInt(usageMeta["promptTokenCount"])
|
||||||
cand, _ := asInt(usageMeta["candidatesTokenCount"])
|
cand, _ := asInt(usageMeta["candidatesTokenCount"])
|
||||||
cached, _ := asInt(usageMeta["cachedContentTokenCount"])
|
cached, _ := asInt(usageMeta["cachedContentTokenCount"])
|
||||||
|
thoughts, _ := asInt(usageMeta["thoughtsTokenCount"])
|
||||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||||
return &ClaudeUsage{
|
return &ClaudeUsage{
|
||||||
InputTokens: prompt - cached,
|
InputTokens: prompt - cached,
|
||||||
OutputTokens: cand,
|
OutputTokens: cand + thoughts,
|
||||||
CacheReadInputTokens: cached,
|
CacheReadInputTokens: cached,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2597,6 +2693,10 @@ func asInt(v any) (int, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||||
|
// 遵守自定义错误码策略:未命中则跳过所有限流处理
|
||||||
|
if !account.ShouldHandleErrorCode(statusCode) {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
|
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
||||||
@@ -203,3 +205,70 @@ func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing
|
|||||||
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractGeminiUsage_ThoughtsTokenCount(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
resp map[string]any
|
||||||
|
wantInput int
|
||||||
|
wantOutput int
|
||||||
|
wantCacheRead int
|
||||||
|
wantNil bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with thoughtsTokenCount",
|
||||||
|
resp: map[string]any{
|
||||||
|
"usageMetadata": map[string]any{
|
||||||
|
"promptTokenCount": float64(100),
|
||||||
|
"candidatesTokenCount": float64(20),
|
||||||
|
"thoughtsTokenCount": float64(50),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantInput: 100,
|
||||||
|
wantOutput: 70,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with thoughtsTokenCount and cache",
|
||||||
|
resp: map[string]any{
|
||||||
|
"usageMetadata": map[string]any{
|
||||||
|
"promptTokenCount": float64(100),
|
||||||
|
"candidatesTokenCount": float64(20),
|
||||||
|
"cachedContentTokenCount": float64(30),
|
||||||
|
"thoughtsTokenCount": float64(50),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantInput: 70,
|
||||||
|
wantOutput: 70,
|
||||||
|
wantCacheRead: 30,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without thoughtsTokenCount (old model)",
|
||||||
|
resp: map[string]any{
|
||||||
|
"usageMetadata": map[string]any{
|
||||||
|
"promptTokenCount": float64(100),
|
||||||
|
"candidatesTokenCount": float64(20),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantInput: 100,
|
||||||
|
wantOutput: 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no usageMetadata",
|
||||||
|
resp: map[string]any{},
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
usage := extractGeminiUsage(tt.resp)
|
||||||
|
if tt.wantNil {
|
||||||
|
require.Nil(t, usage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.Equal(t, tt.wantInput, usage.InputTokens)
|
||||||
|
require.Equal(t, tt.wantOutput, usage.OutputTokens)
|
||||||
|
require.Equal(t, tt.wantCacheRead, usage.CacheReadInputTokens)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -66,12 +66,15 @@ func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account)
|
|||||||
func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil }
|
func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil }
|
||||||
func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
|
func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
|
||||||
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
|
|||||||
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
|
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
|
||||||
Page: page,
|
Page: page,
|
||||||
PageSize: opsAccountsPageSize,
|
PageSize: opsAccountsPageSize,
|
||||||
}, platformFilter, "", "", "")
|
}, platformFilter, "", "", "", 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,10 @@ const (
|
|||||||
// retry the specific upstream attempt (not just the client request).
|
// retry the specific upstream attempt (not just the client request).
|
||||||
// This value is sanitized+trimmed before being persisted.
|
// This value is sanitized+trimmed before being persisted.
|
||||||
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
|
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
|
||||||
|
|
||||||
|
// OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。
|
||||||
|
// ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。
|
||||||
|
OpsSkipPassthroughKey = "ops_skip_passthrough"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
|
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
|
||||||
@@ -103,6 +107,37 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
|
|||||||
evCopy := ev
|
evCopy := ev
|
||||||
existing = append(existing, &evCopy)
|
existing = append(existing, &evCopy)
|
||||||
c.Set(OpsUpstreamErrorsKey, existing)
|
c.Set(OpsUpstreamErrorsKey, existing)
|
||||||
|
|
||||||
|
checkSkipMonitoringForUpstreamEvent(c, &evCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSkipMonitoringForUpstreamEvent checks whether the upstream error event
|
||||||
|
// matches a passthrough rule with skip_monitoring=true and, if so, sets the
|
||||||
|
// OpsSkipPassthroughKey on the context. This ensures intermediate retry /
|
||||||
|
// failover errors (which never go through the final applyErrorPassthroughRule
|
||||||
|
// path) can still suppress ops_error_logs recording.
|
||||||
|
func checkSkipMonitoringForUpstreamEvent(c *gin.Context, ev *OpsUpstreamErrorEvent) {
|
||||||
|
if ev.UpstreamStatusCode == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := getBoundErrorPassthroughService(c)
|
||||||
|
if svc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the best available body representation for keyword matching.
|
||||||
|
// Even when body is empty, MatchRule can still match rules that only
|
||||||
|
// specify ErrorCodes (no Keywords), so we always call it.
|
||||||
|
body := ev.Detail
|
||||||
|
if body == "" {
|
||||||
|
body = ev.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := svc.MatchRule(ev.Platform, ev.UpstreamStatusCode, []byte(body))
|
||||||
|
if rule != nil && rule.SkipMonitoring {
|
||||||
|
c.Set(OpsSkipPassthroughKey, true)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {
|
func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {
|
||||||
|
|||||||
@@ -27,14 +27,15 @@ var (
|
|||||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||||
type LiteLLMModelPricing struct {
|
type LiteLLMModelPricing struct {
|
||||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||||
LiteLLMProvider string `json:"litellm_provider"`
|
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||||
Mode string `json:"mode"`
|
LiteLLMProvider string `json:"litellm_provider"`
|
||||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
Mode string `json:"mode"`
|
||||||
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||||
|
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
||||||
}
|
}
|
||||||
|
|
||||||
// PricingRemoteClient 远程价格数据获取接口
|
// PricingRemoteClient 远程价格数据获取接口
|
||||||
@@ -45,14 +46,15 @@ type PricingRemoteClient interface {
|
|||||||
|
|
||||||
// LiteLLMRawEntry 用于解析原始JSON数据
|
// LiteLLMRawEntry 用于解析原始JSON数据
|
||||||
type LiteLLMRawEntry struct {
|
type LiteLLMRawEntry struct {
|
||||||
InputCostPerToken *float64 `json:"input_cost_per_token"`
|
InputCostPerToken *float64 `json:"input_cost_per_token"`
|
||||||
OutputCostPerToken *float64 `json:"output_cost_per_token"`
|
OutputCostPerToken *float64 `json:"output_cost_per_token"`
|
||||||
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
|
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
|
||||||
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
|
CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"`
|
||||||
LiteLLMProvider string `json:"litellm_provider"`
|
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
|
||||||
Mode string `json:"mode"`
|
LiteLLMProvider string `json:"litellm_provider"`
|
||||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
Mode string `json:"mode"`
|
||||||
OutputCostPerImage *float64 `json:"output_cost_per_image"`
|
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||||
|
OutputCostPerImage *float64 `json:"output_cost_per_image"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PricingService 动态价格服务
|
// PricingService 动态价格服务
|
||||||
@@ -318,6 +320,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
|||||||
if entry.CacheCreationInputTokenCost != nil {
|
if entry.CacheCreationInputTokenCost != nil {
|
||||||
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
|
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
|
||||||
}
|
}
|
||||||
|
if entry.CacheCreationInputTokenCostAbove1hr != nil {
|
||||||
|
pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr
|
||||||
|
}
|
||||||
if entry.CacheReadInputTokenCost != nil {
|
if entry.CacheReadInputTokenCost != nil {
|
||||||
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
|
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -381,10 +381,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 尝试从响应头解析重置时间(Anthropic)
|
// 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口
|
||||||
|
if result := calculateAnthropic429ResetTime(headers); result != nil {
|
||||||
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
|
||||||
|
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新 session window:优先使用 5h-reset 头精确计算,否则从 resetAt 反推
|
||||||
|
windowEnd := result.resetAt
|
||||||
|
if result.fiveHourReset != nil {
|
||||||
|
windowEnd = *result.fiveHourReset
|
||||||
|
}
|
||||||
|
windowStart := windowEnd.Add(-5 * time.Hour)
|
||||||
|
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||||
|
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 尝试从响应头解析重置时间(Anthropic 聚合头,向后兼容)
|
||||||
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
||||||
|
|
||||||
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
|
// 4. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
|
||||||
if resetTimestamp == "" {
|
if resetTimestamp == "" {
|
||||||
switch account.Platform {
|
switch account.Platform {
|
||||||
case PlatformOpenAI:
|
case PlatformOpenAI:
|
||||||
@@ -497,6 +518,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
|
||||||
|
type anthropic429Result struct {
|
||||||
|
resetAt time.Time // The correct reset time to use for SetRateLimited
|
||||||
|
fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers
|
||||||
|
// to determine which window (5h or 7d) actually triggered the 429.
|
||||||
|
//
|
||||||
|
// Headers used:
|
||||||
|
// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold
|
||||||
|
// - anthropic-ratelimit-unified-5h-reset
|
||||||
|
// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold
|
||||||
|
// - anthropic-ratelimit-unified-7d-reset
|
||||||
|
//
|
||||||
|
// Returns nil when the per-window headers are absent (caller should fall back to
|
||||||
|
// the aggregated anthropic-ratelimit-unified-reset header).
|
||||||
|
func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result {
|
||||||
|
reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset")
|
||||||
|
reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset")
|
||||||
|
|
||||||
|
if reset5hStr == "" && reset7dStr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var reset5h, reset7d *time.Time
|
||||||
|
if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil {
|
||||||
|
t := time.Unix(ts, 0)
|
||||||
|
reset5h = &t
|
||||||
|
}
|
||||||
|
if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil {
|
||||||
|
t := time.Unix(ts, 0)
|
||||||
|
reset7d = &t
|
||||||
|
}
|
||||||
|
|
||||||
|
is5hExceeded := isAnthropicWindowExceeded(headers, "5h")
|
||||||
|
is7dExceeded := isAnthropicWindowExceeded(headers, "7d")
|
||||||
|
|
||||||
|
slog.Info("anthropic_429_window_analysis",
|
||||||
|
"is_5h_exceeded", is5hExceeded,
|
||||||
|
"is_7d_exceeded", is7dExceeded,
|
||||||
|
"reset_5h", reset5hStr,
|
||||||
|
"reset_7d", reset7dStr,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Select the correct reset time based on which window(s) are exceeded.
|
||||||
|
var chosen *time.Time
|
||||||
|
switch {
|
||||||
|
case is5hExceeded && is7dExceeded:
|
||||||
|
// Both exceeded → prefer 7d (longer cooldown), fall back to 5h
|
||||||
|
chosen = reset7d
|
||||||
|
if chosen == nil {
|
||||||
|
chosen = reset5h
|
||||||
|
}
|
||||||
|
case is5hExceeded:
|
||||||
|
chosen = reset5h
|
||||||
|
case is7dExceeded:
|
||||||
|
chosen = reset7d
|
||||||
|
default:
|
||||||
|
// Neither flag clearly exceeded — pick the sooner reset as best guess
|
||||||
|
chosen = pickSooner(reset5h, reset7d)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chosen == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window
|
||||||
|
// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers.
|
||||||
|
func isAnthropicWindowExceeded(headers http.Header, window string) bool {
|
||||||
|
prefix := "anthropic-ratelimit-unified-" + window + "-"
|
||||||
|
|
||||||
|
// Check surpassed-threshold first (most explicit signal)
|
||||||
|
if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to utilization >= 1.0
|
||||||
|
if utilStr := headers.Get(prefix + "utilization"); utilStr != "" {
|
||||||
|
if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 {
|
||||||
|
// Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickSooner returns whichever of the two time pointers is earlier.
|
||||||
|
// If only one is non-nil, it is returned. If both are nil, returns nil.
|
||||||
|
func pickSooner(a, b *time.Time) *time.Time {
|
||||||
|
switch {
|
||||||
|
case a != nil && b != nil:
|
||||||
|
if a.Before(*b) {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
case a != nil:
|
||||||
|
return a
|
||||||
|
default:
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||||
// OpenAI 的 usage_limit_reached 错误格式:
|
// OpenAI 的 usage_limit_reached 错误格式:
|
||||||
//
|
//
|
||||||
@@ -623,6 +750,10 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
|
|||||||
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
|
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 同时清除模型级别限流
|
||||||
|
if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil {
|
||||||
|
slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
202
backend/internal/service/ratelimit_service_anthropic_test.go
Normal file
202
backend/internal/service/ratelimit_service_anthropic_test.go
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_Only5hExceeded(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.02")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.32")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1770998400)
|
||||||
|
|
||||||
|
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
|
||||||
|
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_Only7dExceeded(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.50")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.05")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1771549200)
|
||||||
|
|
||||||
|
// fiveHourReset should still be populated for session window calculation
|
||||||
|
if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
|
||||||
|
t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_BothExceeded(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.10")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.02")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1771549200)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_NoPerWindowHeaders(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("expected nil result when no per-window headers, got resetAt=%v", result.resetAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_NoHeaders(t *testing.T) {
|
||||||
|
result := calculateAnthropic429ResetTime(http.Header{})
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("expected nil result for empty headers, got resetAt=%v", result.resetAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_SurpassedThreshold(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "false")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1770998400)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_UtilizationExactlyOne(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.0")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.5")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1770998400)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_NeitherExceeded_UsesShorter(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.95")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") // sooner
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.80")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") // later
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1770998400)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_Only5hResetHeader(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.05")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1770998400)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateAnthropic429ResetTime_Only7dResetHeader(t *testing.T) {
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.03")
|
||||||
|
headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
|
||||||
|
|
||||||
|
result := calculateAnthropic429ResetTime(headers)
|
||||||
|
assertAnthropicResult(t, result, 1771549200)
|
||||||
|
|
||||||
|
if result.fiveHourReset != nil {
|
||||||
|
t.Errorf("expected fiveHourReset=nil when no 5h headers, got %v", result.fiveHourReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAnthropicWindowExceeded(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
headers http.Header
|
||||||
|
window string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "utilization above 1.0",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.02"),
|
||||||
|
window: "5h",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "utilization exactly 1.0",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.0"),
|
||||||
|
window: "5h",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "utilization below 1.0",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "0.99"),
|
||||||
|
window: "5h",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "surpassed-threshold true",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "true"),
|
||||||
|
window: "7d",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "surpassed-threshold True (case insensitive)",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "True"),
|
||||||
|
window: "7d",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "surpassed-threshold false",
|
||||||
|
headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "false"),
|
||||||
|
window: "7d",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no headers",
|
||||||
|
headers: http.Header{},
|
||||||
|
window: "5h",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := isAnthropicWindowExceeded(tc.headers, tc.window)
|
||||||
|
if got != tc.expected {
|
||||||
|
t.Errorf("expected %v, got %v", tc.expected, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertAnthropicResult is a test helper that verifies the result is non-nil and
|
||||||
|
// has the expected resetAt unix timestamp.
|
||||||
|
func assertAnthropicResult(t *testing.T, result *anthropic429Result, wantUnix int64) {
|
||||||
|
t.Helper()
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected non-nil result")
|
||||||
|
return // unreachable, but satisfies staticcheck SA5011
|
||||||
|
}
|
||||||
|
want := time.Unix(wantUnix, 0)
|
||||||
|
if !result.resetAt.Equal(want) {
|
||||||
|
t.Errorf("expected resetAt=%v, got %v", want, result.resetAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeHeader(key, value string) http.Header {
|
||||||
|
h := http.Header{}
|
||||||
|
h.Set(key, value)
|
||||||
|
return h
|
||||||
|
}
|
||||||
@@ -26,8 +26,8 @@ type UsageLog struct {
|
|||||||
CacheCreationTokens int
|
CacheCreationTokens int
|
||||||
CacheReadTokens int
|
CacheReadTokens int
|
||||||
|
|
||||||
CacheCreation5mTokens int
|
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
|
||||||
CacheCreation1hTokens int
|
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
|
||||||
|
|
||||||
InputCost float64
|
InputCost float64
|
||||||
OutputCost float64
|
OutputCost float64
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
-- Add skip_monitoring field to error_passthrough_rules table
|
||||||
|
-- When true, errors matching this rule will not be recorded in ops_error_logs
|
||||||
|
ALTER TABLE error_passthrough_rules
|
||||||
|
ADD COLUMN IF NOT EXISTS skip_monitoring BOOLEAN NOT NULL DEFAULT false;
|
||||||
14
backend/migrations/054_drop_legacy_cache_columns.sql
Normal file
14
backend/migrations/054_drop_legacy_cache_columns.sql
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
-- Drop legacy cache token columns that lack the underscore separator.
|
||||||
|
-- These were created by GORM's automatic snake_case conversion:
|
||||||
|
-- CacheCreation5mTokens → cache_creation5m_tokens (incorrect)
|
||||||
|
-- CacheCreation1hTokens → cache_creation1h_tokens (incorrect)
|
||||||
|
--
|
||||||
|
-- The canonical columns are:
|
||||||
|
-- cache_creation_5m_tokens (defined in 001_init.sql)
|
||||||
|
-- cache_creation_1h_tokens (defined in 001_init.sql)
|
||||||
|
--
|
||||||
|
-- Migration 009 already copied data from legacy → canonical columns.
|
||||||
|
-- This migration drops the legacy columns to avoid confusion.
|
||||||
|
|
||||||
|
ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation5m_tokens;
|
||||||
|
ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation1h_tokens;
|
||||||
@@ -158,6 +158,7 @@ services:
|
|||||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
||||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
||||||
|
- PGDATA=/var/lib/postgresql/data
|
||||||
- TZ=${TZ:-Asia/Shanghai}
|
- TZ=${TZ:-Asia/Shanghai}
|
||||||
networks:
|
networks:
|
||||||
- sub2api-network
|
- sub2api-network
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@lobehub/icons": "^4.0.2",
|
"@lobehub/icons": "^4.0.2",
|
||||||
"@vueuse/core": "^10.7.0",
|
"@vueuse/core": "^10.7.0",
|
||||||
"axios": "^1.6.2",
|
"axios": "^1.13.5",
|
||||||
"chart.js": "^4.4.1",
|
"chart.js": "^4.4.1",
|
||||||
"dompurify": "^3.3.1",
|
"dompurify": "^3.3.1",
|
||||||
"driver.js": "^1.4.0",
|
"driver.js": "^1.4.0",
|
||||||
|
|||||||
21
frontend/pnpm-lock.yaml
generated
21
frontend/pnpm-lock.yaml
generated
@@ -15,8 +15,8 @@ importers:
|
|||||||
specifier: ^10.7.0
|
specifier: ^10.7.0
|
||||||
version: 10.11.1(vue@3.5.26(typescript@5.6.3))
|
version: 10.11.1(vue@3.5.26(typescript@5.6.3))
|
||||||
axios:
|
axios:
|
||||||
specifier: ^1.6.2
|
specifier: ^1.13.5
|
||||||
version: 1.13.2
|
version: 1.13.5
|
||||||
chart.js:
|
chart.js:
|
||||||
specifier: ^4.4.1
|
specifier: ^4.4.1
|
||||||
version: 4.5.1
|
version: 4.5.1
|
||||||
@@ -1257,56 +1257,67 @@ packages:
|
|||||||
resolution: {integrity: sha512-EHMUcDwhtdRGlXZsGSIuXSYwD5kOT9NVnx9sqzYiwAc91wfYOE1g1djOEDseZJKKqtHAHGwnGPQu3kytmfaXLQ==}
|
resolution: {integrity: sha512-EHMUcDwhtdRGlXZsGSIuXSYwD5kOT9NVnx9sqzYiwAc91wfYOE1g1djOEDseZJKKqtHAHGwnGPQu3kytmfaXLQ==}
|
||||||
cpu: [arm]
|
cpu: [arm]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
libc: [glibc]
|
||||||
|
|
||||||
'@rollup/rollup-linux-arm-musleabihf@4.54.0':
|
'@rollup/rollup-linux-arm-musleabihf@4.54.0':
|
||||||
resolution: {integrity: sha512-+pBrqEjaakN2ySv5RVrj/qLytYhPKEUwk+e3SFU5jTLHIcAtqh2rLrd/OkbNuHJpsBgxsD8ccJt5ga/SeG0JmA==}
|
resolution: {integrity: sha512-+pBrqEjaakN2ySv5RVrj/qLytYhPKEUwk+e3SFU5jTLHIcAtqh2rLrd/OkbNuHJpsBgxsD8ccJt5ga/SeG0JmA==}
|
||||||
cpu: [arm]
|
cpu: [arm]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
libc: [musl]
|
||||||
|
|
||||||
'@rollup/rollup-linux-arm64-gnu@4.54.0':
|
'@rollup/rollup-linux-arm64-gnu@4.54.0':
|
||||||
resolution: {integrity: sha512-NSqc7rE9wuUaRBsBp5ckQ5CVz5aIRKCwsoa6WMF7G01sX3/qHUw/z4pv+D+ahL1EIKy6Enpcnz1RY8pf7bjwng==}
|
resolution: {integrity: sha512-NSqc7rE9wuUaRBsBp5ckQ5CVz5aIRKCwsoa6WMF7G01sX3/qHUw/z4pv+D+ahL1EIKy6Enpcnz1RY8pf7bjwng==}
|
||||||
cpu: [arm64]
|
cpu: [arm64]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
libc: [glibc]
|
||||||
|
|
||||||
'@rollup/rollup-linux-arm64-musl@4.54.0':
|
'@rollup/rollup-linux-arm64-musl@4.54.0':
|
||||||
resolution: {integrity: sha512-gr5vDbg3Bakga5kbdpqx81m2n9IX8M6gIMlQQIXiLTNeQW6CucvuInJ91EuCJ/JYvc+rcLLsDFcfAD1K7fMofg==}
|
resolution: {integrity: sha512-gr5vDbg3Bakga5kbdpqx81m2n9IX8M6gIMlQQIXiLTNeQW6CucvuInJ91EuCJ/JYvc+rcLLsDFcfAD1K7fMofg==}
|
||||||
cpu: [arm64]
|
cpu: [arm64]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
libc: [musl]
|
||||||
|
|
||||||
'@rollup/rollup-linux-loong64-gnu@4.54.0':
|
'@rollup/rollup-linux-loong64-gnu@4.54.0':
|
||||||
resolution: {integrity: sha512-gsrtB1NA3ZYj2vq0Rzkylo9ylCtW/PhpLEivlgWe0bpgtX5+9j9EZa0wtZiCjgu6zmSeZWyI/e2YRX1URozpIw==}
|
resolution: {integrity: sha512-gsrtB1NA3ZYj2vq0Rzkylo9ylCtW/PhpLEivlgWe0bpgtX5+9j9EZa0wtZiCjgu6zmSeZWyI/e2YRX1URozpIw==}
|
||||||
cpu: [loong64]
|
cpu: [loong64]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
libc: [glibc]
|
||||||
|
|
||||||
'@rollup/rollup-linux-ppc64-gnu@4.54.0':
|
'@rollup/rollup-linux-ppc64-gnu@4.54.0':
|
||||||
resolution: {integrity: sha512-y3qNOfTBStmFNq+t4s7Tmc9hW2ENtPg8FeUD/VShI7rKxNW7O4fFeaYbMsd3tpFlIg1Q8IapFgy7Q9i2BqeBvA==}
|
resolution: {integrity: sha512-y3qNOfTBStmFNq+t4s7Tmc9hW2ENtPg8FeUD/VShI7rKxNW7O4fFeaYbMsd3tpFlIg1Q8IapFgy7Q9i2BqeBvA==}
|
||||||
cpu: [ppc64]
|
cpu: [ppc64]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
libc: [glibc]
|
||||||
|
|
||||||
'@rollup/rollup-linux-riscv64-gnu@4.54.0':
|
'@rollup/rollup-linux-riscv64-gnu@4.54.0':
|
||||||
resolution: {integrity: sha512-89sepv7h2lIVPsFma8iwmccN7Yjjtgz0Rj/Ou6fEqg3HDhpCa+Et+YSufy27i6b0Wav69Qv4WBNl3Rs6pwhebQ==}
|
resolution: {integrity: sha512-89sepv7h2lIVPsFma8iwmccN7Yjjtgz0Rj/Ou6fEqg3HDhpCa+Et+YSufy27i6b0Wav69Qv4WBNl3Rs6pwhebQ==}
|
||||||
cpu: [riscv64]
|
cpu: [riscv64]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
libc: [glibc]
|
||||||
|
|
||||||
'@rollup/rollup-linux-riscv64-musl@4.54.0':
|
'@rollup/rollup-linux-riscv64-musl@4.54.0':
|
||||||
resolution: {integrity: sha512-ZcU77ieh0M2Q8Ur7D5X7KvK+UxbXeDHwiOt/CPSBTI1fBmeDMivW0dPkdqkT4rOgDjrDDBUed9x4EgraIKoR2A==}
|
resolution: {integrity: sha512-ZcU77ieh0M2Q8Ur7D5X7KvK+UxbXeDHwiOt/CPSBTI1fBmeDMivW0dPkdqkT4rOgDjrDDBUed9x4EgraIKoR2A==}
|
||||||
cpu: [riscv64]
|
cpu: [riscv64]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
libc: [musl]
|
||||||
|
|
||||||
'@rollup/rollup-linux-s390x-gnu@4.54.0':
|
'@rollup/rollup-linux-s390x-gnu@4.54.0':
|
||||||
resolution: {integrity: sha512-2AdWy5RdDF5+4YfG/YesGDDtbyJlC9LHmL6rZw6FurBJ5n4vFGupsOBGfwMRjBYH7qRQowT8D/U4LoSvVwOhSQ==}
|
resolution: {integrity: sha512-2AdWy5RdDF5+4YfG/YesGDDtbyJlC9LHmL6rZw6FurBJ5n4vFGupsOBGfwMRjBYH7qRQowT8D/U4LoSvVwOhSQ==}
|
||||||
cpu: [s390x]
|
cpu: [s390x]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
libc: [glibc]
|
||||||
|
|
||||||
'@rollup/rollup-linux-x64-gnu@4.54.0':
|
'@rollup/rollup-linux-x64-gnu@4.54.0':
|
||||||
resolution: {integrity: sha512-WGt5J8Ij/rvyqpFexxk3ffKqqbLf9AqrTBbWDk7ApGUzaIs6V+s2s84kAxklFwmMF/vBNGrVdYgbblCOFFezMQ==}
|
resolution: {integrity: sha512-WGt5J8Ij/rvyqpFexxk3ffKqqbLf9AqrTBbWDk7ApGUzaIs6V+s2s84kAxklFwmMF/vBNGrVdYgbblCOFFezMQ==}
|
||||||
cpu: [x64]
|
cpu: [x64]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
libc: [glibc]
|
||||||
|
|
||||||
'@rollup/rollup-linux-x64-musl@4.54.0':
|
'@rollup/rollup-linux-x64-musl@4.54.0':
|
||||||
resolution: {integrity: sha512-JzQmb38ATzHjxlPHuTH6tE7ojnMKM2kYNzt44LO/jJi8BpceEC8QuXYA908n8r3CNuG/B3BV8VR3Hi1rYtmPiw==}
|
resolution: {integrity: sha512-JzQmb38ATzHjxlPHuTH6tE7ojnMKM2kYNzt44LO/jJi8BpceEC8QuXYA908n8r3CNuG/B3BV8VR3Hi1rYtmPiw==}
|
||||||
cpu: [x64]
|
cpu: [x64]
|
||||||
os: [linux]
|
os: [linux]
|
||||||
|
libc: [musl]
|
||||||
|
|
||||||
'@rollup/rollup-openharmony-arm64@4.54.0':
|
'@rollup/rollup-openharmony-arm64@4.54.0':
|
||||||
resolution: {integrity: sha512-huT3fd0iC7jigGh7n3q/+lfPcXxBi+om/Rs3yiFxjvSxbSB6aohDFXbWvlspaqjeOh+hx7DDHS+5Es5qRkWkZg==}
|
resolution: {integrity: sha512-huT3fd0iC7jigGh7n3q/+lfPcXxBi+om/Rs3yiFxjvSxbSB6aohDFXbWvlspaqjeOh+hx7DDHS+5Es5qRkWkZg==}
|
||||||
@@ -1805,8 +1816,8 @@ packages:
|
|||||||
peerDependencies:
|
peerDependencies:
|
||||||
postcss: ^8.1.0
|
postcss: ^8.1.0
|
||||||
|
|
||||||
axios@1.13.2:
|
axios@1.13.5:
|
||||||
resolution: {integrity: sha512-VPk9ebNqPcy5lRGuSlKx752IlDatOjT9paPlm8A7yOuW2Fbvp4X3JznJtT4f0GzGLLiWE9W8onz51SqLYwzGaA==}
|
resolution: {integrity: sha512-cz4ur7Vb0xS4/KUN0tPWe44eqxrIu31me+fbang3ijiNscE129POzipJJA6zniq2C/Z6sJCjMimjS8Lc/GAs8Q==}
|
||||||
|
|
||||||
babel-plugin-macros@3.1.0:
|
babel-plugin-macros@3.1.0:
|
||||||
resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==}
|
resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==}
|
||||||
@@ -6387,7 +6398,7 @@ snapshots:
|
|||||||
postcss: 8.5.6
|
postcss: 8.5.6
|
||||||
postcss-value-parser: 4.2.0
|
postcss-value-parser: 4.2.0
|
||||||
|
|
||||||
axios@1.13.2:
|
axios@1.13.5:
|
||||||
dependencies:
|
dependencies:
|
||||||
follow-redirects: 1.15.11
|
follow-redirects: 1.15.11
|
||||||
form-data: 4.0.5
|
form-data: 4.0.5
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ export async function list(
|
|||||||
platform?: string
|
platform?: string
|
||||||
type?: string
|
type?: string
|
||||||
status?: string
|
status?: string
|
||||||
|
group?: string
|
||||||
search?: string
|
search?: string
|
||||||
},
|
},
|
||||||
options?: {
|
options?: {
|
||||||
@@ -327,11 +328,34 @@ export async function getAvailableModels(id: number): Promise<ClaudeModel[]> {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface CRSPreviewAccount {
|
||||||
|
crs_account_id: string
|
||||||
|
kind: string
|
||||||
|
name: string
|
||||||
|
platform: string
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PreviewFromCRSResult {
|
||||||
|
new_accounts: CRSPreviewAccount[]
|
||||||
|
existing_accounts: CRSPreviewAccount[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function previewFromCrs(params: {
|
||||||
|
base_url: string
|
||||||
|
username: string
|
||||||
|
password: string
|
||||||
|
}): Promise<PreviewFromCRSResult> {
|
||||||
|
const { data } = await apiClient.post<PreviewFromCRSResult>('/admin/accounts/sync/crs/preview', params)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
export async function syncFromCrs(params: {
|
export async function syncFromCrs(params: {
|
||||||
base_url: string
|
base_url: string
|
||||||
username: string
|
username: string
|
||||||
password: string
|
password: string
|
||||||
sync_proxies?: boolean
|
sync_proxies?: boolean
|
||||||
|
selected_account_ids?: string[]
|
||||||
}): Promise<{
|
}): Promise<{
|
||||||
created: number
|
created: number
|
||||||
updated: number
|
updated: number
|
||||||
@@ -345,7 +369,19 @@ export async function syncFromCrs(params: {
|
|||||||
error?: string
|
error?: string
|
||||||
}>
|
}>
|
||||||
}> {
|
}> {
|
||||||
const { data } = await apiClient.post('/admin/accounts/sync/crs', params)
|
const { data } = await apiClient.post<{
|
||||||
|
created: number
|
||||||
|
updated: number
|
||||||
|
skipped: number
|
||||||
|
failed: number
|
||||||
|
items: Array<{
|
||||||
|
crs_account_id: string
|
||||||
|
kind: string
|
||||||
|
name: string
|
||||||
|
action: string
|
||||||
|
error?: string
|
||||||
|
}>
|
||||||
|
}>('/admin/accounts/sync/crs', params)
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -442,6 +478,7 @@ export const accountsAPI = {
|
|||||||
batchCreate,
|
batchCreate,
|
||||||
batchUpdateCredentials,
|
batchUpdateCredentials,
|
||||||
bulkUpdate,
|
bulkUpdate,
|
||||||
|
previewFromCrs,
|
||||||
syncFromCrs,
|
syncFromCrs,
|
||||||
exportData,
|
exportData,
|
||||||
importData,
|
importData,
|
||||||
|
|||||||
@@ -53,4 +53,18 @@ export async function exchangeCode(
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
export default { generateAuthUrl, exchangeCode }
|
export async function refreshAntigravityToken(
|
||||||
|
refreshToken: string,
|
||||||
|
proxyId?: number | null
|
||||||
|
): Promise<AntigravityTokenInfo> {
|
||||||
|
const payload: Record<string, any> = { refresh_token: refreshToken }
|
||||||
|
if (proxyId) payload.proxy_id = proxyId
|
||||||
|
|
||||||
|
const { data } = await apiClient.post<AntigravityTokenInfo>(
|
||||||
|
'/admin/antigravity/oauth/refresh-token',
|
||||||
|
payload
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
export default { generateAuthUrl, exchangeCode, refreshAntigravityToken }
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ export interface ErrorPassthroughRule {
|
|||||||
response_code: number | null
|
response_code: number | null
|
||||||
passthrough_body: boolean
|
passthrough_body: boolean
|
||||||
custom_message: string | null
|
custom_message: string | null
|
||||||
|
skip_monitoring: boolean
|
||||||
description: string | null
|
description: string | null
|
||||||
created_at: string
|
created_at: string
|
||||||
updated_at: string
|
updated_at: string
|
||||||
@@ -41,6 +42,7 @@ export interface CreateRuleRequest {
|
|||||||
response_code?: number | null
|
response_code?: number | null
|
||||||
passthrough_body?: boolean
|
passthrough_body?: boolean
|
||||||
custom_message?: string | null
|
custom_message?: string | null
|
||||||
|
skip_monitoring?: boolean
|
||||||
description?: string | null
|
description?: string | null
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,6 +61,7 @@ export interface UpdateRuleRequest {
|
|||||||
response_code?: number | null
|
response_code?: number | null
|
||||||
passthrough_body?: boolean
|
passthrough_body?: boolean
|
||||||
custom_message?: string | null
|
custom_message?: string | null
|
||||||
|
skip_monitoring?: boolean
|
||||||
description?: string | null
|
description?: string | null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@
|
|||||||
>
|
>
|
||||||
<div class="mb-2 flex items-center justify-between">
|
<div class="mb-2 flex items-center justify-between">
|
||||||
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
<span class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||||
{{ t('admin.accounts.allGroups', { count: groups.length }) }}
|
{{ t('admin.accounts.groupCountTotal', { count: groups.length }) }}
|
||||||
</span>
|
</span>
|
||||||
<button
|
<button
|
||||||
@click="showPopover = false"
|
@click="showPopover = false"
|
||||||
|
|||||||
@@ -665,8 +665,8 @@
|
|||||||
<Icon name="cloud" size="sm" />
|
<Icon name="cloud" size="sm" />
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">{{ t('admin.accounts.types.upstream') }}</span>
|
<span class="block text-sm font-medium text-gray-900 dark:text-white">API Key</span>
|
||||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.upstreamDesc') }}</span>
|
<span class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.accounts.types.antigravityApikey') }}</span>
|
||||||
</div>
|
</div>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
@@ -681,7 +681,7 @@
|
|||||||
type="text"
|
type="text"
|
||||||
required
|
required
|
||||||
class="input"
|
class="input"
|
||||||
placeholder="https://s.konstants.xyz"
|
placeholder="https://cloudcode-pa.googleapis.com"
|
||||||
/>
|
/>
|
||||||
<p class="input-hint">{{ t('admin.accounts.upstream.baseUrlHint') }}</p>
|
<p class="input-hint">{{ t('admin.accounts.upstream.baseUrlHint') }}</p>
|
||||||
</div>
|
</div>
|
||||||
@@ -816,8 +816,8 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- API Key input (only for apikey type) -->
|
<!-- API Key input (only for apikey type, excluding Antigravity which has its own fields) -->
|
||||||
<div v-if="form.type === 'apikey'" class="space-y-4">
|
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity'" class="space-y-4">
|
||||||
<div>
|
<div>
|
||||||
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
|
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
|
||||||
<input
|
<input
|
||||||
@@ -862,7 +862,7 @@
|
|||||||
<p class="input-hint">{{ t('admin.accounts.gemini.tier.aiStudioHint') }}</p>
|
<p class="input-hint">{{ t('admin.accounts.gemini.tier.aiStudioHint') }}</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Model Restriction Section (不适用于 Gemini) -->
|
<!-- Model Restriction Section (不适用于 Gemini,Antigravity 已在上层条件排除) -->
|
||||||
<div v-if="form.platform !== 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
<div v-if="form.platform !== 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||||
|
|
||||||
@@ -1647,12 +1647,12 @@
|
|||||||
:show-proxy-warning="form.platform !== 'openai' && !!form.proxy_id"
|
:show-proxy-warning="form.platform !== 'openai' && !!form.proxy_id"
|
||||||
:allow-multiple="form.platform === 'anthropic'"
|
:allow-multiple="form.platform === 'anthropic'"
|
||||||
:show-cookie-option="form.platform === 'anthropic'"
|
:show-cookie-option="form.platform === 'anthropic'"
|
||||||
:show-refresh-token-option="form.platform === 'openai'"
|
:show-refresh-token-option="form.platform === 'openai' || form.platform === 'antigravity'"
|
||||||
:platform="form.platform"
|
:platform="form.platform"
|
||||||
:show-project-id="geminiOAuthType === 'code_assist'"
|
:show-project-id="geminiOAuthType === 'code_assist'"
|
||||||
@generate-url="handleGenerateUrl"
|
@generate-url="handleGenerateUrl"
|
||||||
@cookie-auth="handleCookieAuth"
|
@cookie-auth="handleCookieAuth"
|
||||||
@validate-refresh-token="handleOpenAIValidateRT"
|
@validate-refresh-token="handleValidateRefreshToken"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
@@ -2802,6 +2802,14 @@ const handleGenerateUrl = async () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleValidateRefreshToken = (rt: string) => {
|
||||||
|
if (form.platform === 'openai') {
|
||||||
|
handleOpenAIValidateRT(rt)
|
||||||
|
} else if (form.platform === 'antigravity') {
|
||||||
|
handleAntigravityValidateRT(rt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const formatDateTimeLocal = formatDateTimeLocalInput
|
const formatDateTimeLocal = formatDateTimeLocalInput
|
||||||
const parseDateTimeLocal = parseDateTimeLocalInput
|
const parseDateTimeLocal = parseDateTimeLocalInput
|
||||||
|
|
||||||
@@ -2950,6 +2958,95 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Antigravity 手动 RT 批量验证和创建
|
||||||
|
const handleAntigravityValidateRT = async (refreshTokenInput: string) => {
|
||||||
|
if (!refreshTokenInput.trim()) return
|
||||||
|
|
||||||
|
// Parse multiple refresh tokens (one per line)
|
||||||
|
const refreshTokens = refreshTokenInput
|
||||||
|
.split('\n')
|
||||||
|
.map((rt) => rt.trim())
|
||||||
|
.filter((rt) => rt)
|
||||||
|
|
||||||
|
if (refreshTokens.length === 0) {
|
||||||
|
antigravityOAuth.error.value = t('admin.accounts.oauth.antigravity.pleaseEnterRefreshToken')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
antigravityOAuth.loading.value = true
|
||||||
|
antigravityOAuth.error.value = ''
|
||||||
|
|
||||||
|
let successCount = 0
|
||||||
|
let failedCount = 0
|
||||||
|
const errors: string[] = []
|
||||||
|
|
||||||
|
try {
|
||||||
|
for (let i = 0; i < refreshTokens.length; i++) {
|
||||||
|
try {
|
||||||
|
const tokenInfo = await antigravityOAuth.validateRefreshToken(
|
||||||
|
refreshTokens[i],
|
||||||
|
form.proxy_id
|
||||||
|
)
|
||||||
|
if (!tokenInfo) {
|
||||||
|
failedCount++
|
||||||
|
errors.push(`#${i + 1}: ${antigravityOAuth.error.value || 'Validation failed'}`)
|
||||||
|
antigravityOAuth.error.value = ''
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
const credentials = antigravityOAuth.buildCredentials(tokenInfo)
|
||||||
|
|
||||||
|
// Generate account name with index for batch
|
||||||
|
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
|
||||||
|
|
||||||
|
// Note: Antigravity doesn't have buildExtraInfo, so we pass empty extra or rely on credentials
|
||||||
|
await adminAPI.accounts.create({
|
||||||
|
name: accountName,
|
||||||
|
notes: form.notes,
|
||||||
|
platform: 'antigravity',
|
||||||
|
type: 'oauth',
|
||||||
|
credentials,
|
||||||
|
extra: {},
|
||||||
|
proxy_id: form.proxy_id,
|
||||||
|
concurrency: form.concurrency,
|
||||||
|
priority: form.priority,
|
||||||
|
rate_multiplier: form.rate_multiplier,
|
||||||
|
group_ids: form.group_ids,
|
||||||
|
expires_at: form.expires_at,
|
||||||
|
auto_pause_on_expired: autoPauseOnExpired.value
|
||||||
|
})
|
||||||
|
successCount++
|
||||||
|
} catch (error: any) {
|
||||||
|
failedCount++
|
||||||
|
const errMsg = error.response?.data?.detail || error.message || 'Unknown error'
|
||||||
|
errors.push(`#${i + 1}: ${errMsg}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show results
|
||||||
|
if (successCount > 0 && failedCount === 0) {
|
||||||
|
appStore.showSuccess(
|
||||||
|
refreshTokens.length > 1
|
||||||
|
? t('admin.accounts.oauth.batchSuccess', { count: successCount })
|
||||||
|
: t('admin.accounts.accountCreated')
|
||||||
|
)
|
||||||
|
emit('created')
|
||||||
|
handleClose()
|
||||||
|
} else if (successCount > 0 && failedCount > 0) {
|
||||||
|
appStore.showWarning(
|
||||||
|
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
|
||||||
|
)
|
||||||
|
antigravityOAuth.error.value = errors.join('\n')
|
||||||
|
emit('created')
|
||||||
|
} else {
|
||||||
|
antigravityOAuth.error.value = errors.join('\n')
|
||||||
|
appStore.showError(t('admin.accounts.oauth.batchFailed'))
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
antigravityOAuth.loading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Gemini OAuth 授权码兑换
|
// Gemini OAuth 授权码兑换
|
||||||
const handleGeminiExchange = async (authCode: string) => {
|
const handleGeminiExchange = async (authCode: string) => {
|
||||||
if (!authCode.trim() || !geminiOAuth.sessionId.value) return
|
if (!authCode.trim() || !geminiOAuth.sessionId.value) return
|
||||||
|
|||||||
@@ -39,7 +39,9 @@
|
|||||||
? 'https://api.openai.com'
|
? 'https://api.openai.com'
|
||||||
: account.platform === 'gemini'
|
: account.platform === 'gemini'
|
||||||
? 'https://generativelanguage.googleapis.com'
|
? 'https://generativelanguage.googleapis.com'
|
||||||
: 'https://api.anthropic.com'
|
: account.platform === 'antigravity'
|
||||||
|
? 'https://cloudcode-pa.googleapis.com'
|
||||||
|
: 'https://api.anthropic.com'
|
||||||
"
|
"
|
||||||
/>
|
/>
|
||||||
<p class="input-hint">{{ baseUrlHint }}</p>
|
<p class="input-hint">{{ baseUrlHint }}</p>
|
||||||
@@ -55,14 +57,16 @@
|
|||||||
? 'sk-proj-...'
|
? 'sk-proj-...'
|
||||||
: account.platform === 'gemini'
|
: account.platform === 'gemini'
|
||||||
? 'AIza...'
|
? 'AIza...'
|
||||||
: 'sk-ant-...'
|
: account.platform === 'antigravity'
|
||||||
|
? 'sk-...'
|
||||||
|
: 'sk-ant-...'
|
||||||
"
|
"
|
||||||
/>
|
/>
|
||||||
<p class="input-hint">{{ t('admin.accounts.leaveEmptyToKeep') }}</p>
|
<p class="input-hint">{{ t('admin.accounts.leaveEmptyToKeep') }}</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Model Restriction Section (不适用于 Gemini) -->
|
<!-- Model Restriction Section (不适用于 Gemini 和 Antigravity) -->
|
||||||
<div v-if="account.platform !== 'gemini'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
<div v-if="account.platform !== 'gemini' && account.platform !== 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||||
|
|
||||||
<!-- Mode Toggle -->
|
<!-- Mode Toggle -->
|
||||||
@@ -372,7 +376,7 @@
|
|||||||
v-model="editBaseUrl"
|
v-model="editBaseUrl"
|
||||||
type="text"
|
type="text"
|
||||||
class="input"
|
class="input"
|
||||||
placeholder="https://s.konstants.xyz"
|
placeholder="https://cloudcode-pa.googleapis.com"
|
||||||
/>
|
/>
|
||||||
<p class="input-hint">{{ t('admin.accounts.upstream.baseUrlHint') }}</p>
|
<p class="input-hint">{{ t('admin.accounts.upstream.baseUrlHint') }}</p>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -45,19 +45,19 @@
|
|||||||
class="text-blue-600 focus:ring-blue-500"
|
class="text-blue-600 focus:ring-blue-500"
|
||||||
/>
|
/>
|
||||||
<span class="text-sm text-blue-900 dark:text-blue-200">{{
|
<span class="text-sm text-blue-900 dark:text-blue-200">{{
|
||||||
t('admin.accounts.oauth.openai.refreshTokenAuth')
|
t(getOAuthKey('refreshTokenAuth'))
|
||||||
}}</span>
|
}}</span>
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Refresh Token Input (OpenAI only) -->
|
<!-- Refresh Token Input (OpenAI / Antigravity) -->
|
||||||
<div v-if="inputMethod === 'refresh_token'" class="space-y-4">
|
<div v-if="inputMethod === 'refresh_token'" class="space-y-4">
|
||||||
<div
|
<div
|
||||||
class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80"
|
class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80"
|
||||||
>
|
>
|
||||||
<p class="mb-3 text-sm text-blue-700 dark:text-blue-300">
|
<p class="mb-3 text-sm text-blue-700 dark:text-blue-300">
|
||||||
{{ t('admin.accounts.oauth.openai.refreshTokenDesc') }}
|
{{ t(getOAuthKey('refreshTokenDesc')) }}
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<!-- Refresh Token Input -->
|
<!-- Refresh Token Input -->
|
||||||
@@ -78,7 +78,7 @@
|
|||||||
v-model="refreshTokenInput"
|
v-model="refreshTokenInput"
|
||||||
rows="3"
|
rows="3"
|
||||||
class="input w-full resize-y font-mono text-sm"
|
class="input w-full resize-y font-mono text-sm"
|
||||||
:placeholder="t('admin.accounts.oauth.openai.refreshTokenPlaceholder')"
|
:placeholder="t(getOAuthKey('refreshTokenPlaceholder'))"
|
||||||
></textarea>
|
></textarea>
|
||||||
<p
|
<p
|
||||||
v-if="parsedRefreshTokenCount > 1"
|
v-if="parsedRefreshTokenCount > 1"
|
||||||
@@ -128,8 +128,8 @@
|
|||||||
<Icon v-else name="sparkles" size="sm" class="mr-2" />
|
<Icon v-else name="sparkles" size="sm" class="mr-2" />
|
||||||
{{
|
{{
|
||||||
loading
|
loading
|
||||||
? t('admin.accounts.oauth.openai.validating')
|
? t(getOAuthKey('validating'))
|
||||||
: t('admin.accounts.oauth.openai.validateAndCreate')
|
: t(getOAuthKey('validateAndCreate'))
|
||||||
}}
|
}}
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -6,15 +6,20 @@
|
|||||||
close-on-click-outside
|
close-on-click-outside
|
||||||
@close="handleClose"
|
@close="handleClose"
|
||||||
>
|
>
|
||||||
<form id="sync-from-crs-form" class="space-y-4" @submit.prevent="handleSync">
|
<!-- Step 1: Input credentials -->
|
||||||
|
<form
|
||||||
|
v-if="currentStep === 'input'"
|
||||||
|
id="sync-from-crs-form"
|
||||||
|
class="space-y-4"
|
||||||
|
@submit.prevent="handlePreview"
|
||||||
|
>
|
||||||
<div class="text-sm text-gray-600 dark:text-dark-300">
|
<div class="text-sm text-gray-600 dark:text-dark-300">
|
||||||
{{ t('admin.accounts.syncFromCrsDesc') }}
|
{{ t('admin.accounts.syncFromCrsDesc') }}
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
class="rounded-lg bg-gray-50 p-3 text-xs text-gray-500 dark:bg-dark-700/60 dark:text-dark-400"
|
class="rounded-lg bg-gray-50 p-3 text-xs text-gray-500 dark:bg-dark-700/60 dark:text-dark-400"
|
||||||
>
|
>
|
||||||
已有账号仅同步 CRS
|
{{ t('admin.accounts.crsUpdateBehaviorNote') }}
|
||||||
返回的字段,缺失字段保持原值;凭据按键合并,不会清空未下发的键;未勾选"同步代理"时保留原有代理。
|
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
class="rounded-lg border border-amber-200 bg-amber-50 p-3 text-xs text-amber-600 dark:border-amber-800 dark:bg-amber-900/20 dark:text-amber-400"
|
class="rounded-lg border border-amber-200 bg-amber-50 p-3 text-xs text-amber-600 dark:border-amber-800 dark:bg-amber-900/20 dark:text-amber-400"
|
||||||
@@ -24,26 +29,30 @@
|
|||||||
|
|
||||||
<div class="grid grid-cols-1 gap-4">
|
<div class="grid grid-cols-1 gap-4">
|
||||||
<div>
|
<div>
|
||||||
<label class="input-label">{{ t('admin.accounts.crsBaseUrl') }}</label>
|
<label for="crs-base-url" class="input-label">{{ t('admin.accounts.crsBaseUrl') }}</label>
|
||||||
<input
|
<input
|
||||||
|
id="crs-base-url"
|
||||||
v-model="form.base_url"
|
v-model="form.base_url"
|
||||||
type="text"
|
type="text"
|
||||||
class="input"
|
class="input"
|
||||||
|
required
|
||||||
:placeholder="t('admin.accounts.crsBaseUrlPlaceholder')"
|
:placeholder="t('admin.accounts.crsBaseUrlPlaceholder')"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="grid grid-cols-1 gap-4 sm:grid-cols-2">
|
<div class="grid grid-cols-1 gap-4 sm:grid-cols-2">
|
||||||
<div>
|
<div>
|
||||||
<label class="input-label">{{ t('admin.accounts.crsUsername') }}</label>
|
<label for="crs-username" class="input-label">{{ t('admin.accounts.crsUsername') }}</label>
|
||||||
<input v-model="form.username" type="text" class="input" autocomplete="username" />
|
<input id="crs-username" v-model="form.username" type="text" class="input" required autocomplete="username" />
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<label class="input-label">{{ t('admin.accounts.crsPassword') }}</label>
|
<label for="crs-password" class="input-label">{{ t('admin.accounts.crsPassword') }}</label>
|
||||||
<input
|
<input
|
||||||
|
id="crs-password"
|
||||||
v-model="form.password"
|
v-model="form.password"
|
||||||
type="password"
|
type="password"
|
||||||
class="input"
|
class="input"
|
||||||
|
required
|
||||||
autocomplete="current-password"
|
autocomplete="current-password"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
@@ -58,9 +67,101 @@
|
|||||||
{{ t('admin.accounts.syncProxies') }}
|
{{ t('admin.accounts.syncProxies') }}
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<!-- Step 2: Preview & select -->
|
||||||
|
<div v-else-if="currentStep === 'preview' && previewResult" class="space-y-4">
|
||||||
|
<!-- Existing accounts (read-only info) -->
|
||||||
|
<div
|
||||||
|
v-if="previewResult.existing_accounts.length"
|
||||||
|
class="rounded-lg bg-gray-50 p-3 dark:bg-dark-700/60"
|
||||||
|
>
|
||||||
|
<div class="mb-2 text-sm font-medium text-gray-700 dark:text-dark-300">
|
||||||
|
{{ t('admin.accounts.crsExistingAccounts') }}
|
||||||
|
<span class="ml-1 text-xs text-gray-400">({{ previewResult.existing_accounts.length }})</span>
|
||||||
|
</div>
|
||||||
|
<div class="max-h-32 overflow-auto text-xs text-gray-500 dark:text-dark-400">
|
||||||
|
<div
|
||||||
|
v-for="acc in previewResult.existing_accounts"
|
||||||
|
:key="acc.crs_account_id"
|
||||||
|
class="flex items-center gap-2 py-0.5"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
class="inline-block rounded bg-blue-100 px-1.5 py-0.5 text-[10px] font-medium text-blue-700 dark:bg-blue-900/30 dark:text-blue-400"
|
||||||
|
>{{ acc.platform }} / {{ acc.type }}</span>
|
||||||
|
<span class="truncate">{{ acc.name }}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- New accounts (selectable) -->
|
||||||
|
<div v-if="previewResult.new_accounts.length">
|
||||||
|
<div class="mb-2 flex items-center justify-between">
|
||||||
|
<div class="text-sm font-medium text-gray-900 dark:text-white">
|
||||||
|
{{ t('admin.accounts.crsNewAccounts') }}
|
||||||
|
<span class="ml-1 text-xs text-gray-400">({{ previewResult.new_accounts.length }})</span>
|
||||||
|
</div>
|
||||||
|
<div class="flex gap-2">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="text-xs text-blue-600 hover:text-blue-700 dark:text-blue-400"
|
||||||
|
@click="selectAll"
|
||||||
|
>{{ t('admin.accounts.crsSelectAll') }}</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="text-xs text-gray-500 hover:text-gray-600 dark:text-gray-400"
|
||||||
|
@click="selectNone"
|
||||||
|
>{{ t('admin.accounts.crsSelectNone') }}</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
class="max-h-48 overflow-auto rounded-lg border border-gray-200 p-2 dark:border-dark-600"
|
||||||
|
>
|
||||||
|
<label
|
||||||
|
v-for="acc in previewResult.new_accounts"
|
||||||
|
:key="acc.crs_account_id"
|
||||||
|
class="flex cursor-pointer items-center gap-2 rounded px-2 py-1.5 hover:bg-gray-50 dark:hover:bg-dark-700/40"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
:checked="selectedIds.has(acc.crs_account_id)"
|
||||||
|
class="rounded border-gray-300 dark:border-dark-600"
|
||||||
|
@change="toggleSelect(acc.crs_account_id)"
|
||||||
|
/>
|
||||||
|
<span
|
||||||
|
class="inline-block rounded bg-green-100 px-1.5 py-0.5 text-[10px] font-medium text-green-700 dark:bg-green-900/30 dark:text-green-400"
|
||||||
|
>{{ acc.platform }} / {{ acc.type }}</span>
|
||||||
|
<span class="truncate text-sm text-gray-700 dark:text-dark-300">{{ acc.name }}</span>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<div class="mt-1 text-xs text-gray-400">
|
||||||
|
{{ t('admin.accounts.crsSelectedCount', { count: selectedIds.size }) }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Sync options summary -->
|
||||||
|
<div class="flex items-center gap-2 text-xs text-gray-500 dark:text-dark-400">
|
||||||
|
<span>{{ t('admin.accounts.syncProxies') }}:</span>
|
||||||
|
<span :class="form.sync_proxies ? 'text-green-600 dark:text-green-400' : 'text-gray-400 dark:text-dark-500'">
|
||||||
|
{{ form.sync_proxies ? t('common.yes') : t('common.no') }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- No new accounts -->
|
||||||
|
<div
|
||||||
|
v-if="!previewResult.new_accounts.length"
|
||||||
|
class="rounded-lg bg-gray-50 p-4 text-center text-sm text-gray-500 dark:bg-dark-700/60 dark:text-dark-400"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.crsNoNewAccounts') }}
|
||||||
|
<span v-if="previewResult.existing_accounts.length">
|
||||||
|
{{ t('admin.accounts.crsWillUpdate', { count: previewResult.existing_accounts.length }) }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Step 3: Result -->
|
||||||
|
<div v-else-if="currentStep === 'result' && result" class="space-y-4">
|
||||||
<div
|
<div
|
||||||
v-if="result"
|
|
||||||
class="space-y-2 rounded-xl border border-gray-200 p-4 dark:border-dark-700"
|
class="space-y-2 rounded-xl border border-gray-200 p-4 dark:border-dark-700"
|
||||||
>
|
>
|
||||||
<div class="text-sm font-medium text-gray-900 dark:text-white">
|
<div class="text-sm font-medium text-gray-900 dark:text-white">
|
||||||
@@ -84,21 +185,56 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</div>
|
||||||
|
|
||||||
<template #footer>
|
<template #footer>
|
||||||
<div class="flex justify-end gap-3">
|
<div class="flex justify-end gap-3">
|
||||||
<button class="btn btn-secondary" type="button" :disabled="syncing" @click="handleClose">
|
<!-- Step 1: Input -->
|
||||||
{{ t('common.cancel') }}
|
<template v-if="currentStep === 'input'">
|
||||||
</button>
|
<button
|
||||||
<button
|
class="btn btn-secondary"
|
||||||
class="btn btn-primary"
|
type="button"
|
||||||
type="submit"
|
:disabled="previewing"
|
||||||
form="sync-from-crs-form"
|
@click="handleClose"
|
||||||
:disabled="syncing"
|
>
|
||||||
>
|
{{ t('common.cancel') }}
|
||||||
{{ syncing ? t('admin.accounts.syncing') : t('admin.accounts.syncNow') }}
|
</button>
|
||||||
</button>
|
<button
|
||||||
|
class="btn btn-primary"
|
||||||
|
type="submit"
|
||||||
|
form="sync-from-crs-form"
|
||||||
|
:disabled="previewing"
|
||||||
|
>
|
||||||
|
{{ previewing ? t('admin.accounts.crsPreviewing') : t('admin.accounts.crsPreview') }}
|
||||||
|
</button>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<!-- Step 2: Preview -->
|
||||||
|
<template v-else-if="currentStep === 'preview'">
|
||||||
|
<button
|
||||||
|
class="btn btn-secondary"
|
||||||
|
type="button"
|
||||||
|
:disabled="syncing"
|
||||||
|
@click="handleBack"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.crsBack') }}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
class="btn btn-primary"
|
||||||
|
type="button"
|
||||||
|
:disabled="syncing || hasNewButNoneSelected"
|
||||||
|
@click="handleSync"
|
||||||
|
>
|
||||||
|
{{ syncing ? t('admin.accounts.syncing') : t('admin.accounts.syncNow') }}
|
||||||
|
</button>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<!-- Step 3: Result -->
|
||||||
|
<template v-else-if="currentStep === 'result'">
|
||||||
|
<button class="btn btn-secondary" type="button" @click="handleClose">
|
||||||
|
{{ t('common.close') }}
|
||||||
|
</button>
|
||||||
|
</template>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
</BaseDialog>
|
</BaseDialog>
|
||||||
@@ -110,6 +246,7 @@ import { useI18n } from 'vue-i18n'
|
|||||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
|
import type { PreviewFromCRSResult } from '@/api/admin/accounts'
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
show: boolean
|
show: boolean
|
||||||
@@ -126,7 +263,12 @@ const emit = defineEmits<Emits>()
|
|||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
|
|
||||||
|
type Step = 'input' | 'preview' | 'result'
|
||||||
|
const currentStep = ref<Step>('input')
|
||||||
|
const previewing = ref(false)
|
||||||
const syncing = ref(false)
|
const syncing = ref(false)
|
||||||
|
const previewResult = ref<PreviewFromCRSResult | null>(null)
|
||||||
|
const selectedIds = ref(new Set<string>())
|
||||||
const result = ref<Awaited<ReturnType<typeof adminAPI.accounts.syncFromCrs>> | null>(null)
|
const result = ref<Awaited<ReturnType<typeof adminAPI.accounts.syncFromCrs>> | null>(null)
|
||||||
|
|
||||||
const form = reactive({
|
const form = reactive({
|
||||||
@@ -136,28 +278,90 @@ const form = reactive({
|
|||||||
sync_proxies: true
|
sync_proxies: true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const hasNewButNoneSelected = computed(() => {
|
||||||
|
if (!previewResult.value) return false
|
||||||
|
return previewResult.value.new_accounts.length > 0 && selectedIds.value.size === 0
|
||||||
|
})
|
||||||
|
|
||||||
const errorItems = computed(() => {
|
const errorItems = computed(() => {
|
||||||
if (!result.value?.items) return []
|
if (!result.value?.items) return []
|
||||||
return result.value.items.filter((i) => i.action === 'failed' || i.action === 'skipped')
|
return result.value.items.filter(
|
||||||
|
(i) => i.action === 'failed' || (i.action === 'skipped' && i.error !== 'not selected')
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
watch(
|
watch(
|
||||||
() => props.show,
|
() => props.show,
|
||||||
(open) => {
|
(open) => {
|
||||||
if (open) {
|
if (open) {
|
||||||
|
currentStep.value = 'input'
|
||||||
|
previewResult.value = null
|
||||||
|
selectedIds.value = new Set()
|
||||||
result.value = null
|
result.value = null
|
||||||
|
form.base_url = ''
|
||||||
|
form.username = ''
|
||||||
|
form.password = ''
|
||||||
|
form.sync_proxies = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
const handleClose = () => {
|
const handleClose = () => {
|
||||||
// 防止在同步进行中关闭对话框
|
if (syncing.value || previewing.value) {
|
||||||
if (syncing.value) {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
emit('close')
|
emit('close')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleBack = () => {
|
||||||
|
currentStep.value = 'input'
|
||||||
|
previewResult.value = null
|
||||||
|
selectedIds.value = new Set()
|
||||||
|
}
|
||||||
|
|
||||||
|
const selectAll = () => {
|
||||||
|
if (!previewResult.value) return
|
||||||
|
selectedIds.value = new Set(previewResult.value.new_accounts.map((a) => a.crs_account_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
const selectNone = () => {
|
||||||
|
selectedIds.value = new Set()
|
||||||
|
}
|
||||||
|
|
||||||
|
const toggleSelect = (id: string) => {
|
||||||
|
const s = new Set(selectedIds.value)
|
||||||
|
if (s.has(id)) {
|
||||||
|
s.delete(id)
|
||||||
|
} else {
|
||||||
|
s.add(id)
|
||||||
|
}
|
||||||
|
selectedIds.value = s
|
||||||
|
}
|
||||||
|
|
||||||
|
const handlePreview = async () => {
|
||||||
|
if (!form.base_url.trim() || !form.username.trim() || !form.password.trim()) {
|
||||||
|
appStore.showError(t('admin.accounts.syncMissingFields'))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
previewing.value = true
|
||||||
|
try {
|
||||||
|
const res = await adminAPI.accounts.previewFromCrs({
|
||||||
|
base_url: form.base_url.trim(),
|
||||||
|
username: form.username.trim(),
|
||||||
|
password: form.password
|
||||||
|
})
|
||||||
|
previewResult.value = res
|
||||||
|
// Auto-select all new accounts
|
||||||
|
selectedIds.value = new Set(res.new_accounts.map((a) => a.crs_account_id))
|
||||||
|
currentStep.value = 'preview'
|
||||||
|
} catch (error: any) {
|
||||||
|
appStore.showError(error?.message || t('admin.accounts.crsPreviewFailed'))
|
||||||
|
} finally {
|
||||||
|
previewing.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const handleSync = async () => {
|
const handleSync = async () => {
|
||||||
if (!form.base_url.trim() || !form.username.trim() || !form.password.trim()) {
|
if (!form.base_url.trim() || !form.username.trim() || !form.password.trim()) {
|
||||||
appStore.showError(t('admin.accounts.syncMissingFields'))
|
appStore.showError(t('admin.accounts.syncMissingFields'))
|
||||||
@@ -170,16 +374,18 @@ const handleSync = async () => {
|
|||||||
base_url: form.base_url.trim(),
|
base_url: form.base_url.trim(),
|
||||||
username: form.username.trim(),
|
username: form.username.trim(),
|
||||||
password: form.password,
|
password: form.password,
|
||||||
sync_proxies: form.sync_proxies
|
sync_proxies: form.sync_proxies,
|
||||||
|
selected_account_ids: [...selectedIds.value]
|
||||||
})
|
})
|
||||||
result.value = res
|
result.value = res
|
||||||
|
currentStep.value = 'result'
|
||||||
|
|
||||||
if (res.failed > 0) {
|
if (res.failed > 0) {
|
||||||
appStore.showError(t('admin.accounts.syncCompletedWithErrors', res))
|
appStore.showError(t('admin.accounts.syncCompletedWithErrors', res))
|
||||||
} else {
|
} else {
|
||||||
appStore.showSuccess(t('admin.accounts.syncCompleted', res))
|
appStore.showSuccess(t('admin.accounts.syncCompleted', res))
|
||||||
emit('synced')
|
|
||||||
}
|
}
|
||||||
|
emit('synced')
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
appStore.showError(error?.message || t('admin.accounts.syncFailed'))
|
appStore.showError(error?.message || t('admin.accounts.syncFailed'))
|
||||||
} finally {
|
} finally {
|
||||||
|
|||||||
@@ -148,6 +148,16 @@
|
|||||||
{{ rule.passthrough_body ? t('admin.errorPassthrough.passthrough') : t('admin.errorPassthrough.custom') }}
|
{{ rule.passthrough_body ? t('admin.errorPassthrough.passthrough') : t('admin.errorPassthrough.custom') }}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
|
<div v-if="rule.skip_monitoring" class="flex items-center gap-1">
|
||||||
|
<Icon
|
||||||
|
name="checkCircle"
|
||||||
|
size="xs"
|
||||||
|
class="text-yellow-500"
|
||||||
|
/>
|
||||||
|
<span class="text-gray-600 dark:text-gray-400">
|
||||||
|
{{ t('admin.errorPassthrough.skipMonitoring') }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</td>
|
</td>
|
||||||
<td class="px-3 py-2">
|
<td class="px-3 py-2">
|
||||||
@@ -366,6 +376,19 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Skip Monitoring -->
|
||||||
|
<div class="flex items-center gap-1.5">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
v-model="form.skip_monitoring"
|
||||||
|
class="h-3.5 w-3.5 rounded border-gray-300 text-yellow-600 focus:ring-yellow-500"
|
||||||
|
/>
|
||||||
|
<span class="text-xs font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t('admin.errorPassthrough.form.skipMonitoring') }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<p class="input-hint text-xs -mt-3">{{ t('admin.errorPassthrough.form.skipMonitoringHint') }}</p>
|
||||||
|
|
||||||
<!-- Enabled -->
|
<!-- Enabled -->
|
||||||
<div class="flex items-center gap-1.5">
|
<div class="flex items-center gap-1.5">
|
||||||
<input
|
<input
|
||||||
@@ -453,6 +476,7 @@ const form = reactive({
|
|||||||
response_code: null as number | null,
|
response_code: null as number | null,
|
||||||
passthrough_body: true,
|
passthrough_body: true,
|
||||||
custom_message: null as string | null,
|
custom_message: null as string | null,
|
||||||
|
skip_monitoring: false,
|
||||||
description: null as string | null
|
description: null as string | null
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -497,6 +521,7 @@ const resetForm = () => {
|
|||||||
form.response_code = null
|
form.response_code = null
|
||||||
form.passthrough_body = true
|
form.passthrough_body = true
|
||||||
form.custom_message = null
|
form.custom_message = null
|
||||||
|
form.skip_monitoring = false
|
||||||
form.description = null
|
form.description = null
|
||||||
errorCodesInput.value = ''
|
errorCodesInput.value = ''
|
||||||
keywordsInput.value = ''
|
keywordsInput.value = ''
|
||||||
@@ -520,6 +545,7 @@ const handleEdit = (rule: ErrorPassthroughRule) => {
|
|||||||
form.response_code = rule.response_code
|
form.response_code = rule.response_code
|
||||||
form.passthrough_body = rule.passthrough_body
|
form.passthrough_body = rule.passthrough_body
|
||||||
form.custom_message = rule.custom_message
|
form.custom_message = rule.custom_message
|
||||||
|
form.skip_monitoring = rule.skip_monitoring
|
||||||
form.description = rule.description
|
form.description = rule.description
|
||||||
errorCodesInput.value = rule.error_codes.join(', ')
|
errorCodesInput.value = rule.error_codes.join(', ')
|
||||||
keywordsInput.value = rule.keywords.join('\n')
|
keywordsInput.value = rule.keywords.join('\n')
|
||||||
@@ -575,6 +601,7 @@ const handleSubmit = async () => {
|
|||||||
response_code: form.passthrough_code ? null : form.response_code,
|
response_code: form.passthrough_code ? null : form.response_code,
|
||||||
passthrough_body: form.passthrough_body,
|
passthrough_body: form.passthrough_body,
|
||||||
custom_message: form.passthrough_body ? null : form.custom_message,
|
custom_message: form.passthrough_body ? null : form.custom_message,
|
||||||
|
skip_monitoring: form.skip_monitoring,
|
||||||
description: form.description?.trim() || null
|
description: form.description?.trim() || null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,19 @@ import type { Account } from '@/types'
|
|||||||
const props = defineProps<{ show: boolean; account: Account | null; position: { top: number; left: number } | null }>()
|
const props = defineProps<{ show: boolean; account: Account | null; position: { top: number; left: number } | null }>()
|
||||||
const emit = defineEmits(['close', 'test', 'stats', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit'])
|
const emit = defineEmits(['close', 'test', 'stats', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit'])
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
const isRateLimited = computed(() => props.account?.rate_limit_reset_at && new Date(props.account.rate_limit_reset_at) > new Date())
|
const isRateLimited = computed(() => {
|
||||||
|
if (props.account?.rate_limit_reset_at && new Date(props.account.rate_limit_reset_at) > new Date()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
const modelLimits = (props.account?.extra as Record<string, unknown> | undefined)?.model_rate_limits as
|
||||||
|
| Record<string, { rate_limit_reset_at: string }>
|
||||||
|
| undefined
|
||||||
|
if (modelLimits) {
|
||||||
|
const now = new Date()
|
||||||
|
return Object.values(modelLimits).some(info => new Date(info.rate_limit_reset_at) > now)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
const isOverloaded = computed(() => props.account?.overload_until && new Date(props.account.overload_until) > new Date())
|
const isOverloaded = computed(() => props.account?.overload_until && new Date(props.account.overload_until) > new Date())
|
||||||
|
|
||||||
const handleKeydown = (event: KeyboardEvent) => {
|
const handleKeydown = (event: KeyboardEvent) => {
|
||||||
|
|||||||
@@ -10,16 +10,21 @@
|
|||||||
<Select :model-value="filters.platform" class="w-40" :options="pOpts" @update:model-value="updatePlatform" @change="$emit('change')" />
|
<Select :model-value="filters.platform" class="w-40" :options="pOpts" @update:model-value="updatePlatform" @change="$emit('change')" />
|
||||||
<Select :model-value="filters.type" class="w-40" :options="tOpts" @update:model-value="updateType" @change="$emit('change')" />
|
<Select :model-value="filters.type" class="w-40" :options="tOpts" @update:model-value="updateType" @change="$emit('change')" />
|
||||||
<Select :model-value="filters.status" class="w-40" :options="sOpts" @update:model-value="updateStatus" @change="$emit('change')" />
|
<Select :model-value="filters.status" class="w-40" :options="sOpts" @update:model-value="updateStatus" @change="$emit('change')" />
|
||||||
|
<Select :model-value="filters.group" class="w-40" :options="gOpts" @update:model-value="updateGroup" @change="$emit('change')" />
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { computed } from 'vue'; import { useI18n } from 'vue-i18n'; import Select from '@/components/common/Select.vue'; import SearchInput from '@/components/common/SearchInput.vue'
|
import { computed } from 'vue'; import { useI18n } from 'vue-i18n'; import Select from '@/components/common/Select.vue'; import SearchInput from '@/components/common/SearchInput.vue'
|
||||||
const props = defineProps(['searchQuery', 'filters']); const emit = defineEmits(['update:searchQuery', 'update:filters', 'change']); const { t } = useI18n()
|
import type { AdminGroup } from '@/types'
|
||||||
|
const props = defineProps<{ searchQuery: string; filters: Record<string, any>; groups?: AdminGroup[] }>()
|
||||||
|
const emit = defineEmits(['update:searchQuery', 'update:filters', 'change']); const { t } = useI18n()
|
||||||
const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) }
|
const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) }
|
||||||
const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) }
|
const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) }
|
||||||
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
|
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
|
||||||
|
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
|
||||||
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }])
|
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }])
|
||||||
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }])
|
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }])
|
||||||
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }])
|
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }])
|
||||||
|
const gOpts = computed(() => [{ value: '', label: t('admin.accounts.allGroups') }, ...(props.groups || []).map(g => ({ value: String(g.id), label: g.name }))])
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -70,6 +70,7 @@
|
|||||||
<div v-if="row.cache_creation_tokens > 0" class="inline-flex items-center gap-1">
|
<div v-if="row.cache_creation_tokens > 0" class="inline-flex items-center gap-1">
|
||||||
<svg class="h-3.5 w-3.5 text-amber-500" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" /></svg>
|
<svg class="h-3.5 w-3.5 text-amber-500" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" /></svg>
|
||||||
<span class="font-medium text-amber-600 dark:text-amber-400">{{ formatCacheTokens(row.cache_creation_tokens) }}</span>
|
<span class="font-medium text-amber-600 dark:text-amber-400">{{ formatCacheTokens(row.cache_creation_tokens) }}</span>
|
||||||
|
<span v-if="row.cache_creation_1h_tokens > 0" class="inline-flex items-center rounded px-1 py-px text-[10px] font-medium leading-tight bg-orange-100 text-orange-600 ring-1 ring-inset ring-orange-200 dark:bg-orange-500/20 dark:text-orange-400 dark:ring-orange-500/30">1h</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -157,9 +158,29 @@
|
|||||||
<span class="text-gray-400">{{ t('admin.usage.outputTokens') }}</span>
|
<span class="text-gray-400">{{ t('admin.usage.outputTokens') }}</span>
|
||||||
<span class="font-medium text-white">{{ tokenTooltipData.output_tokens.toLocaleString() }}</span>
|
<span class="font-medium text-white">{{ tokenTooltipData.output_tokens.toLocaleString() }}</span>
|
||||||
</div>
|
</div>
|
||||||
<div v-if="tokenTooltipData && tokenTooltipData.cache_creation_tokens > 0" class="flex items-center justify-between gap-4">
|
<div v-if="tokenTooltipData && tokenTooltipData.cache_creation_tokens > 0">
|
||||||
<span class="text-gray-400">{{ t('admin.usage.cacheCreationTokens') }}</span>
|
<!-- 有 5m/1h 明细时,展开显示 -->
|
||||||
<span class="font-medium text-white">{{ tokenTooltipData.cache_creation_tokens.toLocaleString() }}</span>
|
<template v-if="tokenTooltipData.cache_creation_5m_tokens > 0 || tokenTooltipData.cache_creation_1h_tokens > 0">
|
||||||
|
<div v-if="tokenTooltipData.cache_creation_5m_tokens > 0" class="flex items-center justify-between gap-4">
|
||||||
|
<span class="text-gray-400 flex items-center gap-1.5">
|
||||||
|
{{ t('admin.usage.cacheCreation5mTokens') }}
|
||||||
|
<span class="inline-flex items-center rounded px-1 py-px text-[10px] font-medium leading-tight bg-amber-500/20 text-amber-400 ring-1 ring-inset ring-amber-500/30">5m</span>
|
||||||
|
</span>
|
||||||
|
<span class="font-medium text-white">{{ tokenTooltipData.cache_creation_5m_tokens.toLocaleString() }}</span>
|
||||||
|
</div>
|
||||||
|
<div v-if="tokenTooltipData.cache_creation_1h_tokens > 0" class="flex items-center justify-between gap-4">
|
||||||
|
<span class="text-gray-400 flex items-center gap-1.5">
|
||||||
|
{{ t('admin.usage.cacheCreation1hTokens') }}
|
||||||
|
<span class="inline-flex items-center rounded px-1 py-px text-[10px] font-medium leading-tight bg-orange-500/20 text-orange-400 ring-1 ring-inset ring-orange-500/30">1h</span>
|
||||||
|
</span>
|
||||||
|
<span class="font-medium text-white">{{ tokenTooltipData.cache_creation_1h_tokens.toLocaleString() }}</span>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<!-- 无明细时,只显示聚合值 -->
|
||||||
|
<div v-else class="flex items-center justify-between gap-4">
|
||||||
|
<span class="text-gray-400">{{ t('admin.usage.cacheCreationTokens') }}</span>
|
||||||
|
<span class="font-medium text-white">{{ tokenTooltipData.cache_creation_tokens.toLocaleString() }}</span>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div v-if="tokenTooltipData && tokenTooltipData.cache_read_tokens > 0" class="flex items-center justify-between gap-4">
|
<div v-if="tokenTooltipData && tokenTooltipData.cache_read_tokens > 0" class="flex items-center justify-between gap-4">
|
||||||
<span class="text-gray-400">{{ t('admin.usage.cacheReadTokens') }}</span>
|
<span class="text-gray-400">{{ t('admin.usage.cacheReadTokens') }}</span>
|
||||||
|
|||||||
@@ -22,6 +22,7 @@
|
|||||||
/>
|
/>
|
||||||
<GroupBadge
|
<GroupBadge
|
||||||
:name="group.name"
|
:name="group.name"
|
||||||
|
:platform="group.platform"
|
||||||
:subscription-type="group.subscription_type"
|
:subscription-type="group.subscription_type"
|
||||||
:rate-multiplier="group.rate_multiplier"
|
:rate-multiplier="group.rate_multiplier"
|
||||||
class="min-w-0 flex-1"
|
class="min-w-0 flex-1"
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
<div class="min-w-0 flex-1">
|
<div class="min-w-0 flex-1">
|
||||||
<p class="stat-label truncate">{{ title }}</p>
|
<p class="stat-label truncate">{{ title }}</p>
|
||||||
<div class="mt-1 flex items-baseline gap-2">
|
<div class="mt-1 flex items-baseline gap-2">
|
||||||
<p class="stat-value">{{ formattedValue }}</p>
|
<p class="stat-value" :title="String(formattedValue)">{{ formattedValue }}</p>
|
||||||
<span v-if="change !== undefined" :class="['stat-trend', trendClass]">
|
<span v-if="change !== undefined" :class="['stat-trend', trendClass]">
|
||||||
<Icon
|
<Icon
|
||||||
v-if="changeType !== 'neutral'"
|
v-if="changeType !== 'neutral'"
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
<div class="sidebar-header">
|
<div class="sidebar-header">
|
||||||
<!-- Custom Logo or Default Logo -->
|
<!-- Custom Logo or Default Logo -->
|
||||||
<div class="flex h-9 w-9 items-center justify-center overflow-hidden rounded-xl shadow-glow">
|
<div class="flex h-9 w-9 items-center justify-center overflow-hidden rounded-xl shadow-glow">
|
||||||
<img :src="siteLogo || '/logo.png'" alt="Logo" class="h-full w-full object-contain" />
|
<img v-if="settingsLoaded" :src="siteLogo || '/logo.png'" alt="Logo" class="h-full w-full object-contain" />
|
||||||
</div>
|
</div>
|
||||||
<transition name="fade">
|
<transition name="fade">
|
||||||
<div v-if="!sidebarCollapsed" class="flex flex-col">
|
<div v-if="!sidebarCollapsed" class="flex flex-col">
|
||||||
@@ -167,6 +167,7 @@ const isDark = ref(document.documentElement.classList.contains('dark'))
|
|||||||
const siteName = computed(() => appStore.siteName)
|
const siteName = computed(() => appStore.siteName)
|
||||||
const siteLogo = computed(() => appStore.siteLogo)
|
const siteLogo = computed(() => appStore.siteLogo)
|
||||||
const siteVersion = computed(() => appStore.siteVersion)
|
const siteVersion = computed(() => appStore.siteVersion)
|
||||||
|
const settingsLoaded = computed(() => appStore.publicSettingsLoaded)
|
||||||
|
|
||||||
// SVG Icon Components
|
// SVG Icon Components
|
||||||
const DashboardIcon = {
|
const DashboardIcon = {
|
||||||
|
|||||||
@@ -29,17 +29,19 @@
|
|||||||
<!-- Logo/Brand -->
|
<!-- Logo/Brand -->
|
||||||
<div class="mb-8 text-center">
|
<div class="mb-8 text-center">
|
||||||
<!-- Custom Logo or Default Logo -->
|
<!-- Custom Logo or Default Logo -->
|
||||||
<div
|
<template v-if="settingsLoaded">
|
||||||
class="mb-4 inline-flex h-16 w-16 items-center justify-center overflow-hidden rounded-2xl shadow-lg shadow-primary-500/30"
|
<div
|
||||||
>
|
class="mb-4 inline-flex h-16 w-16 items-center justify-center overflow-hidden rounded-2xl shadow-lg shadow-primary-500/30"
|
||||||
<img :src="siteLogo || '/logo.png'" alt="Logo" class="h-full w-full object-contain" />
|
>
|
||||||
</div>
|
<img :src="siteLogo || '/logo.png'" alt="Logo" class="h-full w-full object-contain" />
|
||||||
<h1 class="text-gradient mb-2 text-3xl font-bold">
|
</div>
|
||||||
{{ siteName }}
|
<h1 class="text-gradient mb-2 text-3xl font-bold">
|
||||||
</h1>
|
{{ siteName }}
|
||||||
<p class="text-sm text-gray-500 dark:text-dark-400">
|
</h1>
|
||||||
{{ siteSubtitle }}
|
<p class="text-sm text-gray-500 dark:text-dark-400">
|
||||||
</p>
|
{{ siteSubtitle }}
|
||||||
|
</p>
|
||||||
|
</template>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Card Container -->
|
<!-- Card Container -->
|
||||||
@@ -61,25 +63,21 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onMounted } from 'vue'
|
import { computed, onMounted } from 'vue'
|
||||||
import { getPublicSettings } from '@/api/auth'
|
import { useAppStore } from '@/stores'
|
||||||
import { sanitizeUrl } from '@/utils/url'
|
import { sanitizeUrl } from '@/utils/url'
|
||||||
|
|
||||||
const siteName = ref('Sub2API')
|
const appStore = useAppStore()
|
||||||
const siteLogo = ref('')
|
|
||||||
const siteSubtitle = ref('Subscription to API Conversion Platform')
|
const siteName = computed(() => appStore.siteName || 'Sub2API')
|
||||||
|
const siteLogo = computed(() => sanitizeUrl(appStore.siteLogo || '', { allowRelative: true, allowDataUrl: true }))
|
||||||
|
const siteSubtitle = computed(() => appStore.cachedPublicSettings?.site_subtitle || 'Subscription to API Conversion Platform')
|
||||||
|
const settingsLoaded = computed(() => appStore.publicSettingsLoaded)
|
||||||
|
|
||||||
const currentYear = computed(() => new Date().getFullYear())
|
const currentYear = computed(() => new Date().getFullYear())
|
||||||
|
|
||||||
onMounted(async () => {
|
onMounted(() => {
|
||||||
try {
|
appStore.fetchPublicSettings()
|
||||||
const settings = await getPublicSettings()
|
|
||||||
siteName.value = settings.site_name || 'Sub2API'
|
|
||||||
siteLogo.value = sanitizeUrl(settings.site_logo || '', { allowRelative: true })
|
|
||||||
siteSubtitle.value = settings.site_subtitle || 'Subscription to API Conversion Platform'
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Failed to load public settings:', error)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
|||||||
@@ -83,6 +83,35 @@ export function useAntigravityOAuth() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const validateRefreshToken = async (
|
||||||
|
refreshToken: string,
|
||||||
|
proxyId?: number | null
|
||||||
|
): Promise<AntigravityTokenInfo | null> => {
|
||||||
|
if (!refreshToken.trim()) {
|
||||||
|
error.value = t('admin.accounts.oauth.antigravity.pleaseEnterRefreshToken')
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
loading.value = true
|
||||||
|
error.value = ''
|
||||||
|
|
||||||
|
try {
|
||||||
|
const tokenInfo = await adminAPI.antigravity.refreshAntigravityToken(
|
||||||
|
refreshToken.trim(),
|
||||||
|
proxyId
|
||||||
|
)
|
||||||
|
return tokenInfo as AntigravityTokenInfo
|
||||||
|
} catch (err: any) {
|
||||||
|
error.value =
|
||||||
|
err.response?.data?.detail || t('admin.accounts.oauth.antigravity.failedToValidateRT')
|
||||||
|
// Don't show global error toast for batch validation to avoid spamming
|
||||||
|
// appStore.showError(error.value)
|
||||||
|
return null
|
||||||
|
} finally {
|
||||||
|
loading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const buildCredentials = (tokenInfo: AntigravityTokenInfo): Record<string, unknown> => {
|
const buildCredentials = (tokenInfo: AntigravityTokenInfo): Record<string, unknown> => {
|
||||||
let expiresAt: string | undefined
|
let expiresAt: string | undefined
|
||||||
if (typeof tokenInfo.expires_at === 'number' && Number.isFinite(tokenInfo.expires_at)) {
|
if (typeof tokenInfo.expires_at === 'number' && Number.isFinite(tokenInfo.expires_at)) {
|
||||||
@@ -110,6 +139,7 @@ export function useAntigravityOAuth() {
|
|||||||
resetState,
|
resetState,
|
||||||
generateAuthUrl,
|
generateAuthUrl,
|
||||||
exchangeAuthCode,
|
exchangeAuthCode,
|
||||||
|
validateRefreshToken,
|
||||||
buildCredentials
|
buildCredentials
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -841,7 +841,7 @@ export default {
|
|||||||
createUser: 'Create User',
|
createUser: 'Create User',
|
||||||
editUser: 'Edit User',
|
editUser: 'Edit User',
|
||||||
deleteUser: 'Delete User',
|
deleteUser: 'Delete User',
|
||||||
searchUsers: 'Search users...',
|
searchUsers: 'Search by email, username, notes, or API key...',
|
||||||
allRoles: 'All Roles',
|
allRoles: 'All Roles',
|
||||||
allStatus: 'All Status',
|
allStatus: 'All Status',
|
||||||
admin: 'Admin',
|
admin: 'Admin',
|
||||||
@@ -1309,10 +1309,23 @@ export default {
|
|||||||
syncResult: 'Sync Result',
|
syncResult: 'Sync Result',
|
||||||
syncResultSummary: 'Created {created}, updated {updated}, skipped {skipped}, failed {failed}',
|
syncResultSummary: 'Created {created}, updated {updated}, skipped {skipped}, failed {failed}',
|
||||||
syncErrors: 'Errors / Skipped Details',
|
syncErrors: 'Errors / Skipped Details',
|
||||||
syncCompleted: 'Sync completed: created {created}, updated {updated}',
|
syncCompleted: 'Sync completed: created {created}, updated {updated}, skipped {skipped}',
|
||||||
syncCompletedWithErrors:
|
syncCompletedWithErrors:
|
||||||
'Sync completed with errors: failed {failed} (created {created}, updated {updated})',
|
'Sync completed with errors: failed {failed} (created {created}, updated {updated}, skipped {skipped})',
|
||||||
syncFailed: 'Sync failed',
|
syncFailed: 'Sync failed',
|
||||||
|
crsPreview: 'Preview',
|
||||||
|
crsPreviewing: 'Previewing...',
|
||||||
|
crsPreviewFailed: 'Preview failed',
|
||||||
|
crsExistingAccounts: 'Existing accounts (will be updated)',
|
||||||
|
crsNewAccounts: 'New accounts (select to sync)',
|
||||||
|
crsSelectAll: 'Select all',
|
||||||
|
crsSelectNone: 'Select none',
|
||||||
|
crsNoNewAccounts: 'All CRS accounts are already synced.',
|
||||||
|
crsWillUpdate: 'Will update {count} existing accounts.',
|
||||||
|
crsSelectedCount: '{count} new accounts selected',
|
||||||
|
crsUpdateBehaviorNote:
|
||||||
|
'Existing accounts only sync fields returned by CRS; missing fields keep their current values. Credentials are merged by key — keys not returned by CRS are preserved. Proxies are kept when "Sync proxies" is unchecked.',
|
||||||
|
crsBack: 'Back',
|
||||||
editAccount: 'Edit Account',
|
editAccount: 'Edit Account',
|
||||||
deleteAccount: 'Delete Account',
|
deleteAccount: 'Delete Account',
|
||||||
searchAccounts: 'Search accounts...',
|
searchAccounts: 'Search accounts...',
|
||||||
@@ -1322,6 +1335,7 @@ export default {
|
|||||||
allPlatforms: 'All Platforms',
|
allPlatforms: 'All Platforms',
|
||||||
allTypes: 'All Types',
|
allTypes: 'All Types',
|
||||||
allStatus: 'All Status',
|
allStatus: 'All Status',
|
||||||
|
allGroups: 'All Groups',
|
||||||
oauthType: 'OAuth',
|
oauthType: 'OAuth',
|
||||||
setupToken: 'Setup Token',
|
setupToken: 'Setup Token',
|
||||||
apiKey: 'API Key',
|
apiKey: 'API Key',
|
||||||
@@ -1331,7 +1345,7 @@ export default {
|
|||||||
schedulableEnabled: 'Scheduling enabled',
|
schedulableEnabled: 'Scheduling enabled',
|
||||||
schedulableDisabled: 'Scheduling disabled',
|
schedulableDisabled: 'Scheduling disabled',
|
||||||
failedToToggleSchedulable: 'Failed to toggle scheduling status',
|
failedToToggleSchedulable: 'Failed to toggle scheduling status',
|
||||||
allGroups: '{count} groups total',
|
groupCountTotal: '{count} groups total',
|
||||||
platforms: {
|
platforms: {
|
||||||
anthropic: 'Anthropic',
|
anthropic: 'Anthropic',
|
||||||
claude: 'Claude',
|
claude: 'Claude',
|
||||||
@@ -1346,6 +1360,7 @@ export default {
|
|||||||
googleOauth: 'Google OAuth',
|
googleOauth: 'Google OAuth',
|
||||||
codeAssist: 'Code Assist',
|
codeAssist: 'Code Assist',
|
||||||
antigravityOauth: 'Antigravity OAuth',
|
antigravityOauth: 'Antigravity OAuth',
|
||||||
|
antigravityApikey: 'Connect via Base URL + API Key',
|
||||||
upstream: 'Upstream',
|
upstream: 'Upstream',
|
||||||
upstreamDesc: 'Connect via Base URL + API Key'
|
upstreamDesc: 'Connect via Base URL + API Key'
|
||||||
},
|
},
|
||||||
@@ -1612,7 +1627,7 @@ export default {
|
|||||||
// Upstream type
|
// Upstream type
|
||||||
upstream: {
|
upstream: {
|
||||||
baseUrl: 'Upstream Base URL',
|
baseUrl: 'Upstream Base URL',
|
||||||
baseUrlHint: 'The address of the upstream Antigravity service, e.g., https://s.konstants.xyz',
|
baseUrlHint: 'The address of the upstream Antigravity service, e.g., https://cloudcode-pa.googleapis.com',
|
||||||
apiKey: 'Upstream API Key',
|
apiKey: 'Upstream API Key',
|
||||||
apiKeyHint: 'API Key for the upstream service',
|
apiKeyHint: 'API Key for the upstream service',
|
||||||
pleaseEnterBaseUrl: 'Please enter upstream Base URL',
|
pleaseEnterBaseUrl: 'Please enter upstream Base URL',
|
||||||
@@ -1760,13 +1775,20 @@ export default {
|
|||||||
authCode: 'Authorization URL or Code',
|
authCode: 'Authorization URL or Code',
|
||||||
authCodePlaceholder:
|
authCodePlaceholder:
|
||||||
'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value',
|
'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value',
|
||||||
authCodeHint: 'You can copy the entire URL or just the code parameter value, the system will auto-detect',
|
authCodeHint: 'You can copy the entire URL or just the code parameter value, the system will auto-detect',
|
||||||
failedToGenerateUrl: 'Failed to generate Antigravity auth URL',
|
failedToGenerateUrl: 'Failed to generate Antigravity auth URL',
|
||||||
missingExchangeParams: 'Missing code, session ID, or state',
|
missingExchangeParams: 'Missing code, session ID, or state',
|
||||||
failedToExchangeCode: 'Failed to exchange Antigravity auth code'
|
failedToExchangeCode: 'Failed to exchange Antigravity auth code',
|
||||||
}
|
// Refresh Token auth
|
||||||
},
|
refreshTokenAuth: 'Manual RT',
|
||||||
// Gemini specific (platform-wide)
|
refreshTokenDesc: 'Enter your existing Antigravity Refresh Token. Supports batch input (one per line). The system will automatically validate and create accounts.',
|
||||||
|
refreshTokenPlaceholder: 'Paste your Antigravity Refresh Token...\nSupports multiple tokens, one per line',
|
||||||
|
validating: 'Validating...',
|
||||||
|
validateAndCreate: 'Validate & Create',
|
||||||
|
pleaseEnterRefreshToken: 'Please enter Refresh Token',
|
||||||
|
failedToValidateRT: 'Failed to validate Refresh Token'
|
||||||
|
}
|
||||||
|
}, // Gemini specific (platform-wide)
|
||||||
gemini: {
|
gemini: {
|
||||||
helpButton: 'Help',
|
helpButton: 'Help',
|
||||||
helpDialog: {
|
helpDialog: {
|
||||||
@@ -2115,7 +2137,7 @@ export default {
|
|||||||
title: 'Redeem Code Management',
|
title: 'Redeem Code Management',
|
||||||
description: 'Generate and manage redeem codes',
|
description: 'Generate and manage redeem codes',
|
||||||
generateCodes: 'Generate Codes',
|
generateCodes: 'Generate Codes',
|
||||||
searchCodes: 'Search codes...',
|
searchCodes: 'Search codes or email...',
|
||||||
allTypes: 'All Types',
|
allTypes: 'All Types',
|
||||||
allStatus: 'All Status',
|
allStatus: 'All Status',
|
||||||
balance: 'Balance',
|
balance: 'Balance',
|
||||||
@@ -2338,6 +2360,8 @@ export default {
|
|||||||
inputTokens: 'Input Tokens',
|
inputTokens: 'Input Tokens',
|
||||||
outputTokens: 'Output Tokens',
|
outputTokens: 'Output Tokens',
|
||||||
cacheCreationTokens: 'Cache Creation Tokens',
|
cacheCreationTokens: 'Cache Creation Tokens',
|
||||||
|
cacheCreation5mTokens: 'Cache Write',
|
||||||
|
cacheCreation1hTokens: 'Cache Write',
|
||||||
cacheReadTokens: 'Cache Read Tokens',
|
cacheReadTokens: 'Cache Read Tokens',
|
||||||
failedToLoad: 'Failed to load usage records',
|
failedToLoad: 'Failed to load usage records',
|
||||||
billingType: 'Billing Type',
|
billingType: 'Billing Type',
|
||||||
@@ -3339,6 +3363,7 @@ export default {
|
|||||||
custom: 'Custom',
|
custom: 'Custom',
|
||||||
code: 'Code',
|
code: 'Code',
|
||||||
body: 'Body',
|
body: 'Body',
|
||||||
|
skipMonitoring: 'Skip Monitoring',
|
||||||
|
|
||||||
// Columns
|
// Columns
|
||||||
columns: {
|
columns: {
|
||||||
@@ -3383,6 +3408,8 @@ export default {
|
|||||||
passthroughBody: 'Passthrough upstream error message',
|
passthroughBody: 'Passthrough upstream error message',
|
||||||
customMessage: 'Custom error message',
|
customMessage: 'Custom error message',
|
||||||
customMessagePlaceholder: 'Error message to return to client...',
|
customMessagePlaceholder: 'Error message to return to client...',
|
||||||
|
skipMonitoring: 'Skip monitoring',
|
||||||
|
skipMonitoringHint: 'When enabled, errors matching this rule will not be recorded in ops monitoring',
|
||||||
enabled: 'Enable this rule'
|
enabled: 'Enable this rule'
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|||||||
@@ -865,8 +865,8 @@ export default {
|
|||||||
editUser: '编辑用户',
|
editUser: '编辑用户',
|
||||||
deleteUser: '删除用户',
|
deleteUser: '删除用户',
|
||||||
deleteConfirmMessage: "确定要删除用户 '{email}' 吗?此操作无法撤销。",
|
deleteConfirmMessage: "确定要删除用户 '{email}' 吗?此操作无法撤销。",
|
||||||
searchPlaceholder: '搜索用户邮箱或用户名、备注、支持模糊查询...',
|
searchPlaceholder: '邮箱/用户名/备注/API Key 模糊搜索...',
|
||||||
searchUsers: '搜索用户邮箱或用户名、备注、支持模糊查询',
|
searchUsers: '邮箱/用户名/备注/API Key 模糊搜索',
|
||||||
roleFilter: '角色筛选',
|
roleFilter: '角色筛选',
|
||||||
allRoles: '全部角色',
|
allRoles: '全部角色',
|
||||||
allStatus: '全部状态',
|
allStatus: '全部状态',
|
||||||
@@ -1397,9 +1397,22 @@ export default {
|
|||||||
syncResult: '同步结果',
|
syncResult: '同步结果',
|
||||||
syncResultSummary: '创建 {created},更新 {updated},跳过 {skipped},失败 {failed}',
|
syncResultSummary: '创建 {created},更新 {updated},跳过 {skipped},失败 {failed}',
|
||||||
syncErrors: '错误/跳过详情',
|
syncErrors: '错误/跳过详情',
|
||||||
syncCompleted: '同步完成:创建 {created},更新 {updated}',
|
syncCompleted: '同步完成:创建 {created},更新 {updated},跳过 {skipped}',
|
||||||
syncCompletedWithErrors: '同步完成但有错误:失败 {failed}(创建 {created},更新 {updated})',
|
syncCompletedWithErrors: '同步完成但有错误:失败 {failed}(创建 {created},更新 {updated},跳过 {skipped})',
|
||||||
syncFailed: '同步失败',
|
syncFailed: '同步失败',
|
||||||
|
crsPreview: '预览',
|
||||||
|
crsPreviewing: '预览中...',
|
||||||
|
crsPreviewFailed: '预览失败',
|
||||||
|
crsExistingAccounts: '将自动更新的已有账号',
|
||||||
|
crsNewAccounts: '新账号(可选择)',
|
||||||
|
crsSelectAll: '全选',
|
||||||
|
crsSelectNone: '全不选',
|
||||||
|
crsNoNewAccounts: '所有 CRS 账号均已同步。',
|
||||||
|
crsWillUpdate: '将更新 {count} 个已有账号。',
|
||||||
|
crsSelectedCount: '已选择 {count} 个新账号',
|
||||||
|
crsUpdateBehaviorNote:
|
||||||
|
'已有账号仅同步 CRS 返回的字段,缺失字段保持原值;凭据按键合并,不会清空未下发的键;未勾选"同步代理"时保留原有代理。',
|
||||||
|
crsBack: '返回',
|
||||||
editAccount: '编辑账号',
|
editAccount: '编辑账号',
|
||||||
deleteAccount: '删除账号',
|
deleteAccount: '删除账号',
|
||||||
deleteConfirmMessage: "确定要删除账号 '{name}' 吗?",
|
deleteConfirmMessage: "确定要删除账号 '{name}' 吗?",
|
||||||
@@ -1413,6 +1426,7 @@ export default {
|
|||||||
allPlatforms: '全部平台',
|
allPlatforms: '全部平台',
|
||||||
allTypes: '全部类型',
|
allTypes: '全部类型',
|
||||||
allStatus: '全部状态',
|
allStatus: '全部状态',
|
||||||
|
allGroups: '全部分组',
|
||||||
oauthType: 'OAuth',
|
oauthType: 'OAuth',
|
||||||
// Schedulable toggle
|
// Schedulable toggle
|
||||||
schedulable: '参与调度',
|
schedulable: '参与调度',
|
||||||
@@ -1420,7 +1434,7 @@ export default {
|
|||||||
schedulableEnabled: '调度已开启',
|
schedulableEnabled: '调度已开启',
|
||||||
schedulableDisabled: '调度已关闭',
|
schedulableDisabled: '调度已关闭',
|
||||||
failedToToggleSchedulable: '切换调度状态失败',
|
failedToToggleSchedulable: '切换调度状态失败',
|
||||||
allGroups: '共 {count} 个分组',
|
groupCountTotal: '共 {count} 个分组',
|
||||||
columns: {
|
columns: {
|
||||||
name: '名称',
|
name: '名称',
|
||||||
platformType: '平台/类型',
|
platformType: '平台/类型',
|
||||||
@@ -1480,6 +1494,7 @@ export default {
|
|||||||
googleOauth: 'Google OAuth',
|
googleOauth: 'Google OAuth',
|
||||||
codeAssist: 'Code Assist',
|
codeAssist: 'Code Assist',
|
||||||
antigravityOauth: 'Antigravity OAuth',
|
antigravityOauth: 'Antigravity OAuth',
|
||||||
|
antigravityApikey: '通过 Base URL + API Key 连接',
|
||||||
upstream: '对接上游',
|
upstream: '对接上游',
|
||||||
upstreamDesc: '通过 Base URL + API Key 连接上游',
|
upstreamDesc: '通过 Base URL + API Key 连接上游',
|
||||||
api_key: 'API Key',
|
api_key: 'API Key',
|
||||||
@@ -1758,7 +1773,7 @@ export default {
|
|||||||
// Upstream type
|
// Upstream type
|
||||||
upstream: {
|
upstream: {
|
||||||
baseUrl: '上游 Base URL',
|
baseUrl: '上游 Base URL',
|
||||||
baseUrlHint: '上游 Antigravity 服务的地址,例如:https://s.konstants.xyz',
|
baseUrlHint: '上游 Antigravity 服务的地址,例如:https://cloudcode-pa.googleapis.com',
|
||||||
apiKey: '上游 API Key',
|
apiKey: '上游 API Key',
|
||||||
apiKeyHint: '上游服务的 API Key',
|
apiKeyHint: '上游服务的 API Key',
|
||||||
pleaseEnterBaseUrl: '请输入上游 Base URL',
|
pleaseEnterBaseUrl: '请输入上游 Base URL',
|
||||||
@@ -1899,7 +1914,15 @@ export default {
|
|||||||
authCodeHint: '您可以直接复制整个链接或仅复制 code 参数值,系统会自动识别',
|
authCodeHint: '您可以直接复制整个链接或仅复制 code 参数值,系统会自动识别',
|
||||||
failedToGenerateUrl: '生成 Antigravity 授权链接失败',
|
failedToGenerateUrl: '生成 Antigravity 授权链接失败',
|
||||||
missingExchangeParams: '缺少 code / session_id / state',
|
missingExchangeParams: '缺少 code / session_id / state',
|
||||||
failedToExchangeCode: 'Antigravity 授权码兑换失败'
|
failedToExchangeCode: 'Antigravity 授权码兑换失败',
|
||||||
|
// Refresh Token auth
|
||||||
|
refreshTokenAuth: '手动输入 RT',
|
||||||
|
refreshTokenDesc: '输入您已有的 Antigravity Refresh Token,支持批量输入(每行一个),系统将自动验证并创建账号。',
|
||||||
|
refreshTokenPlaceholder: '粘贴您的 Antigravity Refresh Token...\n支持多个,每行一个',
|
||||||
|
validating: '验证中...',
|
||||||
|
validateAndCreate: '验证并创建账号',
|
||||||
|
pleaseEnterRefreshToken: '请输入 Refresh Token',
|
||||||
|
failedToValidateRT: '验证 Refresh Token 失败'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// Gemini specific (platform-wide)
|
// Gemini specific (platform-wide)
|
||||||
@@ -2278,7 +2301,7 @@ export default {
|
|||||||
allStatus: '全部状态',
|
allStatus: '全部状态',
|
||||||
unused: '未使用',
|
unused: '未使用',
|
||||||
used: '已使用',
|
used: '已使用',
|
||||||
searchCodes: '搜索兑换码...',
|
searchCodes: '搜索兑换码或邮箱...',
|
||||||
exportCsv: '导出 CSV',
|
exportCsv: '导出 CSV',
|
||||||
deleteAllUnused: '删除全部未使用',
|
deleteAllUnused: '删除全部未使用',
|
||||||
deleteCodeConfirm: '确定要删除此兑换码吗?此操作无法撤销。',
|
deleteCodeConfirm: '确定要删除此兑换码吗?此操作无法撤销。',
|
||||||
@@ -2504,6 +2527,8 @@ export default {
|
|||||||
inputTokens: '输入 Token',
|
inputTokens: '输入 Token',
|
||||||
outputTokens: '输出 Token',
|
outputTokens: '输出 Token',
|
||||||
cacheCreationTokens: '缓存创建 Token',
|
cacheCreationTokens: '缓存创建 Token',
|
||||||
|
cacheCreation5mTokens: '缓存创建',
|
||||||
|
cacheCreation1hTokens: '缓存创建',
|
||||||
cacheReadTokens: '缓存读取 Token',
|
cacheReadTokens: '缓存读取 Token',
|
||||||
failedToLoad: '加载使用记录失败',
|
failedToLoad: '加载使用记录失败',
|
||||||
billingType: '计费类型',
|
billingType: '计费类型',
|
||||||
@@ -3513,6 +3538,7 @@ export default {
|
|||||||
custom: '自定义',
|
custom: '自定义',
|
||||||
code: '状态码',
|
code: '状态码',
|
||||||
body: '消息体',
|
body: '消息体',
|
||||||
|
skipMonitoring: '跳过监控',
|
||||||
|
|
||||||
// Columns
|
// Columns
|
||||||
columns: {
|
columns: {
|
||||||
@@ -3557,6 +3583,8 @@ export default {
|
|||||||
passthroughBody: '透传上游错误信息',
|
passthroughBody: '透传上游错误信息',
|
||||||
customMessage: '自定义错误信息',
|
customMessage: '自定义错误信息',
|
||||||
customMessagePlaceholder: '返回给客户端的错误信息...',
|
customMessagePlaceholder: '返回给客户端的错误信息...',
|
||||||
|
skipMonitoring: '跳过运维监控记录',
|
||||||
|
skipMonitoringHint: '开启后,匹配此规则的错误不会被记录到运维监控中',
|
||||||
enabled: '启用此规则'
|
enabled: '启用此规则'
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|||||||
@@ -243,7 +243,7 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
.stat-value {
|
.stat-value {
|
||||||
@apply text-2xl font-bold text-gray-900 dark:text-white;
|
@apply text-2xl font-bold text-gray-900 dark:text-white truncate;
|
||||||
}
|
}
|
||||||
|
|
||||||
.stat-label {
|
.stat-label {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
*/
|
*/
|
||||||
type SanitizeOptions = {
|
type SanitizeOptions = {
|
||||||
allowRelative?: boolean
|
allowRelative?: boolean
|
||||||
|
allowDataUrl?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
export function sanitizeUrl(value: string, options: SanitizeOptions = {}): string {
|
export function sanitizeUrl(value: string, options: SanitizeOptions = {}): string {
|
||||||
@@ -18,6 +19,11 @@ export function sanitizeUrl(value: string, options: SanitizeOptions = {}): strin
|
|||||||
return trimmed
|
return trimmed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 允许 data:image/ 开头的 data URL(仅限图片类型)
|
||||||
|
if (options.allowDataUrl && trimmed.startsWith('data:image/')) {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
|
||||||
// 只接受绝对 URL,不使用 base URL 来避免相对路径被解析为当前域名
|
// 只接受绝对 URL,不使用 base URL 来避免相对路径被解析为当前域名
|
||||||
// 检查是否以 http:// 或 https:// 开头
|
// 检查是否以 http:// 或 https:// 开头
|
||||||
if (!trimmed.match(/^https?:\/\//i)) {
|
if (!trimmed.match(/^https?:\/\//i)) {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user