Compare commits

...

16 Commits

Author SHA1 Message Date
Wesley Liddick
723102766b Merge pull request #553 from Edric-Li/feat/antigravity-onboard-projectid
feat(antigravity): 添加 onboardUser 支持,修复 project_id 缺失问题
2026-02-11 13:52:44 +08:00
Edric Li
a4a46a8618 feat(antigravity): 添加 onboardUser 支持并修复 project_id 补齐逻辑
- 新增 OnboardUser API 客户端方法,支持账号 onboarding 获取 project_id
- loadProjectIDWithRetry 增加 onboard 回退:LoadCodeAssist 未返回 project_id 时自动触发 onboarding
- GetAccessToken 中 project_id 补齐改用轻量 FillProjectID 替代全量 RefreshAccountToken
- 补齐逻辑增加 5 分钟冷却机制,防止频繁重试
- OnboardUser 轮询等待改为 context 感知,支持提前取消
- 提取 mergeCredentials 辅助方法消除重复代码
- 新增 extractProjectIDFromOnboardResponse 和 resolveDefaultTierID 单元测试
2026-02-11 13:41:55 +08:00
Wesley Liddick
ae6fed15cc Merge pull request #548 from Edric-Li/main
feat: 错误处理增强、重试优化与性能改进
2026-02-10 22:46:58 +08:00
Edric Li
378e476e48 fix: 修复 CI 检查失败
- gofmt: 修复 error_passthrough_service.go 格式问题
- errcheck: 修复 error_passthrough_runtime_test.go 类型断言未检查
- staticcheck: if-else 改为 switch (gateway_service.go)
- test: 修复两个测试用例错误使用 MODEL_CAPACITY_EXHAUSTED 导致走错路径
2026-02-10 22:08:49 +08:00
Edric Li
2a1067c82b Merge remote-tracking branch 'upstream/main' 2026-02-10 21:52:33 +08:00
Edric Li
a54b81cf74 perf: 错误处理性能优化
- MatchRule 延迟/限制 body ToLower,先用 statusCode 短路,只在需要关键词匹配时转换且限制 8KB
- 预计算规则的小写关键词/平台和 error code set,消除运行时重复 ToLower 和线性扫描
- MODEL_CAPACITY_EXHAUSTED 全局去重,避免并发请求重复重试同一模型
- 503 重试 body 读取限制从 2MB 降至 8KB
- time.After 替换为 time.NewTimer,防止 context 取消时 timer 泄漏
2026-02-10 21:40:31 +08:00
Edric Li
2d4236f76e fix: 修复错误透传规则 skip_monitoring 未生效的问题
- ops_error_logger: status < 400 分支增加 OpsSkipPassthroughKey 检查
- ops_upstream_context: 新增 checkSkipMonitoringForUpstreamEvent,中间重试/故障转移事件也能触发跳过标记
- gateway_handler/openai_gateway_handler/gemini_v1beta_handler: handleFailoverExhausted 匹配规则后设置 OpsSkipPassthroughKey
- antigravity_gateway_service: writeMappedClaudeError 增加 applyErrorPassthroughRule 调用
2026-02-10 20:56:01 +08:00
Wesley Liddick
84ced1c497 Merge pull request #543 from slovx2/upstream_main
feat(antigravity): 转发与测试支持 daily/prod 单 URL 切换
2026-02-10 14:57:46 +08:00
song
b161312183 test(antigravity): 更新单URL策略下的重试断言 2026-02-10 14:36:09 +08:00
song
1f647b120a feat(antigravity): 转发与测试支持daily/prod单URL切换 2026-02-10 13:51:29 +08:00
Edric Li
7d0a30fa8f merge: sync upstream main (antigravity single-account 503 retry)
合并上游新增的 Antigravity 单账号 503 退避重试机制,
解决与本地 MODEL_CAPACITY_EXHAUSTED 逻辑的冲突,两者共存。
2026-02-10 12:00:21 +08:00
Edric Li
d95e04fd1f feat: 错误透传规则支持 skip_monitoring 跳过运维监控记录
在每条错误透传规则上新增 skip_monitoring 选项,开启后匹配该规则的错误
不会被记录到 ops_error_logs,减少监控噪音。默认关闭,不影响现有规则。
2026-02-10 11:42:39 +08:00
Edric Li
6114f69cca feat: MODEL_CAPACITY_EXHAUSTED 使用固定1s间隔重试60次,不切换账号
MODEL_CAPACITY_EXHAUSTED (503) 表示模型容量不足,所有账号共享同一容量池,
切换账号无意义。改为固定1s间隔重试最多60次,重试耗尽后直接返回上游错误。

- 新增 antigravityModelCapacityRetryMaxAttempts=60 和 antigravityModelCapacityRetryWait=1s
- shouldTriggerAntigravitySmartRetry 新增 isModelCapacityExhausted 返回值
- handleSmartRetry 对 MODEL_CAPACITY_EXHAUSTED 使用独立重试策略
- handleModelRateLimit 对 MODEL_CAPACITY_EXHAUSTED 仅标记 Handled,不设限流
- 重试耗尽后不设置模型限流、不清除粘性会话、不切换账号
2026-02-10 02:03:06 +08:00
Edric Li
d6c2921f2b feat: same-account retry before failover for transient errors
For retryable transient errors (Google 400 "invalid project resource name"
and empty stream responses), retry on the same account up to 2 times
(with 500ms delay) before switching to another account.

- Add RetryableOnSameAccount field to UpstreamFailoverError
- Add same-account retry loop in both Gemini and Claude/OpenAI handler paths
- Move temp-unschedule from service layer to handler layer (only after
  all same-account retries exhausted)
- Reduce temp-unschedule cooldown from 30 minutes to 1 minute
2026-02-10 00:53:54 +08:00
Edric Li
61c73287dc feat: failover and temp-unschedule on empty stream response
- Empty stream responses now return UpstreamFailoverError instead of
  plain 502, triggering automatic account switching (up to 10 retries)
- Add tempUnscheduleEmptyResponse: accounts returning empty responses
  are temp-unscheduled for 30 minutes
- Apply to both Claude and Gemini non-streaming paths
- Align googleConfigErrorCooldown from 60m to 30m for consistency
2026-02-09 23:25:30 +08:00
Edric Li
89905ec43d feat: failover and temp-unschedule on Google "Invalid project resource name" 400
Google 后端间歇性返回 400 "Invalid project resource name" 错误,
此前该错误直接透传给客户端且不触发账号切换,导致请求失败。

- 在 Antigravity 和 Gemini 两个平台的所有转发路径中,
  精确匹配该错误消息后触发 failover 自动换号重试
- 命中后将账号临时封禁 1 小时,避免反复调度到同一故障账号
- 提取共享函数 isGoogleProjectConfigError / tempUnscheduleGoogleConfigError
  消除跨 Service 的代码重复
2026-02-09 22:48:32 +08:00
37 changed files with 1482 additions and 201 deletions

View File

@@ -44,6 +44,8 @@ type ErrorPassthroughRule struct {
PassthroughBody bool `json:"passthrough_body,omitempty"` PassthroughBody bool `json:"passthrough_body,omitempty"`
// CustomMessage holds the value of the "custom_message" field. // CustomMessage holds the value of the "custom_message" field.
CustomMessage *string `json:"custom_message,omitempty"` CustomMessage *string `json:"custom_message,omitempty"`
// SkipMonitoring holds the value of the "skip_monitoring" field.
SkipMonitoring bool `json:"skip_monitoring,omitempty"`
// Description holds the value of the "description" field. // Description holds the value of the "description" field.
Description *string `json:"description,omitempty"` Description *string `json:"description,omitempty"`
selectValues sql.SelectValues selectValues sql.SelectValues
@@ -56,7 +58,7 @@ func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) {
switch columns[i] { switch columns[i] {
case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms: case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms:
values[i] = new([]byte) values[i] = new([]byte)
case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody: case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody, errorpassthroughrule.FieldSkipMonitoring:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode: case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
@@ -171,6 +173,12 @@ func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) err
_m.CustomMessage = new(string) _m.CustomMessage = new(string)
*_m.CustomMessage = value.String *_m.CustomMessage = value.String
} }
case errorpassthroughrule.FieldSkipMonitoring:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field skip_monitoring", values[i])
} else if value.Valid {
_m.SkipMonitoring = value.Bool
}
case errorpassthroughrule.FieldDescription: case errorpassthroughrule.FieldDescription:
if value, ok := values[i].(*sql.NullString); !ok { if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field description", values[i]) return fmt.Errorf("unexpected type %T for field description", values[i])
@@ -257,6 +265,9 @@ func (_m *ErrorPassthroughRule) String() string {
builder.WriteString(*v) builder.WriteString(*v)
} }
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("skip_monitoring=")
builder.WriteString(fmt.Sprintf("%v", _m.SkipMonitoring))
builder.WriteString(", ")
if v := _m.Description; v != nil { if v := _m.Description; v != nil {
builder.WriteString("description=") builder.WriteString("description=")
builder.WriteString(*v) builder.WriteString(*v)

View File

@@ -39,6 +39,8 @@ const (
FieldPassthroughBody = "passthrough_body" FieldPassthroughBody = "passthrough_body"
// FieldCustomMessage holds the string denoting the custom_message field in the database. // FieldCustomMessage holds the string denoting the custom_message field in the database.
FieldCustomMessage = "custom_message" FieldCustomMessage = "custom_message"
// FieldSkipMonitoring holds the string denoting the skip_monitoring field in the database.
FieldSkipMonitoring = "skip_monitoring"
// FieldDescription holds the string denoting the description field in the database. // FieldDescription holds the string denoting the description field in the database.
FieldDescription = "description" FieldDescription = "description"
// Table holds the table name of the errorpassthroughrule in the database. // Table holds the table name of the errorpassthroughrule in the database.
@@ -61,6 +63,7 @@ var Columns = []string{
FieldResponseCode, FieldResponseCode,
FieldPassthroughBody, FieldPassthroughBody,
FieldCustomMessage, FieldCustomMessage,
FieldSkipMonitoring,
FieldDescription, FieldDescription,
} }
@@ -95,6 +98,8 @@ var (
DefaultPassthroughCode bool DefaultPassthroughCode bool
// DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field. // DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field.
DefaultPassthroughBody bool DefaultPassthroughBody bool
// DefaultSkipMonitoring holds the default value on creation for the "skip_monitoring" field.
DefaultSkipMonitoring bool
) )
// OrderOption defines the ordering options for the ErrorPassthroughRule queries. // OrderOption defines the ordering options for the ErrorPassthroughRule queries.
@@ -155,6 +160,11 @@ func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCustomMessage, opts...).ToFunc() return sql.OrderByField(FieldCustomMessage, opts...).ToFunc()
} }
// BySkipMonitoring orders the results by the skip_monitoring field.
func BySkipMonitoring(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSkipMonitoring, opts...).ToFunc()
}
// ByDescription orders the results by the description field. // ByDescription orders the results by the description field.
func ByDescription(opts ...sql.OrderTermOption) OrderOption { func ByDescription(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDescription, opts...).ToFunc() return sql.OrderByField(FieldDescription, opts...).ToFunc()

View File

@@ -104,6 +104,11 @@ func CustomMessage(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v)) return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v))
} }
// SkipMonitoring applies equality check predicate on the "skip_monitoring" field. It's identical to SkipMonitoringEQ.
func SkipMonitoring(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v))
}
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ. // Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
func Description(v string) predicate.ErrorPassthroughRule { func Description(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
@@ -544,6 +549,16 @@ func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v)) return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v))
} }
// SkipMonitoringEQ applies the EQ predicate on the "skip_monitoring" field.
func SkipMonitoringEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v))
}
// SkipMonitoringNEQ applies the NEQ predicate on the "skip_monitoring" field.
func SkipMonitoringNEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldSkipMonitoring, v))
}
// DescriptionEQ applies the EQ predicate on the "description" field. // DescriptionEQ applies the EQ predicate on the "description" field.
func DescriptionEQ(v string) predicate.ErrorPassthroughRule { func DescriptionEQ(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))

