From ab4e8b2cf009c7067151b02adaa7bcd27865c8c0 Mon Sep 17 00:00:00 2001 From: QTom Date: Mon, 16 Mar 2026 10:28:11 +0800 Subject: [PATCH 1/2] =?UTF-8?q?fix(gateway):=20=E9=98=B2=E6=AD=A2=20OpenAI?= =?UTF-8?q?=20Codex=20=E8=B7=A8=E7=94=A8=E6=88=B7=E4=B8=B2=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 根因:多个用户共享同一 OAuth 账号时,conversation_id/session_id 头 未做用户隔离,导致上游 chatgpt.com 将不同用户的请求关联到同一会话。 HTTP SSE 修复: - 新增 isolateOpenAISessionID(apiKeyID, raw),将 API Key ID 混入 session 标识符(xxhash),确保不同 Key 的用户产生不同上游会话 - buildUpstreamRequest: OAuth 分支先 Del 客户端透传的 session 头, 再用隔离值覆盖 - buildUpstreamRequestOpenAIPassthrough: 透传路径同样隔离 - ForwardAsAnthropic: Anthropic Messages 兼容路径同步修复 - buildOpenAIWSHeaders: WS 路径的 OAuth session 头同步隔离 --- .../service/openai_gateway_messages.go | 7 ++- .../service/openai_gateway_service.go | 57 ++++++++++++++----- ..._gateway_service_session_isolation_test.go | 50 ++++++++++++++++ .../internal/service/openai_ws_forwarder.go | 21 +++++-- .../openai_ws_forwarder_success_test.go | 9 ++- 5 files changed, 119 insertions(+), 25 deletions(-) create mode 100644 backend/internal/service/openai_gateway_service_session_isolation_test.go diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 1e40ec6f..58714571 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -107,10 +107,11 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, fmt.Errorf("build upstream request: %w", err) } - // Override session_id with a deterministic UUID derived from the sticky - // session key (buildUpstreamRequest may have set it to the raw value). + // Override session_id with a deterministic UUID derived from the isolated + // session key, ensuring different API keys produce different upstream sessions. if promptCacheKey != "" { - upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey)) + apiKeyID := getAPIKeyIDFromContext(c) + upstreamReq.Header.Set("session_id", generateSessionUUID(isolateOpenAISessionID(apiKeyID, promptCacheKey))) } // 7. Send request diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 327ce916..c8876edb 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -24,6 +24,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/tidwall/gjson" @@ -787,6 +788,20 @@ func getAPIKeyIDFromContext(c *gin.Context) int64 { return apiKey.ID } +// isolateOpenAISessionID 将 apiKeyID 混入 session 标识符, +// 确保不同 API Key 的用户即使使用相同的原始 session_id/conversation_id, +// 到达上游的标识符也不同,防止跨用户会话碰撞。 +func isolateOpenAISessionID(apiKeyID int64, raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + h := xxhash.New() + _, _ = fmt.Fprintf(h, "k%d:", apiKeyID) + _, _ = h.WriteString(raw) + return fmt.Sprintf("%016x", h.Sum64()) +} + func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) { if !result.Enabled { return @@ -2501,13 +2516,17 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { req.Header.Set("chatgpt-account-id", chatgptAccountID) } + apiKeyID := getAPIKeyIDFromContext(c) + // 先保存客户端原始值,再做 compact 补充,避免后续统一隔离时读到已处理的值。 + clientSessionID := strings.TrimSpace(req.Header.Get("session_id")) + clientConversationID := strings.TrimSpace(req.Header.Get("conversation_id")) if isOpenAIResponsesCompactPath(c) { req.Header.Set("accept", "application/json") if req.Header.Get("version") == "" { req.Header.Set("version", codexCLIVersion) } - if req.Header.Get("session_id") == "" { - req.Header.Set("session_id", resolveOpenAICompactSessionID(c)) + if clientSessionID == "" { + clientSessionID = resolveOpenAICompactSessionID(c) } } else if req.Header.Get("accept") == "" { req.Header.Set("accept", "text/event-stream") @@ -2518,13 +2537,18 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( if req.Header.Get("originator") == "" { req.Header.Set("originator", "codex_cli_rs") } - if promptCacheKey != "" { - if req.Header.Get("conversation_id") == "" { - req.Header.Set("conversation_id", promptCacheKey) - } - if req.Header.Get("session_id") == "" { - req.Header.Set("session_id", promptCacheKey) - } + // 用隔离后的 session 标识符覆盖客户端透传值,防止跨用户会话碰撞。 + if clientSessionID == "" { + clientSessionID = promptCacheKey + } + if clientConversationID == "" { + clientConversationID = promptCacheKey + } + if clientSessionID != "" { + req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, clientSessionID)) + } + if clientConversationID != "" { + req.Header.Set("conversation_id", isolateOpenAISessionID(apiKeyID, clientConversationID)) } } @@ -2887,22 +2911,27 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. } } if account.Type == AccountTypeOAuth { + // 清除客户端透传的 session 头,后续用隔离后的值重新设置,防止跨用户会话碰撞。 + req.Header.Del("conversation_id") + req.Header.Del("session_id") + req.Header.Set("OpenAI-Beta", "responses=experimental") req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) + apiKeyID := getAPIKeyIDFromContext(c) if isOpenAIResponsesCompactPath(c) { req.Header.Set("accept", "application/json") if req.Header.Get("version") == "" { req.Header.Set("version", codexCLIVersion) } - if req.Header.Get("session_id") == "" { - req.Header.Set("session_id", resolveOpenAICompactSessionID(c)) - } + compactSession := resolveOpenAICompactSessionID(c) + req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, compactSession)) } else { req.Header.Set("accept", "text/event-stream") } if promptCacheKey != "" { - req.Header.Set("conversation_id", promptCacheKey) - req.Header.Set("session_id", promptCacheKey) + isolated := isolateOpenAISessionID(apiKeyID, promptCacheKey) + req.Header.Set("conversation_id", isolated) + req.Header.Set("session_id", isolated) } } diff --git a/backend/internal/service/openai_gateway_service_session_isolation_test.go b/backend/internal/service/openai_gateway_service_session_isolation_test.go new file mode 100644 index 00000000..d42fbcc5 --- /dev/null +++ b/backend/internal/service/openai_gateway_service_session_isolation_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsolateOpenAISessionID(t *testing.T) { + t.Run("empty_raw_returns_empty", func(t *testing.T) { + assert.Equal(t, "", isolateOpenAISessionID(1, "")) + assert.Equal(t, "", isolateOpenAISessionID(1, " ")) + }) + + t.Run("deterministic", func(t *testing.T) { + a := isolateOpenAISessionID(42, "sess_abc123") + b := isolateOpenAISessionID(42, "sess_abc123") + assert.Equal(t, a, b) + }) + + t.Run("different_apiKeyID_different_result", func(t *testing.T) { + a := isolateOpenAISessionID(1, "same_session") + b := isolateOpenAISessionID(2, "same_session") + require.NotEqual(t, a, b, "不同 API Key 使用相同 session_id 应产生不同隔离值") + }) + + t.Run("different_raw_different_result", func(t *testing.T) { + a := isolateOpenAISessionID(1, "session_a") + b := isolateOpenAISessionID(1, "session_b") + require.NotEqual(t, a, b) + }) + + t.Run("format_is_16_hex_chars", func(t *testing.T) { + result := isolateOpenAISessionID(99, "test_session") + assert.Len(t, result, 16, "应为 16 字符的 hex 字符串") + for _, ch := range result { + assert.True(t, (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f'), + "应仅包含 hex 字符: %c", ch) + } + }) + + t.Run("zero_apiKeyID_still_works", func(t *testing.T) { + result := isolateOpenAISessionID(0, "session") + assert.NotEmpty(t, result) + // apiKeyID=0 与 apiKeyID=1 应产生不同结果 + other := isolateOpenAISessionID(1, "session") + assert.NotEqual(t, result, other) + }) +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index d4e4ea5a..b7ac8423 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1124,11 +1124,22 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders( headers.Set("accept-language", v) } } - if sessionResolution.SessionID != "" { - headers.Set("session_id", sessionResolution.SessionID) - } - if sessionResolution.ConversationID != "" { - headers.Set("conversation_id", sessionResolution.ConversationID) + // OAuth 账号:将 apiKeyID 混入 session 标识符,防止跨用户会话碰撞。 + if account != nil && account.Type == AccountTypeOAuth { + apiKeyID := getAPIKeyIDFromContext(c) + if sessionResolution.SessionID != "" { + headers.Set("session_id", isolateOpenAISessionID(apiKeyID, sessionResolution.SessionID)) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", isolateOpenAISessionID(apiKeyID, sessionResolution.ConversationID)) + } + } else { + if sessionResolution.SessionID != "" { + headers.Set("session_id", sessionResolution.SessionID) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", sessionResolution.ConversationID) + } } if state := strings.TrimSpace(turnState); state != "" { headers.Set(openAIWSTurnStateHeader, state) diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go index 912fade9..0d5004c0 100644 --- a/backend/internal/service/openai_ws_forwarder_success_test.go +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -454,8 +454,10 @@ func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段") require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true") require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta")) - require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id")) - require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id")) + // OAuth 账号的 session_id/conversation_id 应被 isolateOpenAISessionID 隔离, + // 测试中未设置 api_key 到 context,apiKeyID=0。 + require.Equal(t, isolateOpenAISessionID(0, "sess-oauth-1"), captureDialer.lastHeaders.Get("session_id")) + require.Equal(t, isolateOpenAISessionID(0, "conv-oauth-1"), captureDialer.lastHeaders.Get("conversation_id")) } func TestOpenAIGatewayService_Forward_WSv2_OAuthOriginatorCompatibility(t *testing.T) { @@ -596,7 +598,8 @@ func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheK require.NotNil(t, result) require.Equal(t, "resp_prompt_cache_key", result.RequestID) - require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id")) + // OAuth 账号的 session_id 应被 isolateOpenAISessionID 隔离(apiKeyID=0,未在 context 设置)。 + require.Equal(t, isolateOpenAISessionID(0, "pcache_123"), captureDialer.lastHeaders.Get("session_id")) require.Empty(t, captureDialer.lastHeaders.Get("conversation_id")) require.NotNil(t, captureConn.lastWrite) require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists()) From 3741617ebd5131921a9bf7e8a82d47fbce780891 Mon Sep 17 00:00:00 2001 From: QTom Date: Mon, 16 Mar 2026 10:27:57 +0800 Subject: [PATCH 2/2] =?UTF-8?q?fix(gateway):=20WS=20=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E6=B1=A0=E6=9D=A1=E4=BB=B6=E5=BC=8F=20MarkBroken=20=E9=98=B2?= =?UTF-8?q?=E6=AD=A2=E8=B7=A8=E8=AF=B7=E6=B1=82=E4=B8=B2=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 正常终端事件(response.completed 等)退出后连接归还复用, 仅异常路径(读写错误、error 事件、客户端断连)MarkBroken 销毁。 Generate 模式: - 引入 cleanExit 标记,仅在 isTerminalEvent break 时设置 true - defer 中根据 cleanExit 决定是否 MarkBroken - 所有异常路径已在各自分支中提前调用 MarkBroken Ingress 模式: - 引入 lastTurnClean 标记,sendAndRelay 正常完成时设为 true - releaseSessionLease 根据 lastTurnClean 决定是否 MarkBroken - 错误路径重置 lastTurnClean = false - 客户端断连后 drain 仍保守 MarkBroken(L2916) --- .../internal/service/openai_ws_forwarder.go | 21 ++++++++++++++++--- .../openai_ws_forwarder_success_test.go | 9 ++++++-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index b7ac8423..1d3d8fdf 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1870,7 +1870,16 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( } return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) } - defer lease.Release() + // cleanExit 标记正常终端事件退出,此时上游不会再发送帧,连接可安全归还复用。 + // 所有异常路径(读写错误、error 事件等)已在各自分支中提前调用 MarkBroken, + // 因此 defer 中只需处理正常退出时不 MarkBroken 即可。 + cleanExit := false + defer func() { + if !cleanExit { + lease.MarkBroken() + } + lease.Release() + }() connID := strings.TrimSpace(lease.ConnID()) logOpenAIWSModeDebug( "connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v", @@ -2248,6 +2257,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( } if isTerminalEvent { + cleanExit = true break } } @@ -2983,12 +2993,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( pinnedSessionConnID = connID } } + // lastTurnClean 标记最后一轮 sendAndRelay 是否正常完成(收到终端事件且客户端未断连)。 + // 所有异常路径(读写错误、error 事件、客户端断连)已在各自分支或上层(L3403)中 MarkBroken, + // 因此 releaseSessionLease 中只需在非正常结束时 MarkBroken。 + lastTurnClean := false releaseSessionLease := func() { if sessionLease == nil { return } - if dedicatedMode { - // dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。 + if !lastTurnClean { sessionLease.MarkBroken() } unpinSessionConn(sessionConnID) @@ -3383,6 +3396,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) if relayErr != nil { + lastTurnClean = false if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { continue } @@ -3402,6 +3416,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( turnRetry = 0 turnPrevRecoveryTried = false lastTurnFinishedAt = time.Now() + lastTurnClean = true if hooks != nil && hooks.AfterTurn != nil { hooks.AfterTurn(turn, result, nil) } diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go index 0d5004c0..7a76c385 100644 --- a/backend/internal/service/openai_ws_forwarder_success_test.go +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -380,7 +380,8 @@ func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) { require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_")) } - require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链") + // 条件式 MarkBroken:正常终端事件退出后连接归还复用,不再无条件销毁。 + require.Equal(t, int64(1), upgradeCount.Load(), "正常完成后连接应归还复用,不应每次新建") metrics := svc.SnapshotOpenAIWSPoolMetrics() require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1)) require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1)) @@ -964,6 +965,10 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t require.NotNil(t, result1) require.Equal(t, "resp_meta_1", result1.RequestID) + require.Len(t, captureConn.writes, 1) + firstWrite := requestToJSONString(captureConn.writes[0]) + require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String()) + rec2 := httptest.NewRecorder() c2, _ := gin.CreateTestContext(rec2) c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) @@ -977,7 +982,7 @@ func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *t require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接") require.Len(t, captureConn.writes, 2) - firstWrite := requestToJSONString(captureConn.writes[0]) + firstWrite = requestToJSONString(captureConn.writes[0]) secondWrite := requestToJSONString(captureConn.writes[1]) require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String()) require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String())