diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index b7ac8423..1d3d8fdf 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1870,7 +1870,16 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( } return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) } - defer lease.Release() + // cleanExit 标记正常终端事件退出,此时上游不会再发送帧,连接可安全归还复用。 + // 所有异常路径(读写错误、error 事件等)已在各自分支中提前调用 MarkBroken, + // 因此 defer 中只需处理正常退出时不 MarkBroken 即可。 + cleanExit := false + defer func() { + if !cleanExit { + lease.MarkBroken() + } + lease.Release() + }() connID := strings.TrimSpace(lease.ConnID()) logOpenAIWSModeDebug( "connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v", @@ -2248,6 +2257,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( } if isTerminalEvent { + cleanExit = true break } } @@ -2983,12 +2993,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( pinnedSessionConnID = connID } } + // lastTurnClean 标记最后一轮 sendAndRelay 是否正常完成(收到终端事件且客户端未断连)。 + // 所有异常路径(读写错误、error 事件、客户端断连)已在各自分支或上层(L3403)中 MarkBroken, + // 因此 releaseSessionLease 中只需在非正常结束时 MarkBroken。 + lastTurnClean := false releaseSessionLease := func() { if sessionLease == nil { return } - if dedicatedMode { - // dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。 + if !lastTurnClean { sessionLease.MarkBroken() } unpinSessionConn(sessionConnID) @@ -3383,6 +3396,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) if relayErr != nil { + lastTurnClean = false if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { continue } @@ -3402,6 +3416,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( turnRetry = 0 turnPrevRecoveryTried = false lastTurnFinishedAt = time.Now() + lastTurnClean = true if hooks != nil && hooks.AfterTurn != nil { hooks.AfterTurn(turn, result, nil) } diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go index 0d5004c0..7a76c385 100644 --- a/backend/internal/service/openai_ws_forwarder_success_test.go +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -380,7 +380,8 @@ func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) { require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_")) } - require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链") + // 条件式 MarkBroken:正常终端事件退出后连接归还复用,不再无条件销毁。 + require.Equal(t, int64(1), upgradeCount.Load(), "正常完成后连接应归还复用,不应每次新建") metrics := svc.SnapshotOpenAIWSPoolMetrics() require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1)) require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1)) @@ -964,6 +965,10 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t require.NotNil(t, result1) require.Equal(t, "resp_meta_1", result1.RequestID) + require.Len(t, captureConn.writes, 1) + firstWrite := requestToJSONString(captureConn.writes[0]) + require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String()) + rec2 := httptest.NewRecorder() c2, _ := gin.CreateTestContext(rec2) c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) @@ -977,7 +982,7 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接") require.Len(t, captureConn.writes, 2) - firstWrite := requestToJSONString(captureConn.writes[0]) + firstWrite = requestToJSONString(captureConn.writes[0]) secondWrite := requestToJSONString(captureConn.writes[1]) require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String()) require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String())