View File

@@ -172,6 +172,20 @@ func (_c *ErrorPassthroughRuleCreate) SetNillableCustomMessage(v *string) *Error
return _c return _c
} }
// SetSkipMonitoring sets the "skip_monitoring" field.
func (_c *ErrorPassthroughRuleCreate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleCreate {
_c.mutation.SetSkipMonitoring(v)
return _c
}
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
func (_c *ErrorPassthroughRuleCreate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleCreate {
if v != nil {
_c.SetSkipMonitoring(*v)
}
return _c
}
// SetDescription sets the "description" field. // SetDescription sets the "description" field.
func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate { func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate {
_c.mutation.SetDescription(v) _c.mutation.SetDescription(v)
@@ -249,6 +263,10 @@ func (_c *ErrorPassthroughRuleCreate) defaults() {
v := errorpassthroughrule.DefaultPassthroughBody v := errorpassthroughrule.DefaultPassthroughBody
_c.mutation.SetPassthroughBody(v) _c.mutation.SetPassthroughBody(v)
} }
if _, ok := _c.mutation.SkipMonitoring(); !ok {
v := errorpassthroughrule.DefaultSkipMonitoring
_c.mutation.SetSkipMonitoring(v)
}
} }
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
@@ -287,6 +305,9 @@ func (_c *ErrorPassthroughRuleCreate) check() error {
if _, ok := _c.mutation.PassthroughBody(); !ok { if _, ok := _c.mutation.PassthroughBody(); !ok {
return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)} return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)}
} }
if _, ok := _c.mutation.SkipMonitoring(); !ok {
return &ValidationError{Name: "skip_monitoring", err: errors.New(`ent: missing required field "ErrorPassthroughRule.skip_monitoring"`)}
}
return nil return nil
} }
@@ -366,6 +387,10 @@ func (_c *ErrorPassthroughRuleCreate) createSpec() (*ErrorPassthroughRule, *sqlg
_spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value)
_node.CustomMessage = &value _node.CustomMessage = &value
} }
if value, ok := _c.mutation.SkipMonitoring(); ok {
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
_node.SkipMonitoring = value
}
if value, ok := _c.mutation.Description(); ok { if value, ok := _c.mutation.Description(); ok {
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
_node.Description = &value _node.Description = &value
@@ -608,6 +633,18 @@ func (u *ErrorPassthroughRuleUpsert) ClearCustomMessage() *ErrorPassthroughRuleU
return u return u
} }
// SetSkipMonitoring sets the "skip_monitoring" field.
func (u *ErrorPassthroughRuleUpsert) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsert {
u.Set(errorpassthroughrule.FieldSkipMonitoring, v)
return u
}
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
func (u *ErrorPassthroughRuleUpsert) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsert {
u.SetExcluded(errorpassthroughrule.FieldSkipMonitoring)
return u
}
// SetDescription sets the "description" field. // SetDescription sets the "description" field.
func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert { func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert {
u.Set(errorpassthroughrule.FieldDescription, v) u.Set(errorpassthroughrule.FieldDescription, v)
@@ -888,6 +925,20 @@ func (u *ErrorPassthroughRuleUpsertOne) ClearCustomMessage() *ErrorPassthroughRu
}) })
} }
// SetSkipMonitoring sets the "skip_monitoring" field.
func (u *ErrorPassthroughRuleUpsertOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertOne {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.SetSkipMonitoring(v)
})
}
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
func (u *ErrorPassthroughRuleUpsertOne) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertOne {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.UpdateSkipMonitoring()
})
}
// SetDescription sets the "description" field. // SetDescription sets the "description" field.
func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne { func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne {
return u.Update(func(s *ErrorPassthroughRuleUpsert) { return u.Update(func(s *ErrorPassthroughRuleUpsert) {
@@ -1337,6 +1388,20 @@ func (u *ErrorPassthroughRuleUpsertBulk) ClearCustomMessage() *ErrorPassthroughR
}) })
} }
// SetSkipMonitoring sets the "skip_monitoring" field.
func (u *ErrorPassthroughRuleUpsertBulk) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertBulk {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.SetSkipMonitoring(v)
})
}
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
func (u *ErrorPassthroughRuleUpsertBulk) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertBulk {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.UpdateSkipMonitoring()
})
}
// SetDescription sets the "description" field. // SetDescription sets the "description" field.
func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk { func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk {
return u.Update(func(s *ErrorPassthroughRuleUpsert) { return u.Update(func(s *ErrorPassthroughRuleUpsert) {

View File

@@ -227,6 +227,20 @@ func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRule
return _u return _u
} }
// SetSkipMonitoring sets the "skip_monitoring" field.
func (_u *ErrorPassthroughRuleUpdate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdate {
_u.mutation.SetSkipMonitoring(v)
return _u
}
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetSkipMonitoring(*v)
}
return _u
}
// SetDescription sets the "description" field. // SetDescription sets the "description" field.
func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate { func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate {
_u.mutation.SetDescription(v) _u.mutation.SetDescription(v)
@@ -387,6 +401,9 @@ func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, e
if _u.mutation.CustomMessageCleared() { if _u.mutation.CustomMessageCleared() {
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
} }
if value, ok := _u.mutation.SkipMonitoring(); ok {
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
}
if value, ok := _u.mutation.Description(); ok { if value, ok := _u.mutation.Description(); ok {
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
} }
@@ -611,6 +628,20 @@ func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughR
return _u return _u
} }
// SetSkipMonitoring sets the "skip_monitoring" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetSkipMonitoring(v)
return _u
}
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetSkipMonitoring(*v)
}
return _u
}
// SetDescription sets the "description" field. // SetDescription sets the "description" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne { func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetDescription(v) _u.mutation.SetDescription(v)
@@ -801,6 +832,9 @@ func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *Er
if _u.mutation.CustomMessageCleared() { if _u.mutation.CustomMessageCleared() {
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
} }
if value, ok := _u.mutation.SkipMonitoring(); ok {
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
}
if value, ok := _u.mutation.Description(); ok { if value, ok := _u.mutation.Description(); ok {
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
} }

View File

@@ -325,6 +325,7 @@ var (
{Name: "response_code", Type: field.TypeInt, Nullable: true}, {Name: "response_code", Type: field.TypeInt, Nullable: true},
{Name: "passthrough_body", Type: field.TypeBool, Default: true}, {Name: "passthrough_body", Type: field.TypeBool, Default: true},
{Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647}, {Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647},
{Name: "skip_monitoring", Type: field.TypeBool, Default: false},
{Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647}, {Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647},
} }
// ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table. // ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table.

View File

@@ -5776,6 +5776,7 @@ type ErrorPassthroughRuleMutation struct {
addresponse_code *int addresponse_code *int
passthrough_body *bool passthrough_body *bool
custom_message *string custom_message *string
skip_monitoring *bool
description *string description *string
clearedFields map[string]struct{} clearedFields map[string]struct{}
done bool done bool
@@ -6503,6 +6504,42 @@ func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() {
delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage)
} }
// SetSkipMonitoring sets the "skip_monitoring" field.
func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) {
m.skip_monitoring = &b
}
// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation.
func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) {
v := m.skip_monitoring
if v == nil {
return
}
return *v, true
}
// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity.
// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldSkipMonitoring requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err)
}
return oldValue.SkipMonitoring, nil
}
// ResetSkipMonitoring resets all changes to the "skip_monitoring" field.
func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() {
m.skip_monitoring = nil
}
// SetDescription sets the "description" field. // SetDescription sets the "description" field.
func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { func (m *ErrorPassthroughRuleMutation) SetDescription(s string) {
m.description = &s m.description = &s
@@ -6586,7 +6623,7 @@ func (m *ErrorPassthroughRuleMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *ErrorPassthroughRuleMutation) Fields() []string { func (m *ErrorPassthroughRuleMutation) Fields() []string {
fields := make([]string, 0, 14) fields := make([]string, 0, 15)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, errorpassthroughrule.FieldCreatedAt) fields = append(fields, errorpassthroughrule.FieldCreatedAt)
} }
@@ -6626,6 +6663,9 @@ func (m *ErrorPassthroughRuleMutation) Fields() []string {
if m.custom_message != nil { if m.custom_message != nil {
fields = append(fields, errorpassthroughrule.FieldCustomMessage) fields = append(fields, errorpassthroughrule.FieldCustomMessage)
} }
if m.skip_monitoring != nil {
fields = append(fields, errorpassthroughrule.FieldSkipMonitoring)
}
if m.description != nil { if m.description != nil {
fields = append(fields, errorpassthroughrule.FieldDescription) fields = append(fields, errorpassthroughrule.FieldDescription)
} }
@@ -6663,6 +6703,8 @@ func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) {
return m.PassthroughBody() return m.PassthroughBody()
case errorpassthroughrule.FieldCustomMessage: case errorpassthroughrule.FieldCustomMessage:
return m.CustomMessage() return m.CustomMessage()
case errorpassthroughrule.FieldSkipMonitoring:
return m.SkipMonitoring()
case errorpassthroughrule.FieldDescription: case errorpassthroughrule.FieldDescription:
return m.Description() return m.Description()
} }
@@ -6700,6 +6742,8 @@ func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string
return m.OldPassthroughBody(ctx) return m.OldPassthroughBody(ctx)
case errorpassthroughrule.FieldCustomMessage: case errorpassthroughrule.FieldCustomMessage:
return m.OldCustomMessage(ctx) return m.OldCustomMessage(ctx)
case errorpassthroughrule.FieldSkipMonitoring:
return m.OldSkipMonitoring(ctx)
case errorpassthroughrule.FieldDescription: case errorpassthroughrule.FieldDescription:
return m.OldDescription(ctx) return m.OldDescription(ctx)
} }
@@ -6802,6 +6846,13 @@ func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) er
} }
m.SetCustomMessage(v) m.SetCustomMessage(v)
return nil return nil
case errorpassthroughrule.FieldSkipMonitoring:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetSkipMonitoring(v)
return nil
case errorpassthroughrule.FieldDescription: case errorpassthroughrule.FieldDescription:
v, ok := value.(string) v, ok := value.(string)
if !ok { if !ok {
@@ -6963,6 +7014,9 @@ func (m *ErrorPassthroughRuleMutation) ResetField(name string) error {
case errorpassthroughrule.FieldCustomMessage: case errorpassthroughrule.FieldCustomMessage:
m.ResetCustomMessage() m.ResetCustomMessage()
return nil return nil
case errorpassthroughrule.FieldSkipMonitoring:
m.ResetSkipMonitoring()
return nil
case errorpassthroughrule.FieldDescription: case errorpassthroughrule.FieldDescription:
m.ResetDescription() m.ResetDescription()
return nil return nil

View File

@@ -326,6 +326,10 @@ func init() {
errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor() errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor()
// errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field. // errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field.
errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool) errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool)
// errorpassthroughruleDescSkipMonitoring is the schema descriptor for skip_monitoring field.
errorpassthroughruleDescSkipMonitoring := errorpassthroughruleFields[11].Descriptor()
// errorpassthroughrule.DefaultSkipMonitoring holds the default value on creation for the skip_monitoring field.
errorpassthroughrule.DefaultSkipMonitoring = errorpassthroughruleDescSkipMonitoring.Default.(bool)
groupMixin := schema.Group{}.Mixin() groupMixin := schema.Group{}.Mixin()
groupMixinHooks1 := groupMixin[1].Hooks() groupMixinHooks1 := groupMixin[1].Hooks()
group.Hooks[0] = groupMixinHooks1[0] group.Hooks[0] = groupMixinHooks1[0]

View File

@@ -105,6 +105,12 @@ func (ErrorPassthroughRule) Fields() []ent.Field {
Optional(). Optional().
Nillable(), Nillable(),
// skip_monitoring: 是否跳过运维监控记录
// true: 匹配此规则的错误不会被记录到 ops_error_logs
// false: 正常记录到运维监控(默认行为)
field.Bool("skip_monitoring").
Default(false),
// description: 规则描述,用于说明规则的用途 // description: 规则描述,用于说明规则的用途
field.Text("description"). field.Text("description").
Optional(). Optional().

View File

@@ -32,6 +32,7 @@ type CreateErrorPassthroughRuleRequest struct {
ResponseCode *int `json:"response_code"` ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"` PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"` CustomMessage *string `json:"custom_message"`
SkipMonitoring *bool `json:"skip_monitoring"`
Description *string `json:"description"` Description *string `json:"description"`
} }
@@ -48,6 +49,7 @@ type UpdateErrorPassthroughRuleRequest struct {
ResponseCode *int `json:"response_code"` ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"` PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"` CustomMessage *string `json:"custom_message"`
SkipMonitoring *bool `json:"skip_monitoring"`
Description *string `json:"description"` Description *string `json:"description"`
} }
@@ -122,6 +124,9 @@ func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
} else { } else {
rule.PassthroughBody = true rule.PassthroughBody = true
} }
if req.SkipMonitoring != nil {
rule.SkipMonitoring = *req.SkipMonitoring
}
rule.ResponseCode = req.ResponseCode rule.ResponseCode = req.ResponseCode
rule.CustomMessage = req.CustomMessage rule.CustomMessage = req.CustomMessage
rule.Description = req.Description rule.Description = req.Description
@@ -190,6 +195,7 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
ResponseCode: existing.ResponseCode, ResponseCode: existing.ResponseCode,
PassthroughBody: existing.PassthroughBody, PassthroughBody: existing.PassthroughBody,
CustomMessage: existing.CustomMessage, CustomMessage: existing.CustomMessage,
SkipMonitoring: existing.SkipMonitoring,
Description: existing.Description, Description: existing.Description,
} }
@@ -230,6 +236,9 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
if req.Description != nil { if req.Description != nil {
rule.Description = req.Description rule.Description = req.Description
} }
if req.SkipMonitoring != nil {
rule.SkipMonitoring = *req.SkipMonitoring
}
// 确保切片不为 nil // 确保切片不为 nil
if rule.ErrorCodes == nil { if rule.ErrorCodes == nil {

View File

@@ -235,6 +235,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitchesGemini maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
@@ -359,11 +360,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr lastFailoverErr = failoverErr
if needForceCacheBilling(hasBoundSession, failoverErr) { if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true forceCacheBilling = true
} }
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
return return
@@ -427,6 +445,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitches maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError var lastFailoverErr *service.UpstreamFailoverError
retryWithFallback := false retryWithFallback := false
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
@@ -579,11 +598,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr lastFailoverErr = failoverErr
if needForceCacheBilling(hasBoundSession, failoverErr) { if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true forceCacheBilling = true
} }
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches { if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
return return
@@ -863,6 +899,23 @@ func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFa
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling) return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
} }
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
)
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
func sleepSameAccountRetryDelay(ctx context.Context) bool {
select {
case <-ctx.Done():
return false
case <-time.After(sameAccountRetryDelay):
return true
}
}
// sleepFailoverDelay 账号切换线性递增延时第1次0s、第2次1s、第3次2s… // sleepFailoverDelay 账号切换线性递增延时第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。 // 返回 false 表示 context 已取消。
func sleepFailoverDelay(ctx context.Context, switchCount int) bool { func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
@@ -918,6 +971,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
msg = *rule.CustomMessage msg = *rule.CustomMessage
} }
if rule.SkipMonitoring {
c.Set(service.OpsSkipPassthroughKey, true)
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return return
} }

