diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 7676ffa3..c0de4476 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -1233,6 +1233,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ) hooks := &service.OpenAIWSIngressHooks{ + InitialRequestModel: reqModel, BeforeTurn: func(turn int) error { if turn == 1 { return nil diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index 8ecee59a..0e21dc08 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -651,6 +652,46 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") } +func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) { + got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{ + firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`, + userAgent: testStringPtr("codex_cli_rs/0.125.0 test"), + }) + + require.NotNil(t, got.log.UserAgent) + require.Equal(t, "codex_cli_rs/0.125.0 test", *got.log.UserAgent) + require.NotNil(t, got.log.ReasoningEffort) + require.Equal(t, "high", *got.log.ReasoningEffort) + require.True(t, got.log.OpenAIWSMode) +} + +func TestOpenAIResponsesWebSocket_PassthroughUsageLogInfersReasoningFromInitialRequestModel(t *testing.T) { + got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{ + firstPayload: `{"type":"response.create","model":"gpt-5.4-xhigh","stream":false}`, + userAgent: testStringPtr("codex_cli_rs/0.125.0 mapped"), + channelMapping: map[string]string{ + "gpt-5.4-xhigh": "gpt-5.4", + }, + }) + + require.Equal(t, "gpt-5.4", gjson.GetBytes(got.upstreamFirstPayload, "model").String(), + "上游首帧应使用渠道映射后的模型") + require.NotNil(t, got.log.ReasoningEffort) + require.Equal(t, "xhigh", *got.log.ReasoningEffort, + "usage log reasoning effort 必须使用渠道映射前首帧模型后缀推导") +} + +func TestOpenAIResponsesWebSocket_PassthroughUsageLogLeavesUserAgentNilWhenMissing(t *testing.T) { + got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{ + firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"medium"}}`, + userAgent: testStringPtr(""), + }) + + require.Nil(t, got.log.UserAgent, "空入站 User-Agent 不应由上游握手 UA 或默认 UA 兜底") + require.NotNil(t, got.log.ReasoningEffort) + require.Equal(t, "medium", *got.log.ReasoningEffort) +} + func TestSetOpenAIClientTransportHTTP(t *testing.T) { gin.SetMode(gin.TestMode) @@ -796,3 +837,278 @@ func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject router.GET("/openai/v1/responses", h.ResponsesWebSocket) return httptest.NewServer(router) } + +type openAIResponsesWSUsageLogCase struct { + firstPayload string + userAgent *string + channelMapping map[string]string +} + +type openAIResponsesWSUsageLogResult struct { + log *service.UsageLog + upstreamFirstPayload []byte +} + +type openAIWSUsageHandlerAccountRepoStub struct { + service.AccountRepository + account service.Account +} + +func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + if s.account.Platform != platform { + return nil, nil + } + return []service.Account{s.account}, nil +} + +func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { + return s.ListSchedulableByPlatform(ctx, platform) +} + +func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) { + if s.account.ID != id { + return nil, nil + } + account := s.account + return &account, nil +} + +type openAIWSUsageHandlerUsageLogRepoStub struct { + service.UsageLogRepository + created chan *service.UsageLog +} + +func (s *openAIWSUsageHandlerUsageLogRepoStub) Create(ctx context.Context, log *service.UsageLog) (bool, error) { + if s.created != nil { + s.created <- log + } + return true, nil +} + +type openAIWSUsageHandlerChannelRepoStub struct { + service.ChannelRepository + channels []service.Channel + groupPlatforms map[int64]string +} + +func (s *openAIWSUsageHandlerChannelRepoStub) ListAll(ctx context.Context) ([]service.Channel, error) { + return s.channels, nil +} + +func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { + out := make(map[int64]string, len(groupIDs)) + for _, groupID := range groupIDs { + if platform := strings.TrimSpace(s.groupPlatforms[groupID]); platform != "" { + out[groupID] = platform + } + } + return out, nil +} + +func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult { + t.Helper() + gin.SetMode(gin.TestMode) + + upstreamPayloadCh := make(chan []byte, 1) + upstreamErrCh := make(chan error, 1) + upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + upstreamErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, payload, readErr := conn.Read(readCtx) + cancelRead() + if readErr != nil { + upstreamErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + upstreamErrCh <- errors.New("unexpected upstream websocket message type") + return + } + upstreamPayloadCh <- payload + + writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second) + writeErr := conn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.completed","response":{"id":"resp_usage_e2e","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}}`, + )) + cancelWrite() + if writeErr != nil { + upstreamErrCh <- writeErr + return + } + _ = conn.Close(coderws.StatusNormalClosure, "done") + upstreamErrCh <- nil + })) + defer upstreamServer.Close() + + groupID := int64(4201) + account := service.Account{ + ID: 9901, + Name: "openai-ws-passthrough-usage-e2e", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": upstreamServer.URL, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + "openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough, + }, + } + + cfg := &config.Config{} + cfg.RunMode = config.RunModeSimple + cfg.Default.RateMultiplier = 1 + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + accountRepo := &openAIWSUsageHandlerAccountRepoStub{account: account} + usageRepo := &openAIWSUsageHandlerUsageLogRepoStub{created: make(chan *service.UsageLog, 1)} + + var channelSvc *service.ChannelService + if len(tc.channelMapping) > 0 { + channelSvc = service.NewChannelService(&openAIWSUsageHandlerChannelRepoStub{ + channels: []service.Channel{{ + ID: 7701, + Name: "openai-ws-e2e-channel", + Status: service.StatusActive, + GroupIDs: []int64{groupID}, + ModelMapping: map[string]map[string]string{service.PlatformOpenAI: tc.channelMapping}, + }}, + groupPlatforms: map[int64]string{groupID: service.PlatformOpenAI}, + }, nil, nil, nil) + } + + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) + gatewaySvc := service.NewOpenAIGatewayService( + accountRepo, + usageRepo, + nil, + nil, + nil, + nil, + nil, + cfg, + nil, + nil, + service.NewBillingService(cfg, nil), + nil, + billingCacheSvc, + nil, + &service.DeferredService{}, + nil, + nil, + channelSvc, + nil, + nil, + ) + + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + h := &OpenAIGatewayHandler{ + gatewayService: gatewaySvc, + billingCacheService: billingCacheSvc, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second), + } + + apiKey := &service.APIKey{ + ID: 1801, + GroupID: &groupID, + User: &service.User{ID: 1701, Status: service.StatusActive}, + } + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1}) + c.Next() + }) + router.GET("/openai/v1/responses", h.ResponsesWebSocket) + handlerServer := httptest.NewServer(router) + defer handlerServer.Close() + + headers := http.Header{} + if tc.userAgent != nil { + headers.Set("User-Agent", *tc.userAgent) + } + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses", + &coderws.DialOptions{HTTPHeader: headers, CompressionMode: coderws.CompressionContextTakeover}, + ) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(tc.firstPayload)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, event, err := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, err) + require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String()) + _ = clientConn.Close(coderws.StatusNormalClosure, "done") + + var usageLog *service.UsageLog + select { + case usageLog = <-usageRepo.created: + require.NotNil(t, usageLog) + case <-time.After(3 * time.Second): + t.Fatal("等待 WebSocket usage log 写入超时") + } + + var upstreamFirstPayload []byte + select { + case upstreamFirstPayload = <-upstreamPayloadCh: + case <-time.After(3 * time.Second): + t.Fatal("等待上游 WebSocket 首帧超时") + } + + select { + case upstreamErr := <-upstreamErrCh: + require.NoError(t, upstreamErr) + case <-time.After(3 * time.Second): + t.Fatal("等待上游 WebSocket 结束超时") + } + + return openAIResponsesWSUsageLogResult{ + log: usageLog, + upstreamFirstPayload: upstreamFirstPayload, + } +} + +func testStringPtr(v string) *string { + return &v +} diff --git a/backend/internal/service/openai_fast_policy_ws_test.go b/backend/internal/service/openai_fast_policy_ws_test.go index 3316a242..7c8341b2 100644 --- a/backend/internal/service/openai_fast_policy_ws_test.go +++ b/backend/internal/service/openai_fast_policy_ws_test.go @@ -972,6 +972,62 @@ func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing "turn 3: response.create without service_tier overwrites billing to nil to match upstream default") } +func TestPassthroughUsageMeta_TracksReasoningEffortAcrossTurns(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","reasoning":{"effort":"medium"},"service_tier":"priority"}`) + meta := newOpenAIWSPassthroughUsageMeta("", firstFrame) + capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame) + firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, capturedSessionModel, firstFrame) + require.NoError(t, firstErr) + require.Nil(t, firstBlocked) + meta.initFromFirstFrame(firstOut) + require.NotNil(t, meta.reasoningEffort.Load()) + require.Equal(t, "medium", *meta.reasoningEffort.Load()) + + process := func(payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { + capturedSessionModel = updated + } + meta.updateSessionRequestModel(payload) + requestModelForThisFrame := meta.requestModelForFrame(payload) + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + if model == "" { + model = capturedSessionModel + } + out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload) + if policyErr == nil && blocked == nil && + strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { + meta.updateFromResponseCreate(out, requestModelForThisFrame) + } + return out, blocked, policyErr + } + + _, blockedSession, errSession := process([]byte(`{"type":"session.update","session":{"model":"gpt-5-high"}}`)) + require.NoError(t, errSession) + require.Nil(t, blockedSession) + require.NotNil(t, meta.reasoningEffort.Load()) + require.Equal(t, "medium", *meta.reasoningEffort.Load(), "session.update 只刷新后续 fallback model,不覆盖当前 turn metadata") + + _, blockedCancel, errCancel := process([]byte(`{"type":"response.cancel","reasoning_effort":"x-high"}`)) + require.NoError(t, errCancel) + require.Nil(t, blockedCancel) + require.NotNil(t, meta.reasoningEffort.Load()) + require.Equal(t, "medium", *meta.reasoningEffort.Load(), "非 response.create 帧不能污染当前 turn metadata") + + _, blockedFlat, errFlat := process([]byte(`{"type":"response.create","reasoning_effort":"x-high"}`)) + require.NoError(t, errFlat) + require.Nil(t, blockedFlat) + require.NotNil(t, meta.reasoningEffort.Load()) + require.Equal(t, "xhigh", *meta.reasoningEffort.Load(), "flat reasoning_effort 必须进入 passthrough usage metadata") + + _, blockedClear, errClear := process([]byte(`{"type":"response.create","model":"gpt-4o"}`)) + require.NoError(t, errClear) + require.Nil(t, blockedClear) + require.Nil(t, meta.reasoningEffort.Load(), "新的 response.create 无 effort 且无可推导后缀时必须清空旧值") +} + // TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the // "block keeps previous" semantic: when policy returns block on a // response.create frame, that frame is never sent upstream, so billing tier diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index d1386b1b..201073e0 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -219,8 +219,11 @@ func (e *OpenAIWSClientCloseError) Reason() string { // OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 type OpenAIWSIngressHooks struct { - BeforeTurn func(turn int) error - AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) + // InitialRequestModel 是首帧渠道映射前的请求模型,只用于 usage metadata + // 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。 + InitialRequestModel string + BeforeTurn func(turn int) error + AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) } func normalizeOpenAIWSLogValue(value string) string { diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index 30fd4142..5246d37d 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -399,7 +399,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR }() writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) - err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`)) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast","reasoning":{"effort":"HIGH"}}`)) cancelWrite() require.NoError(t, err) @@ -431,6 +431,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR require.Equal(t, 3, result.Usage.OutputTokens) require.NotNil(t, result.ServiceTier) require.Equal(t, "priority", *result.ServiceTier) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "high", *result.ReasoningEffort) case <-time.After(2 * time.Second): t.Fatal("未收到 passthrough turn 结果回调") } diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index 3dbb199a..8bc17d42 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -124,6 +124,73 @@ func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload [] return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original)) } +type openAIWSPassthroughUsageMeta struct { + serviceTier atomic.Pointer[string] + reasoningEffort atomic.Pointer[string] + + // 仅在 client->upstream filter goroutine 中读写;Load 侧通过上方原子指针同步。 + sessionRequestModel string +} + +func newOpenAIWSPassthroughUsageMeta(initialRequestModel string, firstFrame []byte) *openAIWSPassthroughUsageMeta { + meta := &openAIWSPassthroughUsageMeta{ + sessionRequestModel: strings.TrimSpace(initialRequestModel), + } + if meta.sessionRequestModel == "" { + meta.sessionRequestModel = openAIWSPassthroughRequestModelForFrame(firstFrame) + } + return meta +} + +func (m *openAIWSPassthroughUsageMeta) initFromFirstFrame(policyOutput []byte) { + if m == nil { + return + } + m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput)) + m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, m.sessionRequestModel)) +} + +func (m *openAIWSPassthroughUsageMeta) updateSessionRequestModel(payload []byte) { + if m == nil { + return + } + if model := openAIWSPassthroughRequestModelFromSessionFrame(payload); model != "" { + m.sessionRequestModel = model + } +} + +func (m *openAIWSPassthroughUsageMeta) requestModelForFrame(payload []byte) string { + if m == nil { + return openAIWSPassthroughRequestModelForFrame(payload) + } + if model := openAIWSPassthroughRequestModelForFrame(payload); model != "" { + return model + } + return m.sessionRequestModel +} + +func (m *openAIWSPassthroughUsageMeta) updateFromResponseCreate(policyOutput []byte, requestModelForFrame string) { + if m == nil { + return + } + m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput)) + m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, requestModelForFrame)) +} + +func openAIWSPassthroughRequestModelForFrame(payload []byte) string { + if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" { + return "" + } + return strings.TrimSpace(gjson.GetBytes(payload, "model").String()) +} + +func openAIWSPassthroughRequestModelFromSessionFrame(payload []byte) string { + if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "session.update" { + return "" + } + return strings.TrimSpace(gjson.GetBytes(payload, "session.model").String()) +} + const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) @@ -204,6 +271,11 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( // silently passed through, defeating the policy on every frame after // the first. capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage) + initialRequestModel := "" + if hooks != nil { + initialRequestModel = hooks.InitialRequestModel + } + usageMeta := newOpenAIWSPassthroughUsageMeta(initialRequestModel, firstClientMessage) updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage) if policyErr != nil { return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr) @@ -226,7 +298,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( } firstClientMessage = updatedFirst - // 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter + // 在 policy filter 之后再提取 service_tier / reasoning_effort 用于 + // usage 上报:filter // 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当 // 反映上游实际处理的 tier(nil = default),而不是用户最初请求的 // "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody)) @@ -237,11 +310,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( // codex-rs/core/src/client.rs build_responses_request 每次重新填值)。 // 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream // goroutine)和 OnTurnComplete / final result(runUpstreamToClient - // goroutine)之间同步当前 turn 的 service_tier。 - // extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型, - // 可直接 Store/Load 而无需额外封装。 - var requestServiceTierPtr atomic.Pointer[string] - requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage)) + // goroutine)之间同步当前 turn 的 usage metadata。 + usageMeta.initFromFirstFrame(firstClientMessage) wsURL, err := s.buildOpenAIResponsesWSURL(account) if err != nil { @@ -327,6 +397,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { capturedSessionModel = updated } + usageMeta.updateSessionRequestModel(payload) + requestModelForThisFrame := usageMeta.requestModelForFrame(payload) // Per-frame model first; if the client omits "model" on a // follow-up frame (legal in Realtime), fall back to the // session-level model captured from the first frame so the @@ -337,14 +409,14 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( model = capturedSessionModel } out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload) - // 多轮 passthrough billing:仅在成功(non-block / non-err) - // 的 response.create 帧上更新 requestServiceTierPtr,使用 + // 多轮 passthrough usage:仅在成功(non-block / non-err) + // 的 response.create 帧上更新 usageMeta,使用 // filter 处理后的 payload,与首帧 policy-after-extract 语义 // 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。 // - 非 response.create 帧(response.cancel / // conversation.item.create / session.update 等)不携带 - // per-response service_tier,不应覆盖前一轮值。 - // - blocked != nil:该帧不会发送上游,billing tier 应保持 + // per-response metadata,不应覆盖前一轮值。 + // - blocked != nil:该帧不会发送上游,usage metadata 应保持 // 上一轮值。 // - policyErr != nil:异常路径,保持上一轮值。 // - 不带 service_tier 的 response.create 会让 @@ -353,7 +425,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( // service_tier 时按 default 处理,billing 应如实反映。 if policyErr == nil && blocked == nil && strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { - requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out)) + usageMeta.updateFromResponseCreate(out, requestModelForThisFrame) } return out, blocked, policyErr }, @@ -397,7 +469,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( CacheReadInputTokens: turn.Usage.CacheReadInputTokens, }, Model: turn.RequestModel, - ServiceTier: requestServiceTierPtr.Load(), + ServiceTier: usageMeta.serviceTier.Load(), + ReasoningEffort: usageMeta.reasoningEffort.Load(), Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders), @@ -445,7 +518,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, }, Model: relayResult.RequestModel, - ServiceTier: requestServiceTierPtr.Load(), + ServiceTier: usageMeta.serviceTier.Load(), + ReasoningEffort: usageMeta.reasoningEffort.Load(), Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders),