mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
fix(gateway): WS 连接池条件式 MarkBroken 防止跨请求串流
正常终端事件(response.completed 等)退出后连接归还复用, 仅异常路径(读写错误、error 事件、客户端断连)MarkBroken 销毁。 Generate 模式: - 引入 cleanExit 标记,仅在 isTerminalEvent break 时设置 true - defer 中根据 cleanExit 决定是否 MarkBroken - 所有异常路径已在各自分支中提前调用 MarkBroken Ingress 模式: - 引入 lastTurnClean 标记,sendAndRelay 正常完成时设为 true - releaseSessionLease 根据 lastTurnClean 决定是否 MarkBroken - 错误路径重置 lastTurnClean = false - 客户端断连后 drain 仍保守 MarkBroken(L2916)
This commit is contained in:
@@ -1870,7 +1870,16 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
|||||||
}
|
}
|
||||||
return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err)
|
return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err)
|
||||||
}
|
}
|
||||||
defer lease.Release()
|
// cleanExit 标记正常终端事件退出,此时上游不会再发送帧,连接可安全归还复用。
|
||||||
|
// 所有异常路径(读写错误、error 事件等)已在各自分支中提前调用 MarkBroken,
|
||||||
|
// 因此 defer 中只需处理正常退出时不 MarkBroken 即可。
|
||||||
|
cleanExit := false
|
||||||
|
defer func() {
|
||||||
|
if !cleanExit {
|
||||||
|
lease.MarkBroken()
|
||||||
|
}
|
||||||
|
lease.Release()
|
||||||
|
}()
|
||||||
connID := strings.TrimSpace(lease.ConnID())
|
connID := strings.TrimSpace(lease.ConnID())
|
||||||
logOpenAIWSModeDebug(
|
logOpenAIWSModeDebug(
|
||||||
"connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v",
|
"connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v",
|
||||||
@@ -2248,6 +2257,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if isTerminalEvent {
|
if isTerminalEvent {
|
||||||
|
cleanExit = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2983,12 +2993,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
pinnedSessionConnID = connID
|
pinnedSessionConnID = connID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// lastTurnClean 标记最后一轮 sendAndRelay 是否正常完成(收到终端事件且客户端未断连)。
|
||||||
|
// 所有异常路径(读写错误、error 事件、客户端断连)已在各自分支或上层(L3403)中 MarkBroken,
|
||||||
|
// 因此 releaseSessionLease 中只需在非正常结束时 MarkBroken。
|
||||||
|
lastTurnClean := false
|
||||||
releaseSessionLease := func() {
|
releaseSessionLease := func() {
|
||||||
if sessionLease == nil {
|
if sessionLease == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if dedicatedMode {
|
if !lastTurnClean {
|
||||||
// dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。
|
|
||||||
sessionLease.MarkBroken()
|
sessionLease.MarkBroken()
|
||||||
}
|
}
|
||||||
unpinSessionConn(sessionConnID)
|
unpinSessionConn(sessionConnID)
|
||||||
@@ -3383,6 +3396,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
|
|
||||||
result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel)
|
result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel)
|
||||||
if relayErr != nil {
|
if relayErr != nil {
|
||||||
|
lastTurnClean = false
|
||||||
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
|
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -3402,6 +3416,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
turnRetry = 0
|
turnRetry = 0
|
||||||
turnPrevRecoveryTried = false
|
turnPrevRecoveryTried = false
|
||||||
lastTurnFinishedAt = time.Now()
|
lastTurnFinishedAt = time.Now()
|
||||||
|
lastTurnClean = true
|
||||||
if hooks != nil && hooks.AfterTurn != nil {
|
if hooks != nil && hooks.AfterTurn != nil {
|
||||||
hooks.AfterTurn(turn, result, nil)
|
hooks.AfterTurn(turn, result, nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -380,7 +380,8 @@ func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) {
|
|||||||
require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_"))
|
require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_"))
|
||||||
}
|
}
|
||||||
|
|
||||||
require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链")
|
// 条件式 MarkBroken:正常终端事件退出后连接归还复用,不再无条件销毁。
|
||||||
|
require.Equal(t, int64(1), upgradeCount.Load(), "正常完成后连接应归还复用,不应每次新建")
|
||||||
metrics := svc.SnapshotOpenAIWSPoolMetrics()
|
metrics := svc.SnapshotOpenAIWSPoolMetrics()
|
||||||
require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1))
|
require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1))
|
||||||
require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
|
require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
|
||||||
@@ -964,6 +965,10 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t
|
|||||||
require.NotNil(t, result1)
|
require.NotNil(t, result1)
|
||||||
require.Equal(t, "resp_meta_1", result1.RequestID)
|
require.Equal(t, "resp_meta_1", result1.RequestID)
|
||||||
|
|
||||||
|
require.Len(t, captureConn.writes, 1)
|
||||||
|
firstWrite := requestToJSONString(captureConn.writes[0])
|
||||||
|
require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String())
|
||||||
|
|
||||||
rec2 := httptest.NewRecorder()
|
rec2 := httptest.NewRecorder()
|
||||||
c2, _ := gin.CreateTestContext(rec2)
|
c2, _ := gin.CreateTestContext(rec2)
|
||||||
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||||
@@ -977,7 +982,7 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t
|
|||||||
require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接")
|
require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接")
|
||||||
require.Len(t, captureConn.writes, 2)
|
require.Len(t, captureConn.writes, 2)
|
||||||
|
|
||||||
firstWrite := requestToJSONString(captureConn.writes[0])
|
firstWrite = requestToJSONString(captureConn.writes[0])
|
||||||
secondWrite := requestToJSONString(captureConn.writes[1])
|
secondWrite := requestToJSONString(captureConn.writes[1])
|
||||||
require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String())
|
require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String())
|
||||||
require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String())
|
require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String())
|
||||||
|
|||||||
Reference in New Issue
Block a user