View File

@@ -554,6 +554,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
msg = *rule.CustomMessage msg = *rule.CustomMessage
} }
if rule.SkipMonitoring {
c.Set(service.OpsSkipPassthroughKey, true)
}
googleError(c, respCode, msg) googleError(c, respCode, msg)
return return
} }

View File

@@ -354,6 +354,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
msg = *rule.CustomMessage msg = *rule.CustomMessage
} }
if rule.SkipMonitoring {
c.Set(service.OpsSkipPassthroughKey, true)
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return return
} }

View File

@@ -537,6 +537,13 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
// Store request headers/body only when an upstream error occurred to keep overhead minimal. // Store request headers/body only when an upstream error occurred to keep overhead minimal.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
// Skip logging if a passthrough rule with skip_monitoring=true matched.
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
if skip, _ := v.(bool); skip {
return
}
}
enqueueOpsErrorLog(ops, entry, requestBody) enqueueOpsErrorLog(ops, entry, requestBody)
return return
} }
@@ -544,6 +551,13 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
body := w.buf.Bytes() body := w.buf.Bytes()
parsed := parseOpsErrorResponse(body) parsed := parseOpsErrorResponse(body)
// Skip logging if a passthrough rule with skip_monitoring=true matched.
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
if skip, _ := v.(bool); skip {
return
}
}
// Skip logging if the error should be filtered based on settings // Skip logging if the error should be filtered based on settings
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) { if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
return return

View File

@@ -18,6 +18,7 @@ type ErrorPassthroughRule struct {
ResponseCode *int `json:"response_code"` // 自定义状态码passthrough_code=false 时使用) ResponseCode *int `json:"response_code"` // 自定义状态码passthrough_code=false 时使用)
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息 PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
CustomMessage *string `json:"custom_message"` // 自定义错误信息passthrough_body=false 时使用) CustomMessage *string `json:"custom_message"` // 自定义错误信息passthrough_body=false 时使用)
SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录
Description *string `json:"description"` // 规则描述 Description *string `json:"description"` // 规则描述
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`

View File

@@ -115,6 +115,23 @@ type LoadCodeAssistResponse struct {
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"` IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
} }
// OnboardUserRequest onboardUser 请求
type OnboardUserRequest struct {
TierID string `json:"tierId"`
Metadata struct {
IDEType string `json:"ideType"`
Platform string `json:"platform,omitempty"`
PluginType string `json:"pluginType,omitempty"`
} `json:"metadata"`
}
// OnboardUserResponse onboardUser 响应
type OnboardUserResponse struct {
Name string `json:"name,omitempty"`
Done bool `json:"done"`
Response map[string]any `json:"response,omitempty"`
}
// GetTier 获取账户类型 // GetTier 获取账户类型
// 优先返回 paidTier付费订阅级别否则返回 currentTier // 优先返回 paidTier付费订阅级别否则返回 currentTier
func (r *LoadCodeAssistResponse) GetTier() string { func (r *LoadCodeAssistResponse) GetTier() string {
@@ -361,6 +378,117 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
return nil, nil, lastErr return nil, nil, lastErr
} }
// OnboardUser 触发账号 onboarding并返回 project_id
// 说明:
// 1) 部分账号 loadCodeAssist 不会立即返回 cloudaicompanionProject
// 2) 这时需要调用 onboardUser 完成初始化,之后才能拿到 project_id。
func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
tierID = strings.TrimSpace(tierID)
if tierID == "" {
return "", fmt.Errorf("tier_id 为空")
}
reqBody := OnboardUserRequest{TierID: tierID}
reqBody.Metadata.IDEType = "ANTIGRAVITY"
reqBody.Metadata.Platform = "PLATFORM_UNSPECIFIED"
reqBody.Metadata.PluginType = "GEMINI"
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
availableURLs := BaseURLs
var lastErr error
for urlIdx, baseURL := range availableURLs {
apiURL := baseURL + "/v1internal:onboardUser"
for attempt := 1; attempt <= 5; attempt++ {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes))
if err != nil {
lastErr = fmt.Errorf("创建请求失败: %w", err)
break
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("onboardUser 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
log.Printf("[antigravity] onboardUser URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
break
}
return "", lastErr
}
respBodyBytes, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
}
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
log.Printf("[antigravity] onboardUser URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
break
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("onboardUser 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
return "", lastErr
}
var onboardResp OnboardUserResponse
if err := json.Unmarshal(respBodyBytes, &onboardResp); err != nil {
lastErr = fmt.Errorf("onboardUser 响应解析失败: %w", err)
return "", lastErr
}
if onboardResp.Done {
if projectID := extractProjectIDFromOnboardResponse(onboardResp.Response); projectID != "" {
DefaultURLAvailability.MarkSuccess(baseURL)
return projectID, nil
}
lastErr = fmt.Errorf("onboardUser 完成但未返回 project_id")
return "", lastErr
}
// done=false 时等待后重试(与 CLIProxyAPI 行为一致)
select {
case <-time.After(2 * time.Second):
case <-ctx.Done():
return "", ctx.Err()
}
}
}
if lastErr != nil {
return "", lastErr
}
return "", fmt.Errorf("onboardUser 未返回 project_id")
}
func extractProjectIDFromOnboardResponse(resp map[string]any) string {
if len(resp) == 0 {
return ""
}
if v, ok := resp["cloudaicompanionProject"]; ok {
switch project := v.(type) {
case string:
return strings.TrimSpace(project)
case map[string]any:
if id, ok := project["id"].(string); ok {
return strings.TrimSpace(id)
}
}
}
return ""
}
// ModelQuotaInfo 模型配额信息 // ModelQuotaInfo 模型配额信息
type ModelQuotaInfo struct { type ModelQuotaInfo struct {
RemainingFraction float64 `json:"remainingFraction"` RemainingFraction float64 `json:"remainingFraction"`

View File

@@ -0,0 +1,76 @@
package antigravity
import (
"testing"
)
func TestExtractProjectIDFromOnboardResponse(t *testing.T) {
t.Parallel()
tests := []struct {
name string
resp map[string]any
want string
}{
{
name: "nil response",
resp: nil,
want: "",
},
{
name: "empty response",
resp: map[string]any{},
want: "",
},
{
name: "project as string",
resp: map[string]any{
"cloudaicompanionProject": "my-project-123",
},
want: "my-project-123",
},
{
name: "project as string with spaces",
resp: map[string]any{
"cloudaicompanionProject": " my-project-123 ",
},
want: "my-project-123",
},
{
name: "project as map with id",
resp: map[string]any{
"cloudaicompanionProject": map[string]any{
"id": "proj-from-map",
},
},
want: "proj-from-map",
},
{
name: "project as map without id",
resp: map[string]any{
"cloudaicompanionProject": map[string]any{
"name": "some-name",
},
},
want: "",
},
{
name: "missing cloudaicompanionProject key",
resp: map[string]any{
"otherField": "value",
},
want: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := extractProjectIDFromOnboardResponse(tc.resp)
if got != tc.want {
t.Fatalf("extractProjectIDFromOnboardResponse() = %q, want %q", got, tc.want)
}
})
}
}

View File

@@ -54,7 +54,8 @@ func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.Err
SetPriority(rule.Priority). SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode). SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode). SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody) SetPassthroughBody(rule.PassthroughBody).
SetSkipMonitoring(rule.SkipMonitoring)
if len(rule.ErrorCodes) > 0 { if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes) builder.SetErrorCodes(rule.ErrorCodes)
@@ -90,7 +91,8 @@ func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.Err
SetPriority(rule.Priority). SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode). SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode). SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody) SetPassthroughBody(rule.PassthroughBody).
SetSkipMonitoring(rule.SkipMonitoring)
// 处理可选字段 // 处理可选字段
if len(rule.ErrorCodes) > 0 { if len(rule.ErrorCodes) > 0 {
@@ -149,6 +151,7 @@ func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model
Platforms: e.Platforms, Platforms: e.Platforms,
PassthroughCode: e.PassthroughCode, PassthroughCode: e.PassthroughCode,
PassthroughBody: e.PassthroughBody, PassthroughBody: e.PassthroughBody,
SkipMonitoring: e.SkipMonitoring,
CreatedAt: e.CreatedAt, CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt, UpdatedAt: e.UpdatedAt,
} }

View File

@@ -16,6 +16,7 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -40,6 +41,12 @@ const (
antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待) antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待)
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用) antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
// MODEL_CAPACITY_EXHAUSTED 专用重试参数
// 模型容量不足时,所有账号共享同一容量池,切换账号无意义
// 使用固定 1s 间隔重试,最多重试 60 次
antigravityModelCapacityRetryMaxAttempts = 60
antigravityModelCapacityRetryWait = 1 * time.Second
// Google RPC 状态和类型常量 // Google RPC 状态和类型常量
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED" googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
googleRPCStatusUnavailable = "UNAVAILABLE" googleRPCStatusUnavailable = "UNAVAILABLE"
@@ -60,6 +67,9 @@ const (
// 单账号 503 退避重试:原地重试的总累计等待时间上限 // 单账号 503 退避重试:原地重试的总累计等待时间上限
// 超过此上限将不再重试,直接返回 503 // 超过此上限将不再重试,直接返回 503
antigravitySingleAccountSmartRetryTotalMaxWait = 30 * time.Second antigravitySingleAccountSmartRetryTotalMaxWait = 30 * time.Second
// MODEL_CAPACITY_EXHAUSTED 全局去重:重试全部失败后的 cooldown 时间
antigravityModelCapacityCooldown = 10 * time.Second
) )
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
@@ -68,8 +78,15 @@ var antigravityPassthroughErrorMessages = []string{
"prompt is too long", "prompt is too long",
} }
// MODEL_CAPACITY_EXHAUSTED 全局去重:避免多个并发请求同时对同一模型进行容量耗尽重试
var (
modelCapacityExhaustedMu sync.RWMutex
modelCapacityExhaustedUntil = make(map[string]time.Time) // modelName -> cooldown until
)
const ( const (
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL" antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
) )
@@ -131,6 +148,20 @@ type antigravityRetryLoopResult struct {
resp *http.Response resp *http.Response
} }
// resolveAntigravityForwardBaseURL 解析转发用 base URL。
// 默认使用 dailyForwardBaseURLs 的首个地址);当环境变量为 prod 时使用第二个地址。
func resolveAntigravityForwardBaseURL() string {
baseURLs := antigravity.ForwardBaseURLs()
if len(baseURLs) == 0 {
return ""
}
mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv)))
if mode == "prod" && len(baseURLs) > 1 {
return baseURLs[1]
}
return baseURLs[0]
}
// smartRetryAction 智能重试的处理结果 // smartRetryAction 智能重试的处理结果
type smartRetryAction int type smartRetryAction int
@@ -158,7 +189,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
} }
// 判断是否触发智能重试 // 判断是否触发智能重试
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody) shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody)
// 情况1: retryDelay >= 阈值,限流模型并切换账号 // 情况1: retryDelay >= 阈值,限流模型并切换账号
if shouldRateLimitModel { if shouldRateLimitModel {
@@ -195,20 +226,48 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
} }
} }
// 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次) // 情况2: retryDelay < 阈值(或 MODEL_CAPACITY_EXHAUSTED智能重试
if shouldSmartRetry { if shouldSmartRetry {
var lastRetryResp *http.Response var lastRetryResp *http.Response
var lastRetryBody []byte var lastRetryBody []byte
for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ { // MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数60 次,固定 1s 间隔)
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d", maxAttempts := antigravitySmartRetryMaxAttempts
p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID) if isModelCapacityExhausted {
maxAttempts = antigravityModelCapacityRetryMaxAttempts
waitDuration = antigravityModelCapacityRetryWait
// 全局去重:如果其他 goroutine 已在重试同一模型且尚在 cooldown 中,直接返回 503
if modelName != "" {
modelCapacityExhaustedMu.RLock()
cooldownUntil, exists := modelCapacityExhaustedUntil[modelName]
modelCapacityExhaustedMu.RUnlock()
if exists && time.Now().Before(cooldownUntil) {
log.Printf("%s status=%d model_capacity_exhausted_dedup model=%s account=%d cooldown_until=%v (skip retry)",
p.prefix, resp.StatusCode, modelName, p.account.ID, cooldownUntil.Format("15:04:05"))
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
},
}
}
}
}
for attempt := 1; attempt <= maxAttempts; attempt++ {
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID)
timer := time.NewTimer(waitDuration)
select { select {
case <-p.ctx.Done(): case <-p.ctx.Done():
timer.Stop()
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix) log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
case <-time.After(waitDuration): case <-timer.C:
} }
// 智能重试:创建新请求 // 智能重试:创建新请求
@@ -228,13 +287,19 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts) log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts)
// 重试成功,清除 MODEL_CAPACITY_EXHAUSTED cooldown
if isModelCapacityExhausted && modelName != "" {
modelCapacityExhaustedMu.Lock()
delete(modelCapacityExhaustedUntil, modelName)
modelCapacityExhaustedMu.Unlock()
}
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp} return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
} }
// 网络错误时,继续重试 // 网络错误时,继续重试
if retryErr != nil || retryResp == nil { if retryErr != nil || retryResp == nil {
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr) log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, maxAttempts, retryErr)
continue continue
} }
@@ -244,13 +309,13 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
} }
lastRetryResp = retryResp lastRetryResp = retryResp
if retryResp != nil { if retryResp != nil {
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
_ = retryResp.Body.Close() _ = retryResp.Body.Close()
} }
// 解析新的重试信息,用于下次重试的等待时间 // 解析新的重试信息,用于下次重试的等待时间MODEL_CAPACITY_EXHAUSTED 使用固定循环,跳过)
if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil { if !isModelCapacityExhausted && attempt < maxAttempts && lastRetryBody != nil {
newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody) newShouldRetry, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
if newShouldRetry && newWaitDuration > 0 { if newShouldRetry && newWaitDuration > 0 {
waitDuration = newWaitDuration waitDuration = newWaitDuration
} }
@@ -267,6 +332,27 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
retryBody = respBody retryBody = respBody
} }
// MODEL_CAPACITY_EXHAUSTED模型容量不足切换账号无意义
// 直接返回上游错误响应,不设置模型限流,不切换账号
if isModelCapacityExhausted {
// 设置 cooldown让后续请求快速失败避免重复重试
if modelName != "" {
modelCapacityExhaustedMu.Lock()
modelCapacityExhaustedUntil[modelName] = time.Now().Add(antigravityModelCapacityCooldown)
modelCapacityExhaustedMu.Unlock()
}
log.Printf("%s status=%d smart_retry_exhausted_model_capacity attempts=%d model=%s account=%d body=%s (model capacity exhausted, not switching account)",
p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200))
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryBody)),
},
}
}
// 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号, // 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号,
// 直接返回 503 让 Handler 层的单账号退避循环做最终处理。 // 直接返回 503 让 Handler 层的单账号退避循环做最终处理。
if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) { if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) {
@@ -283,7 +369,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
} }
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)", log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)",
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200)) p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200))
resetAt := time.Now().Add(rateLimitDuration) resetAt := time.Now().Add(rateLimitDuration)
if p.accountRepo != nil && modelName != "" { if p.accountRepo != nil && modelName != "" {
@@ -367,11 +453,13 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d", log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d",
p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID) p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID)
timer := time.NewTimer(waitDuration)
select { select {
case <-p.ctx.Done(): case <-p.ctx.Done():
timer.Stop()
log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix) log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix)
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
case <-time.After(waitDuration): case <-timer.C:
} }
totalWaited += waitDuration totalWaited += waitDuration
@@ -405,12 +493,12 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
_ = lastRetryResp.Body.Close() _ = lastRetryResp.Body.Close()
} }
lastRetryResp = retryResp lastRetryResp = retryResp
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
_ = retryResp.Body.Close() _ = retryResp.Body.Close()
// 解析新的重试信息,更新下次等待时间 // 解析新的重试信息,更新下次等待时间
if attempt < antigravitySingleAccountSmartRetryMaxAttempts && lastRetryBody != nil { if attempt < antigravitySingleAccountSmartRetryMaxAttempts && lastRetryBody != nil {
_, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody) _, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
if newWaitDuration > 0 { if newWaitDuration > 0 {
waitDuration = newWaitDuration waitDuration = newWaitDuration
if waitDuration > antigravitySingleAccountSmartRetryMaxWait { if waitDuration > antigravitySingleAccountSmartRetryMaxWait {
@@ -466,10 +554,11 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
} }
} }
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() baseURL := resolveAntigravityForwardBaseURL()
if len(availableURLs) == 0 { if baseURL == "" {
availableURLs = antigravity.BaseURLs return nil, errors.New("no antigravity forward base url configured")
} }
availableURLs := []string{baseURL}
var resp *http.Response var resp *http.Response
var usedBaseURL string var usedBaseURL string
@@ -907,11 +996,11 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
// URL fallback 循环 baseURL := resolveAntigravityForwardBaseURL()
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() if baseURL == "" {
if len(availableURLs) == 0 { return nil, errors.New("no antigravity forward base url configured")
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
} }
availableURLs := []string{baseURL}
var lastErr error var lastErr error
for urlIdx, baseURL := range availableURLs { for urlIdx, baseURL := range availableURLs {
@@ -1376,7 +1465,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
break break
} }
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
_ = retryResp.Body.Close() _ = retryResp.Body.Close()
if retryResp.StatusCode == http.StatusTooManyRequests { if retryResp.StatusCode == http.StatusTooManyRequests {
retryBaseURL := "" retryBaseURL := ""
@@ -1457,6 +1546,27 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession) s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
if resp.StatusCode == http.StatusBadRequest {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if isGoogleProjectConfigError(msg) {
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
upstreamDetail := s.getUpstreamErrorDetail(respBody)
log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true}
}
}
if s.shouldFailoverUpstreamError(resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
@@ -1997,6 +2107,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// Always record upstream context for Ops error logs, even when we will failover. // Always record upstream context for Ops error logs, even when we will failover.
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
if resp.StatusCode == http.StatusBadRequest && isGoogleProjectConfigError(strings.ToLower(upstreamMsg)) {
log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps, RetryableOnSameAccount: true}
}
if s.shouldFailoverUpstreamError(resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
@@ -2092,6 +2218,44 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
} }
} }
// isGoogleProjectConfigError 判断(已提取的小写)错误消息是否属于 Google 服务端配置类问题。
// 只精确匹配已知的服务端侧错误,避免对客户端请求错误做无意义重试。
// 适用于所有走 Google 后端的平台Antigravity、Gemini
func isGoogleProjectConfigError(lowerMsg string) bool {
// Google 间歇性 BugProject ID 有效但被临时识别失败
return strings.Contains(lowerMsg, "invalid project resource name")
}
// googleConfigErrorCooldown 服务端配置类 400 错误的临时封禁时长
const googleConfigErrorCooldown = 1 * time.Minute
// tempUnscheduleGoogleConfigError 对服务端配置类 400 错误触发临时封禁,
// 避免短时间内反复调度到同一个有问题的账号。
func tempUnscheduleGoogleConfigError(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) {
until := time.Now().Add(googleConfigErrorCooldown)
reason := "400: invalid project resource name (auto temp-unschedule 1m)"
if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil {
log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err)
} else {
log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason)
}
}
// emptyResponseCooldown 空流式响应的临时封禁时长
const emptyResponseCooldown = 1 * time.Minute
// tempUnscheduleEmptyResponse 对空流式响应触发临时封禁,
// 避免短时间内反复调度到同一个返回空响应的账号。
func tempUnscheduleEmptyResponse(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) {
until := time.Now().Add(emptyResponseCooldown)
reason := "empty stream response (auto temp-unschedule 1m)"
if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil {
log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err)
} else {
log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason)
}
}
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待 // sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
// 返回 true 表示正常完成等待false 表示 context 已取消 // 返回 true 表示正常完成等待false 表示 context 已取消
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
@@ -2108,10 +2272,12 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
sleepFor = 0 sleepFor = 0
} }
timer := time.NewTimer(sleepFor)
select { select {
case <-ctx.Done(): case <-ctx.Done():
timer.Stop()
return false return false
case <-time.After(sleepFor): case <-timer.C:
return true return true
} }
} }
@@ -2156,8 +2322,9 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
// antigravitySmartRetryInfo 智能重试所需的信息 // antigravitySmartRetryInfo 智能重试所需的信息
type antigravitySmartRetryInfo struct { type antigravitySmartRetryInfo struct {
RetryDelay time.Duration // 重试延迟时间 RetryDelay time.Duration // 重试延迟时间
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5" ModelName string // 限流的模型名称(如 "claude-sonnet-4-5"
IsModelCapacityExhausted bool // 是否为模型容量不足MODEL_CAPACITY_EXHAUSTED
} }
// parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息 // parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息
@@ -2272,31 +2439,40 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo {
} }
return &antigravitySmartRetryInfo{ return &antigravitySmartRetryInfo{
RetryDelay: retryDelay, RetryDelay: retryDelay,
ModelName: modelName, ModelName: modelName,
IsModelCapacityExhausted: hasModelCapacityExhausted,
} }
} }
// shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试 // shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试
// 返回: // 返回:
// - shouldRetry: 是否应该智能重试retryDelay < antigravityRateLimitThreshold // - shouldRetry: 是否应该智能重试retryDelay < antigravityRateLimitThreshold,或 MODEL_CAPACITY_EXHAUSTED
// - shouldRateLimitModel: 是否应该限流模型retryDelay >= antigravityRateLimitThreshold // - shouldRateLimitModel: 是否应该限流模型并切换账号(仅 RATE_LIMIT_EXCEEDED 且 retryDelay >= 阈值
// - waitDuration: 等待时间智能重试时使用shouldRateLimitModel=true 时为 0 // - waitDuration: 等待时间
// - modelName: 限流的模型名称 // - modelName: 限流的模型名称
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) { // - isModelCapacityExhausted: 是否为模型容量不足MODEL_CAPACITY_EXHAUSTED
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string, isModelCapacityExhausted bool) {
if account.Platform != PlatformAntigravity { if account.Platform != PlatformAntigravity {
return false, false, 0, "" return false, false, 0, "", false
} }
info := parseAntigravitySmartRetryInfo(respBody) info := parseAntigravitySmartRetryInfo(respBody)
if info == nil { if info == nil {
return false, false, 0, "" return false, false, 0, "", false
} }
// MODEL_CAPACITY_EXHAUSTED模型容量不足所有账号共享同一模型容量池
// 切换账号无意义,使用固定 1s 间隔重试
if info.IsModelCapacityExhausted {
return true, false, antigravityModelCapacityRetryWait, info.ModelName, true
}
// RATE_LIMIT_EXCEEDED账号级限流
// retryDelay >= 阈值:直接限流模型,不重试 // retryDelay >= 阈值:直接限流模型,不重试
// 注意:如果上游未提供 retryDelayparseAntigravitySmartRetryInfo 已设置为默认 30s // 注意:如果上游未提供 retryDelayparseAntigravitySmartRetryInfo 已设置为默认 30s
if info.RetryDelay >= antigravityRateLimitThreshold { if info.RetryDelay >= antigravityRateLimitThreshold {
return false, true, info.RetryDelay, info.ModelName return false, true, info.RetryDelay, info.ModelName, false
} }
// retryDelay < 阈值:智能重试 // retryDelay < 阈值:智能重试
@@ -2305,7 +2481,7 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou
waitDuration = antigravitySmartRetryMinWait waitDuration = antigravitySmartRetryMinWait
} }
return true, false, waitDuration, info.ModelName return true, false, waitDuration, info.ModelName, false
} }
// handleModelRateLimitParams 模型级限流处理参数 // handleModelRateLimitParams 模型级限流处理参数
@@ -2331,8 +2507,9 @@ type handleModelRateLimitResult struct {
// handleModelRateLimit 处理模型级限流(在原有逻辑之前调用) // handleModelRateLimit 处理模型级限流(在原有逻辑之前调用)
// 仅处理 429/503解析模型名和 retryDelay // 仅处理 429/503解析模型名和 retryDelay
// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true由调用方等待后重试 // - MODEL_CAPACITY_EXHAUSTED: 返回 Handled=true实际重试由 handleSmartRetry 处理)
// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError // - RATE_LIMIT_EXCEEDED + retryDelay < 阈值: 返回 ShouldRetry=true由调用方等待后重试
// - RATE_LIMIT_EXCEEDED + retryDelay >= 阈值: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult { func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult {
if p.statusCode != 429 && p.statusCode != 503 { if p.statusCode != 429 && p.statusCode != 503 {
return &handleModelRateLimitResult{Handled: false} return &handleModelRateLimitResult{Handled: false}
@@ -2343,7 +2520,17 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
return &handleModelRateLimitResult{Handled: false} return &handleModelRateLimitResult{Handled: false}
} }
// < antigravityRateLimitThreshold: 等待后重试 // MODEL_CAPACITY_EXHAUSTED模型容量不足所有账号共享同一容量池
// 切换账号无意义,不设置模型限流(实际重试由 handleSmartRetry 处理)
if info.IsModelCapacityExhausted {
log.Printf("%s status=%d model_capacity_exhausted model=%s (not switching account, retry handled by smart retry)",
p.prefix, p.statusCode, info.ModelName)
return &handleModelRateLimitResult{
Handled: true,
}
}
// RATE_LIMIT_EXCEEDED: < antigravityRateLimitThreshold: 等待后重试
if info.RetryDelay < antigravityRateLimitThreshold { if info.RetryDelay < antigravityRateLimitThreshold {
log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v", log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v",
p.prefix, p.statusCode, info.ModelName, info.RetryDelay) p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
@@ -2354,7 +2541,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
} }
} }
// >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号 // RATE_LIMIT_EXCEEDED: >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
s.setModelRateLimitAndClearSession(p, info) s.setModelRateLimitAndClearSession(p, info)
return &handleModelRateLimitResult{ return &handleModelRateLimitResult{
@@ -2902,9 +3089,14 @@ returnResponse:
// 选择最后一个有效响应 // 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts) finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况 // 处理空响应情况 — 触发同账号重试 + failover 切换账号
if last == nil && lastWithParts == nil { if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received") log.Printf("[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover")
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
RetryableOnSameAccount: true,
}
} }
// 如果收集到了图片 parts需要合并到最终响应中 // 如果收集到了图片 parts需要合并到最终响应中
@@ -3122,6 +3314,21 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes)) log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
} }
// 检查错误透传规则
if ptStatus, ptErrType, ptErrMsg, matched := applyErrorPassthroughRule(
c, account.Platform, upstreamStatus, body,
0, "", "",
); matched {
c.JSON(ptStatus, gin.H{
"type": "error",
"error": gin.H{"type": ptErrType, "message": ptErrMsg},
})
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d", upstreamStatus)
}
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
}
var statusCode int var statusCode int
var errType, errMsg string var errType, errMsg string
@@ -3317,10 +3524,14 @@ returnResponse:
// 选择最后一个有效响应 // 选择最后一个有效响应
finalResponse := pickGeminiCollectResult(last, lastWithParts) finalResponse := pickGeminiCollectResult(last, lastWithParts)
// 处理空响应情况 // 处理空响应情况 — 触发同账号重试 + failover 切换账号
if last == nil && lastWithParts == nil { if last == nil && lastWithParts == nil {
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received") log.Printf("[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover")
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream") return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
RetryableOnSameAccount: true,
}
} }
// 将收集的所有 parts 合并到最终响应中 // 将收集的所有 parts 合并到最终响应中

View File

@@ -273,12 +273,21 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
} }
client := antigravity.NewClient(proxyURL) client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken) loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil return loadResp.CloudAICompanionProject, nil
} }
if err == nil {
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
return projectID, nil
} else if onboardErr != nil {
lastErr = onboardErr
continue
}
}
// 记录错误 // 记录错误
if err != nil { if err != nil {
lastErr = err lastErr = err
@@ -292,6 +301,65 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr) return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
} }
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
tierID := resolveDefaultTierID(loadRaw)
if tierID == "" {
return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
}
projectID, err := client.OnboardUser(ctx, accessToken, tierID)
if err != nil {
return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
}
return projectID, nil
}
func resolveDefaultTierID(loadRaw map[string]any) string {
if len(loadRaw) == 0 {
return ""
}
rawTiers, ok := loadRaw["allowedTiers"]
if !ok {
return ""
}
tiers, ok := rawTiers.([]any)
if !ok {
return ""
}
for _, rawTier := range tiers {
tier, ok := rawTier.(map[string]any)
if !ok {
continue
}
if isDefault, _ := tier["isDefault"].(bool); !isDefault {
continue
}
if id, ok := tier["id"].(string); ok {
id = strings.TrimSpace(id)
if id != "" {
return id
}
}
}
return ""
}
// FillProjectID 仅获取 project_id不刷新 OAuth token
func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) {
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
}
// BuildAccountCredentials 构建账户凭证 // BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any { func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{ creds := map[string]any{

View File

@@ -0,0 +1,82 @@
package service
import (
"testing"
)
func TestResolveDefaultTierID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
loadRaw map[string]any
want string
}{
{
name: "nil loadRaw",
loadRaw: nil,
want: "",
},
{
name: "missing allowedTiers",
loadRaw: map[string]any{
"paidTier": map[string]any{"id": "g1-pro-tier"},
},
want: "",
},
{
name: "empty allowedTiers",
loadRaw: map[string]any{"allowedTiers": []any{}},
want: "",
},
{
name: "tier missing id field",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"isDefault": true},
},
},
want: "",
},
{
name: "allowedTiers but no default",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": "free-tier", "isDefault": false},
map[string]any{"id": "standard-tier", "isDefault": false},
},
},
want: "",
},
{
name: "default tier found",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": "free-tier", "isDefault": true},
map[string]any{"id": "standard-tier", "isDefault": false},
},
},
want: "free-tier",
},
{
name: "default tier id with spaces",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": " standard-tier ", "isDefault": true},
},
},
want: "standard-tier",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := resolveDefaultTierID(tc.loadRaw)
if got != tc.want {
t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
}
})
}
}

View File

@@ -86,7 +86,9 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i
return nil return nil
} }
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
t.Setenv(antigravityForwardBaseURLEnv, "")
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability oldAvailability := antigravity.DefaultURLAvailability
defer func() { defer func() {
@@ -131,15 +133,16 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.resp) require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }() defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusOK, result.resp.StatusCode) require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
require.False(t, handleErrorCalled) require.True(t, handleErrorCalled)
require.Len(t, upstream.calls, 2) require.Len(t, upstream.calls, antigravityMaxRetries)
require.True(t, strings.HasPrefix(upstream.calls[0], base1)) for _, callURL := range upstream.calls {
require.True(t, strings.HasPrefix(upstream.calls[1], base2)) require.True(t, strings.HasPrefix(callURL, base1))
}
available := antigravity.DefaultURLAvailability.GetAvailableURLs() available := antigravity.DefaultURLAvailability.GetAvailableURLs()
require.NotEmpty(t, available) require.NotEmpty(t, available)
require.Equal(t, base2, available[0]) require.Equal(t, base1, available[0])
} }
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景 // TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
@@ -188,13 +191,14 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
} }
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景 // TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) { // MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
repo := &stubAntigravityAccountRepo{} repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo} svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity} account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity}
// 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流 // 503 + MODEL_CAPACITY_EXHAUSTED → 等待重试,不切换账号
body := []byte(`{ body := []byte(`{
"error": { "error": {
"status": "UNAVAILABLE", "status": "UNAVAILABLE",
@@ -207,13 +211,13 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 应该触发模型限流 // MODEL_CAPACITY_EXHAUSTED 应该标记为已处理,不切换账号,不设置模型限流
// 实际重试由 handleSmartRetry 处理
require.NotNil(t, result) require.NotNil(t, result)
require.True(t, result.Handled) require.True(t, result.Handled)
require.NotNil(t, result.SwitchError) require.False(t, result.ShouldRetry, "MODEL_CAPACITY_EXHAUSTED should not trigger retry from handleModelRateLimit path")
require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel) require.Nil(t, result.SwitchError, "MODEL_CAPACITY_EXHAUSTED should not trigger account switch")
require.Len(t, repo.modelRateLimitCalls, 1) require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
} }
// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理) // TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理)
@@ -301,11 +305,12 @@ func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) {
func TestParseAntigravitySmartRetryInfo(t *testing.T) { func TestParseAntigravitySmartRetryInfo(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
body string body string
expectedDelay time.Duration expectedDelay time.Duration
expectedModel string expectedModel string
expectedNil bool expectedNil bool
expectedIsModelCapacityExhausted bool
}{ }{
{ {
name: "valid complete response with RATE_LIMIT_EXCEEDED", name: "valid complete response with RATE_LIMIT_EXCEEDED",
@@ -368,8 +373,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
"message": "No capacity available for model gemini-3-pro-high on the server" "message": "No capacity available for model gemini-3-pro-high on the server"
} }
}`, }`,
expectedDelay: 39 * time.Second, expectedDelay: 39 * time.Second,
expectedModel: "gemini-3-pro-high", expectedModel: "gemini-3-pro-high",
expectedIsModelCapacityExhausted: true,
}, },
{ {
name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil", name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil",
@@ -480,6 +486,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
if result.ModelName != tt.expectedModel { if result.ModelName != tt.expectedModel {
t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel) t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel)
} }
if result.IsModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
t.Errorf("IsModelCapacityExhausted = %v, want %v", result.IsModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
}
}) })
} }
} }
@@ -491,13 +500,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
apiKeyAccount := &Account{Type: AccountTypeAPIKey} apiKeyAccount := &Account{Type: AccountTypeAPIKey}
tests := []struct { tests := []struct {
name string name string
account *Account account *Account
body string body string
expectedShouldRetry bool expectedShouldRetry bool
expectedShouldRateLimit bool expectedShouldRateLimit bool
minWait time.Duration expectedIsModelCapacityExhausted bool
modelName string minWait time.Duration
modelName string
}{ }{
{ {
name: "OAuth account with short delay (< 7s) - smart retry", name: "OAuth account with short delay (< 7s) - smart retry",
@@ -611,13 +621,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
] ]
} }
}`, }`,
expectedShouldRetry: false, expectedShouldRetry: true,
expectedShouldRateLimit: true, expectedShouldRateLimit: false,
minWait: 39 * time.Second, expectedIsModelCapacityExhausted: true,
modelName: "gemini-3-pro-high", minWait: 1 * time.Second,
modelName: "gemini-3-pro-high",
}, },
{ {
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit", name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use fixed wait",
account: oauthAccount, account: oauthAccount,
body: `{ body: `{
"error": { "error": {
@@ -629,10 +640,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
"message": "No capacity available for model gemini-2.5-flash on the server" "message": "No capacity available for model gemini-2.5-flash on the server"
} }
}`, }`,
expectedShouldRetry: false, expectedShouldRetry: true,
expectedShouldRateLimit: true, expectedShouldRateLimit: false,
minWait: 30 * time.Second, expectedIsModelCapacityExhausted: true,
modelName: "gemini-2.5-flash", minWait: 1 * time.Second,
modelName: "gemini-2.5-flash",
}, },
{ {
name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit", name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit",
@@ -656,13 +668,16 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body)) shouldRetry, shouldRateLimit, wait, model, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
if shouldRetry != tt.expectedShouldRetry { if shouldRetry != tt.expectedShouldRetry {
t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry) t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry)
} }
if shouldRateLimit != tt.expectedShouldRateLimit { if shouldRateLimit != tt.expectedShouldRateLimit {
t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit) t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit)
} }
if isModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
t.Errorf("isModelCapacityExhausted = %v, want %v", isModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
}
if shouldRetry { if shouldRetry {
if wait < tt.minWait { if wait < tt.minWait {
t.Errorf("wait = %v, want >= %v", wait, tt.minWait) t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
@@ -915,6 +930,22 @@ func TestIsAntigravityAccountSwitchError(t *testing.T) {
} }
} }
func TestResolveAntigravityForwardBaseURL_DefaultDaily(t *testing.T) {
t.Setenv(antigravityForwardBaseURLEnv, "")
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
defer func() {
antigravity.BaseURLs = oldBaseURLs
}()
prodURL := "https://prod.test"
dailyURL := "https://daily.test"
antigravity.BaseURLs = []string{dailyURL, prodURL}
resolved := resolveAntigravityForwardBaseURL()
require.Equal(t, dailyURL, resolved)
}
func TestAntigravityAccountSwitchError_Error(t *testing.T) { func TestAntigravityAccountSwitchError_Error(t *testing.T) {
err := &AntigravityAccountSwitchError{ err := &AntigravityAccountSwitchError{
OriginalAccountID: 789, OriginalAccountID: 789,

View File

@@ -153,13 +153,14 @@ func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *te
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
} }
// 503 + 39s >= 7s 阈值 // 503 + 39s >= 7s 阈值(使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED
// 因为 MODEL_CAPACITY_EXHAUSTED 走独立的重试路径,不触发 shouldRateLimitModel
respBody := []byte(`{ respBody := []byte(`{
"error": { "error": {
"code": 503, "code": 503,
"status": "UNAVAILABLE", "status": "RESOURCE_EXHAUSTED",
"details": [ "details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"} {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
] ]
} }
@@ -339,13 +340,14 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi
// TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit // TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit
// 对照组503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流 // 对照组503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流
// 使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED因为后者走独立的 60 次重试路径
func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) { func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) {
failRespBody := `{ failRespBody := `{
"error": { "error": {
"code": 503, "code": 503,
"status": "UNAVAILABLE", "status": "RESOURCE_EXHAUSTED",
"details": [ "details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
] ]
} }
@@ -371,9 +373,9 @@ func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *t
respBody := []byte(`{ respBody := []byte(`{
"error": { "error": {
"code": 503, "code": 503,
"status": "UNAVAILABLE", "status": "RESOURCE_EXHAUSTED",
"details": [ "details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
] ]
} }

View File

@@ -294,8 +294,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)") require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)")
} }
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError // TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess 测试 503 MODEL_CAPACITY_EXHAUSTED 重试成功
func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) { // MODEL_CAPACITY_EXHAUSTED 使用固定 1s 间隔重试,不切换账号
func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T) {
repo := &stubAntigravityAccountRepo{} repo := &stubAntigravityAccountRepo{}
account := &Account{ account := &Account{
ID: 3, ID: 3,
@@ -304,7 +305,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
} }
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值 // 503 + MODEL_CAPACITY_EXHAUSTED + 39s(上游 retryDelay 应被忽略,使用固定 1s
respBody := []byte(`{ respBody := []byte(`{
"error": { "error": {
"code": 503, "code": 503,
@@ -322,6 +323,14 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
Body: io.NopCloser(bytes.NewReader(respBody)), Body: io.NopCloser(bytes.NewReader(respBody)),
} }
// mock: 第 1 次重试返回 200 成功
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{
{StatusCode: http.StatusOK, Header: http.Header{}, Body: io.NopCloser(strings.NewReader(`{"ok":true}`))},
},
errors: []error{nil},
}
params := antigravityRetryLoopParams{ params := antigravityRetryLoopParams{
ctx: context.Background(), ctx: context.Background(),
prefix: "[test]", prefix: "[test]",
@@ -330,6 +339,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
action: "generateContent", action: "generateContent",
body: []byte(`{"input":"test"}`), body: []byte(`{"input":"test"}`),
accountRepo: repo, accountRepo: repo,
httpUpstream: upstream,
isStickySession: true, isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil return nil
@@ -343,16 +353,67 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action) require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp) require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.err) require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted") require.Nil(t, result.switchError, "MODEL_CAPACITY_EXHAUSTED should not return switchError")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
// 验证模型限流已设置 // 不应设置模型限流
require.Len(t, repo.modelRateLimitCalls, 1) require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey) require.Len(t, upstream.calls, 1, "should have made one retry call before success")
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel 测试 MODEL_CAPACITY_EXHAUSTED 上下文取消
func TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
Name: "acc-3",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
// 立即取消上下文,验证重试循环能正确退出
ctx, cancel := context.WithCancel(context.Background())
cancel()
params := antigravityRetryLoopParams{
ctx: ctx,
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Error(t, result.err, "should return context error")
require.Nil(t, result.switchError, "should not return switchError on context cancel")
require.Empty(t, repo.modelRateLimitCalls, "should not set model rate limit on context cancel")
} }
// TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑 // TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑
@@ -1129,20 +1190,20 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t
} }
// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession // TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
// 503 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定 // 429 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) { func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) {
failRespBody := `{ failRespBody := `{
"error": { "error": {
"code": 503, "code": 429,
"status": "UNAVAILABLE", "status": "RESOURCE_EXHAUSTED",
"details": [ "details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
] ]
} }
}` }`
failResp := &http.Response{ failResp := &http.Response{
StatusCode: http.StatusServiceUnavailable, StatusCode: http.StatusTooManyRequests,
Header: http.Header{}, Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)), Body: io.NopCloser(strings.NewReader(failRespBody)),
} }
@@ -1162,16 +1223,16 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
respBody := []byte(`{ respBody := []byte(`{
"error": { "error": {
"code": 503, "code": 429,
"status": "UNAVAILABLE", "status": "RESOURCE_EXHAUSTED",
"details": [ "details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
] ]
} }
}`) }`)
resp := &http.Response{ resp := &http.Response{
StatusCode: http.StatusServiceUnavailable, StatusCode: http.StatusTooManyRequests,
Header: http.Header{}, Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)), Body: io.NopCloser(bytes.NewReader(respBody)),
} }

View File

@@ -7,12 +7,14 @@ import (
"log/slog" "log/slog"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
const ( const (
antigravityTokenRefreshSkew = 3 * time.Minute antigravityTokenRefreshSkew = 3 * time.Minute
antigravityTokenCacheSkew = 5 * time.Minute antigravityTokenCacheSkew = 5 * time.Minute
antigravityBackfillCooldown = 5 * time.Minute
) )
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) // AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
@@ -23,6 +25,7 @@ type AntigravityTokenProvider struct {
accountRepo AccountRepository accountRepo AccountRepository
tokenCache AntigravityTokenCache tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService antigravityOAuthService *AntigravityOAuthService
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
} }
func NewAntigravityTokenProvider( func NewAntigravityTokenProvider(
@@ -93,13 +96,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if err != nil { if err != nil {
return "", err return "", err
} }
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo) p.mergeCredentials(account, tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr) log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
} }
@@ -113,6 +110,21 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("access_token not found in credentials") return "", errors.New("access_token not found in credentials")
} }
// 如果账号还没有 project_id尝试在线补齐避免请求 daily/sandbox 时出现
// "Invalid project resource name projects/"。
// 仅调用 loadProjectIDWithRetry不刷新 OAuth token带冷却机制防止频繁重试。
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
if p.shouldAttemptBackfill(account.ID) {
p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
}
}
}
}
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil { if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
@@ -144,6 +156,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil return accessToken, nil
} }
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
}
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id冷却期内不重复尝试
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
if v, ok := p.backfillCooldown.Load(accountID); ok {
if lastAttempt, ok := v.(time.Time); ok {
return time.Since(lastAttempt) > antigravityBackfillCooldown
}
}
return true
}
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
p.backfillCooldown.Store(accountID, time.Now())
}
func AntigravityTokenCacheKey(account *Account) string { func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id")) projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" { if projectID != "" {

View File

@@ -61,6 +61,11 @@ func applyErrorPassthroughRule(
errMsg = *rule.CustomMessage errMsg = *rule.CustomMessage
} }
// 命中 skip_monitoring 时在 context 中标记,供 ops_error_logger 跳过记录。
if rule.SkipMonitoring {
c.Set(OpsSkipPassthroughKey, true)
}
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。 // 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
errType = "upstream_error" errType = "upstream_error"
return status, errType, errMsg, true return status, errType, errMsg, true

View File

@@ -194,6 +194,63 @@ func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) {
assert.Equal(t, "Gemini上游失败", errField["message"]) assert.Equal(t, "Gemini上游失败", errField["message"])
} }
func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
rule.SkipMonitoring = true
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
BindErrorPassthroughService(c, ruleSvc)
_, _, _, matched := applyErrorPassthroughRule(
c,
PlatformAnthropic,
http.StatusBadRequest,
[]byte(`{"error":{"message":"prompt is too long"}}`),
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
)
assert.True(t, matched)
v, exists := c.Get(OpsSkipPassthroughKey)
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
boolVal, ok := v.(bool)
assert.True(t, ok, "value should be bool")
assert.True(t, boolVal)
}
func TestApplyErrorPassthroughRule_NoSkipMonitoringDoesNotSetContextKey(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
rule.SkipMonitoring = false
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
BindErrorPassthroughService(c, ruleSvc)
_, _, _, matched := applyErrorPassthroughRule(
c,
PlatformAnthropic,
http.StatusBadRequest,
[]byte(`{"error":{"message":"prompt is too long"}}`),
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
)
assert.True(t, matched)
_, exists := c.Get(OpsSkipPassthroughKey)
assert.False(t, exists, "OpsSkipPassthroughKey should NOT be set when skip_monitoring=false")
}
func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule { func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule {
return &model.ErrorPassthroughRule{ return &model.ErrorPassthroughRule{
ID: 1, ID: 1,

View File

@@ -45,10 +45,20 @@ type ErrorPassthroughService struct {
cache ErrorPassthroughCache cache ErrorPassthroughCache
// 本地内存缓存,用于快速匹配 // 本地内存缓存,用于快速匹配
localCache []*model.ErrorPassthroughRule localCache []*cachedPassthroughRule
localCacheMu sync.RWMutex localCacheMu sync.RWMutex
} }
// cachedPassthroughRule 预计算的规则缓存,避免运行时重复 ToLower
type cachedPassthroughRule struct {
*model.ErrorPassthroughRule
lowerKeywords []string // 预计算的小写关键词
lowerPlatforms []string // 预计算的小写平台
errorCodeSet map[int]struct{} // 预计算的 error code set
}
const maxBodyMatchLen = 8 << 10 // 8KB错误信息不会在 8KB 之后才出现
// NewErrorPassthroughService 创建错误透传规则服务 // NewErrorPassthroughService 创建错误透传规则服务
func NewErrorPassthroughService( func NewErrorPassthroughService(
repo ErrorPassthroughRepository, repo ErrorPassthroughRepository,
@@ -150,17 +160,19 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
return nil return nil
} }
bodyStr := strings.ToLower(string(body)) lowerPlatform := strings.ToLower(platform)
var bodyLower string // 延迟初始化,只在需要关键词匹配时计算
var bodyLowerDone bool
for _, rule := range rules { for _, rule := range rules {
if !rule.Enabled { if !rule.Enabled {
continue continue
} }
if !s.platformMatches(rule, platform) { if !s.platformMatchesCached(rule, lowerPlatform) {
continue continue
} }
if s.ruleMatches(rule, statusCode, bodyStr) { if s.ruleMatchesOptimized(rule, statusCode, body, &bodyLower, &bodyLowerDone) {
return rule return rule.ErrorPassthroughRule
} }
} }
@@ -168,7 +180,7 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
} }
// getCachedRules 获取缓存的规则列表(按优先级排序) // getCachedRules 获取缓存的规则列表(按优先级排序)
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule { func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule {
s.localCacheMu.RLock() s.localCacheMu.RLock()
rules := s.localCache rules := s.localCache
s.localCacheMu.RUnlock() s.localCacheMu.RUnlock()
@@ -223,17 +235,39 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
return nil return nil
} }
// setLocalCache 设置本地缓存 // setLocalCache 设置本地缓存,预计算小写值和 set 以避免运行时重复计算
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) { func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
cached := make([]*cachedPassthroughRule, len(rules))
for i, r := range rules {
cr := &cachedPassthroughRule{ErrorPassthroughRule: r}
if len(r.Keywords) > 0 {
cr.lowerKeywords = make([]string, len(r.Keywords))
for j, kw := range r.Keywords {
cr.lowerKeywords[j] = strings.ToLower(kw)
}
}
if len(r.Platforms) > 0 {
cr.lowerPlatforms = make([]string, len(r.Platforms))
for j, p := range r.Platforms {
cr.lowerPlatforms[j] = strings.ToLower(p)
}
}
if len(r.ErrorCodes) > 0 {
cr.errorCodeSet = make(map[int]struct{}, len(r.ErrorCodes))
for _, code := range r.ErrorCodes {
cr.errorCodeSet[code] = struct{}{}
}
}
cached[i] = cr
}
// 按优先级排序 // 按优先级排序
sorted := make([]*model.ErrorPassthroughRule, len(rules)) sort.Slice(cached, func(i, j int) bool {
copy(sorted, rules) return cached[i].Priority < cached[j].Priority
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
}) })
s.localCacheMu.Lock() s.localCacheMu.Lock()
s.localCache = sorted s.localCache = cached
s.localCacheMu.Unlock() s.localCacheMu.Unlock()
} }
@@ -273,62 +307,79 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
} }
} }
// platformMatches 检查平台是否匹配 // ensureBodyLower 延迟初始化 body 的小写版本,只做一次转换,限制 8KB
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool { func ensureBodyLower(body []byte, bodyLower *string, done *bool) string {
// 如果没有配置平台限制,则匹配所有平台 if *done {
if len(rule.Platforms) == 0 { return *bodyLower
}
b := body
if len(b) > maxBodyMatchLen {
b = b[:maxBodyMatchLen]
}
*bodyLower = strings.ToLower(string(b))
*done = true
return *bodyLower
}
// platformMatchesCached 使用预计算的小写平台检查是否匹配
func (s *ErrorPassthroughService) platformMatchesCached(rule *cachedPassthroughRule, lowerPlatform string) bool {
if len(rule.lowerPlatforms) == 0 {
return true return true
} }
for _, p := range rule.lowerPlatforms {
platform = strings.ToLower(platform) if p == lowerPlatform {
for _, p := range rule.Platforms {
if strings.ToLower(p) == platform {
return true return true
} }
} }
return false return false
} }
// ruleMatches 检查规则是否匹配 // ruleMatchesOptimized 优化的规则匹配,支持短路和延迟 body 转换
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool { func (s *ErrorPassthroughService) ruleMatchesOptimized(rule *cachedPassthroughRule, statusCode int, body []byte, bodyLower *string, bodyLowerDone *bool) bool {
hasErrorCodes := len(rule.ErrorCodes) > 0 hasErrorCodes := len(rule.errorCodeSet) > 0
hasKeywords := len(rule.Keywords) > 0 hasKeywords := len(rule.lowerKeywords) > 0
// 如果没有配置任何条件,不匹配
if !hasErrorCodes && !hasKeywords { if !hasErrorCodes && !hasKeywords {
return false return false
} }
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode) codeMatch := !hasErrorCodes || s.containsIntSet(rule.errorCodeSet, statusCode)
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
if rule.MatchMode == model.MatchModeAll { if rule.MatchMode == model.MatchModeAll {
// "all" 模式:所有配置的条件都必须满足 // "all" 模式:所有配置的条件都必须满足,短路
return codeMatch && keywordMatch if hasErrorCodes && !codeMatch {
return false
}
if hasKeywords {
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
}
return codeMatch
} }
// "any" 模式:任一条件满足即可 // "any" 模式:任一条件满足即可,短路
if hasErrorCodes && hasKeywords { if hasErrorCodes && hasKeywords {
return codeMatch || keywordMatch if codeMatch {
return true
}
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
} }
return codeMatch && keywordMatch // 只配置了一种条件
if hasKeywords {
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
}
return codeMatch
} }
// containsInt 检查切片是否包含指定整数 // containsIntSet 使用 map 查找替代线性扫描
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool { func (s *ErrorPassthroughService) containsIntSet(set map[int]struct{}, val int) bool {
for _, v := range slice { _, ok := set[val]
if v == val { return ok
return true }
}
} // containsAnyKeywordCached 使用预计算的小写关键词检查匹配
return false func (s *ErrorPassthroughService) containsAnyKeywordCached(bodyLower string, lowerKeywords []string) bool {
} for _, kw := range lowerKeywords {
if strings.Contains(bodyLower, kw) {
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
for _, kw := range keywords {
if strings.Contains(bodyLower, strings.ToLower(kw)) {
return true return true
} }
} }

View File

@@ -145,32 +145,58 @@ func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughServic
return svc return svc
} }
// newCachedRuleForTest 从 model.ErrorPassthroughRule 创建 cachedPassthroughRule测试用
func newCachedRuleForTest(rule *model.ErrorPassthroughRule) *cachedPassthroughRule {
cr := &cachedPassthroughRule{ErrorPassthroughRule: rule}
if len(rule.Keywords) > 0 {
cr.lowerKeywords = make([]string, len(rule.Keywords))
for j, kw := range rule.Keywords {
cr.lowerKeywords[j] = strings.ToLower(kw)
}
}
if len(rule.Platforms) > 0 {
cr.lowerPlatforms = make([]string, len(rule.Platforms))
for j, p := range rule.Platforms {
cr.lowerPlatforms[j] = strings.ToLower(p)
}
}
if len(rule.ErrorCodes) > 0 {
cr.errorCodeSet = make(map[int]struct{}, len(rule.ErrorCodes))
for _, code := range rule.ErrorCodes {
cr.errorCodeSet[code] = struct{}{}
}
}
return cr
}
// ============================================================================= // =============================================================================
// 测试 ruleMatches 核心匹配逻辑 // 测试 ruleMatchesOptimized 核心匹配逻辑
// ============================================================================= // =============================================================================
func TestRuleMatches_NoConditions(t *testing.T) { func TestRuleMatches_NoConditions(t *testing.T) {
// 没有配置任何条件时,不应该匹配 // 没有配置任何条件时,不应该匹配
svc := newTestService(nil) svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{ rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Enabled: true, Enabled: true,
ErrorCodes: []int{}, ErrorCodes: []int{},
Keywords: []string{}, Keywords: []string{},
MatchMode: model.MatchModeAny, MatchMode: model.MatchModeAny,
} })
assert.False(t, svc.ruleMatches(rule, 422, "some error message"), var bodyLower string
var bodyLowerDone bool
assert.False(t, svc.ruleMatchesOptimized(rule, 422, []byte("some error message"), &bodyLower, &bodyLowerDone),
"没有配置条件时不应该匹配") "没有配置条件时不应该匹配")
} }
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
svc := newTestService(nil) svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{ rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Enabled: true, Enabled: true,
ErrorCodes: []int{422, 400}, ErrorCodes: []int{422, 400},
Keywords: []string{}, Keywords: []string{},
MatchMode: model.MatchModeAny, MatchMode: model.MatchModeAny,
} })
tests := []struct { tests := []struct {
name string name string
@@ -186,7 +212,9 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body) var bodyLower string
var bodyLowerDone bool
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
assert.Equal(t, tt.expected, result) assert.Equal(t, tt.expected, result)
}) })
} }
@@ -194,12 +222,12 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
svc := newTestService(nil) svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{ rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Enabled: true, Enabled: true,
ErrorCodes: []int{}, ErrorCodes: []int{},
Keywords: []string{"context limit", "model not supported"}, Keywords: []string{"context limit", "model not supported"},
MatchMode: model.MatchModeAny, MatchMode: model.MatchModeAny,
} })
tests := []struct { tests := []struct {
name string name string
@@ -210,16 +238,14 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
{"关键词匹配 context limit", 500, "error: context limit reached", true}, {"关键词匹配 context limit", 500, "error: context limit reached", true},
{"关键词匹配 model not supported", 400, "the model not supported here", true}, {"关键词匹配 model not supported", 400, "the model not supported here", true},
{"关键词不匹配", 422, "some other error", false}, {"关键词不匹配", 422, "some other error", false},
// 注意ruleMatches 接收的 body 参数应该是已经转换为小写的 {"关键词大小写 - 自动转换", 500, "Context Limit exceeded", true},
// 实际使用时MatchRule 会先将 body 转换为小写再传给 ruleMatches
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// 模拟 MatchRule 的行为:先转换为小写 var bodyLower string
bodyLower := strings.ToLower(tt.body) var bodyLowerDone bool
result := svc.ruleMatches(rule, tt.statusCode, bodyLower) result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
assert.Equal(t, tt.expected, result) assert.Equal(t, tt.expected, result)
}) })
} }
@@ -228,12 +254,12 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
// any 模式:错误码 OR 关键词 // any 模式:错误码 OR 关键词
svc := newTestService(nil) svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{ rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Enabled: true, Enabled: true,
ErrorCodes: []int{422, 400}, ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"}, Keywords: []string{"context limit"},
MatchMode: model.MatchModeAny, MatchMode: model.MatchModeAny,
} })
tests := []struct { tests := []struct {
name string name string
@@ -274,7 +300,9 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body) var bodyLower string
var bodyLowerDone bool
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
assert.Equal(t, tt.expected, result, tt.reason) assert.Equal(t, tt.expected, result, tt.reason)
}) })
} }
@@ -283,12 +311,12 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
func TestRuleMatches_BothConditions_AllMode(t *testing.T) { func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
// all 模式:错误码 AND 关键词 // all 模式:错误码 AND 关键词
svc := newTestService(nil) svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{ rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Enabled: true, Enabled: true,
ErrorCodes: []int{422, 400}, ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"}, Keywords: []string{"context limit"},
MatchMode: model.MatchModeAll, MatchMode: model.MatchModeAll,
} })
tests := []struct { tests := []struct {
name string name string
@@ -329,14 +357,16 @@ func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body) var bodyLower string
var bodyLowerDone bool
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
assert.Equal(t, tt.expected, result, tt.reason) assert.Equal(t, tt.expected, result, tt.reason)
}) })
} }
} }
// ============================================================================= // =============================================================================
// 测试 platformMatches 平台匹配逻辑 // 测试 platformMatchesCached 平台匹配逻辑
// ============================================================================= // =============================================================================
func TestPlatformMatches(t *testing.T) { func TestPlatformMatches(t *testing.T) {
@@ -394,10 +424,10 @@ func TestPlatformMatches(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
rule := &model.ErrorPassthroughRule{ rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Platforms: tt.rulePlatforms, Platforms: tt.rulePlatforms,
} })
result := svc.platformMatches(rule, tt.requestPlatform) result := svc.platformMatchesCached(rule, strings.ToLower(tt.requestPlatform))
assert.Equal(t, tt.expected, result) assert.Equal(t, tt.expected, result)
}) })
} }

View File

@@ -368,15 +368,31 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover. // UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct { type UpstreamFailoverError struct {
StatusCode int StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配 ResponseBody []byte // 上游响应体,用于错误透传规则匹配
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应应在同一账号上重试 N 次再切换
} }
func (e *UpstreamFailoverError) Error() string { func (e *UpstreamFailoverError) Error() string {
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode) return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
} }
// TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。
// 由 handler 层在同账号重试全部用尽、切换账号时调用。
func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) {
if failoverErr == nil || !failoverErr.RetryableOnSameAccount {
return
}
// 根据状态码选择封禁策略
switch failoverErr.StatusCode {
case http.StatusBadRequest:
tempUnscheduleGoogleConfigError(ctx, s.accountRepo, accountID, "[handler]")
case http.StatusBadGateway:
tempUnscheduleEmptyResponse(ctx, s.accountRepo, accountID, "[handler]")
}
}
// GatewayService handles API gateway operations // GatewayService handles API gateway operations
type GatewayService struct { type GatewayService struct {
accountRepo AccountRepository accountRepo AccountRepository

View File

@@ -880,6 +880,37 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// ErrorPolicyNone → 原有逻辑 // ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
if resp.StatusCode == http.StatusBadRequest {
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if isGoogleProjectConfigError(msg400) {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true}
}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader) upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" { if upstreamReqID == "" {
@@ -1330,6 +1361,34 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// ErrorPolicyNone → 原有逻辑 // ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
if resp.StatusCode == http.StatusBadRequest {
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if isGoogleProjectConfigError(msg400) {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody)))
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(evBody), maxBytes)
}
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody, RetryableOnSameAccount: true}
}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody) evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))

View File

@@ -20,6 +20,10 @@ const (
// retry the specific upstream attempt (not just the client request). // retry the specific upstream attempt (not just the client request).
// This value is sanitized+trimmed before being persisted. // This value is sanitized+trimmed before being persisted.
OpsUpstreamRequestBodyKey = "ops_upstream_request_body" OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
// OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。
// ops_error_logger 中间件检查此 key为 true 时跳过错误记录。
OpsSkipPassthroughKey = "ops_skip_passthrough"
) )
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) { func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
@@ -103,6 +107,37 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
evCopy := ev evCopy := ev
existing = append(existing, &evCopy) existing = append(existing, &evCopy)
c.Set(OpsUpstreamErrorsKey, existing) c.Set(OpsUpstreamErrorsKey, existing)
checkSkipMonitoringForUpstreamEvent(c, &evCopy)
}
// checkSkipMonitoringForUpstreamEvent checks whether the upstream error event
// matches a passthrough rule with skip_monitoring=true and, if so, sets the
// OpsSkipPassthroughKey on the context. This ensures intermediate retry /
// failover errors (which never go through the final applyErrorPassthroughRule
// path) can still suppress ops_error_logs recording.
func checkSkipMonitoringForUpstreamEvent(c *gin.Context, ev *OpsUpstreamErrorEvent) {
if ev.UpstreamStatusCode == 0 {
return
}
svc := getBoundErrorPassthroughService(c)
if svc == nil {
return
}
// Use the best available body representation for keyword matching.
// Even when body is empty, MatchRule can still match rules that only
// specify ErrorCodes (no Keywords), so we always call it.
body := ev.Detail
if body == "" {
body = ev.Message
}
rule := svc.MatchRule(ev.Platform, ev.UpstreamStatusCode, []byte(body))
if rule != nil && rule.SkipMonitoring {
c.Set(OpsSkipPassthroughKey, true)
}
} }
func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string { func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {

View File

@@ -0,0 +1,4 @@
-- Add skip_monitoring field to error_passthrough_rules table
-- When true, errors matching this rule will not be recorded in ops_error_logs
ALTER TABLE error_passthrough_rules
ADD COLUMN IF NOT EXISTS skip_monitoring BOOLEAN NOT NULL DEFAULT false;

View File

@@ -21,6 +21,7 @@ export interface ErrorPassthroughRule {
response_code: number | null response_code: number | null
passthrough_body: boolean passthrough_body: boolean
custom_message: string | null custom_message: string | null
skip_monitoring: boolean
description: string | null description: string | null
created_at: string created_at: string
updated_at: string updated_at: string
@@ -41,6 +42,7 @@ export interface CreateRuleRequest {
response_code?: number | null response_code?: number | null
passthrough_body?: boolean passthrough_body?: boolean
custom_message?: string | null custom_message?: string | null
skip_monitoring?: boolean
description?: string | null description?: string | null
} }
@@ -59,6 +61,7 @@ export interface UpdateRuleRequest {
response_code?: number | null response_code?: number | null
passthrough_body?: boolean passthrough_body?: boolean
custom_message?: string | null custom_message?: string | null
skip_monitoring?: boolean
description?: string | null description?: string | null
} }

View File

@@ -148,6 +148,16 @@
{{ rule.passthrough_body ? t('admin.errorPassthrough.passthrough') : t('admin.errorPassthrough.custom') }} {{ rule.passthrough_body ? t('admin.errorPassthrough.passthrough') : t('admin.errorPassthrough.custom') }}
</span> </span>
</div> </div>
<div v-if="rule.skip_monitoring" class="flex items-center gap-1">
<Icon
name="checkCircle"
size="xs"
class="text-yellow-500"
/>
<span class="text-gray-600 dark:text-gray-400">
{{ t('admin.errorPassthrough.skipMonitoring') }}
</span>
</div>
</div> </div>
</td> </td>
<td class="px-3 py-2"> <td class="px-3 py-2">
@@ -366,6 +376,19 @@
</div> </div>
</div> </div>
<!-- Skip Monitoring -->
<div class="flex items-center gap-1.5">
<input
type="checkbox"
v-model="form.skip_monitoring"
class="h-3.5 w-3.5 rounded border-gray-300 text-yellow-600 focus:ring-yellow-500"
/>
<span class="text-xs font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.errorPassthrough.form.skipMonitoring') }}
</span>
</div>
<p class="input-hint text-xs -mt-3">{{ t('admin.errorPassthrough.form.skipMonitoringHint') }}</p>
<!-- Enabled --> <!-- Enabled -->
<div class="flex items-center gap-1.5"> <div class="flex items-center gap-1.5">
<input <input
@@ -453,6 +476,7 @@ const form = reactive({
response_code: null as number | null, response_code: null as number | null,
passthrough_body: true, passthrough_body: true,
custom_message: null as string | null, custom_message: null as string | null,
skip_monitoring: false,
description: null as string | null description: null as string | null
}) })
@@ -497,6 +521,7 @@ const resetForm = () => {
form.response_code = null form.response_code = null
form.passthrough_body = true form.passthrough_body = true
form.custom_message = null form.custom_message = null
form.skip_monitoring = false
form.description = null form.description = null
errorCodesInput.value = '' errorCodesInput.value = ''
keywordsInput.value = '' keywordsInput.value = ''
@@ -520,6 +545,7 @@ const handleEdit = (rule: ErrorPassthroughRule) => {
form.response_code = rule.response_code form.response_code = rule.response_code
form.passthrough_body = rule.passthrough_body form.passthrough_body = rule.passthrough_body
form.custom_message = rule.custom_message form.custom_message = rule.custom_message
form.skip_monitoring = rule.skip_monitoring
form.description = rule.description form.description = rule.description
errorCodesInput.value = rule.error_codes.join(', ') errorCodesInput.value = rule.error_codes.join(', ')
keywordsInput.value = rule.keywords.join('\n') keywordsInput.value = rule.keywords.join('\n')
@@ -575,6 +601,7 @@ const handleSubmit = async () => {
response_code: form.passthrough_code ? null : form.response_code, response_code: form.passthrough_code ? null : form.response_code,
passthrough_body: form.passthrough_body, passthrough_body: form.passthrough_body,
custom_message: form.passthrough_body ? null : form.custom_message, custom_message: form.passthrough_body ? null : form.custom_message,
skip_monitoring: form.skip_monitoring,
description: form.description?.trim() || null description: form.description?.trim() || null
} }

View File

@@ -3353,6 +3353,7 @@ export default {
custom: 'Custom', custom: 'Custom',
code: 'Code', code: 'Code',
body: 'Body', body: 'Body',
skipMonitoring: 'Skip Monitoring',
// Columns // Columns
columns: { columns: {
@@ -3397,6 +3398,8 @@ export default {
passthroughBody: 'Passthrough upstream error message', passthroughBody: 'Passthrough upstream error message',
customMessage: 'Custom error message', customMessage: 'Custom error message',
customMessagePlaceholder: 'Error message to return to client...', customMessagePlaceholder: 'Error message to return to client...',
skipMonitoring: 'Skip monitoring',
skipMonitoringHint: 'When enabled, errors matching this rule will not be recorded in ops monitoring',
enabled: 'Enable this rule' enabled: 'Enable this rule'
}, },

View File

@@ -3527,6 +3527,7 @@ export default {
custom: '自定义', custom: '自定义',
code: '状态码', code: '状态码',
body: '消息体', body: '消息体',
skipMonitoring: '跳过监控',
// Columns // Columns
columns: { columns: {
@@ -3571,6 +3572,8 @@ export default {
passthroughBody: '透传上游错误信息', passthroughBody: '透传上游错误信息',
customMessage: '自定义错误信息', customMessage: '自定义错误信息',
customMessagePlaceholder: '返回给客户端的错误信息...', customMessagePlaceholder: '返回给客户端的错误信息...',
skipMonitoring: '跳过运维监控记录',
skipMonitoringHint: '开启后,匹配此规则的错误不会被记录到运维监控中',
enabled: '启用此规则' enabled: '启用此规则'
}, },