fix(openai): 修复 WS passthrough 使用记录缺失推理强度和 User-Agent

- 为 OpenAI Responses WebSocket v2 passthrough 补齐每轮 reasoning_effort 元数据
- 传递首帧渠道映射前模型,保留模型后缀推理强度推导能力
- 增加 usage log 端到端回归,覆盖入站 User-Agent、显式 effort 和渠道映射场景
This commit is contained in:
deqiying
2026-05-03 19:33:09 +08:00
parent 48912014a1
commit 23555be380
6 changed files with 468 additions and 16 deletions

View File

@@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -651,6 +652,46 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
}
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
userAgent: testStringPtr("codex_cli_rs/0.125.0 test"),
})
require.NotNil(t, got.log.UserAgent)
require.Equal(t, "codex_cli_rs/0.125.0 test", *got.log.UserAgent)
require.NotNil(t, got.log.ReasoningEffort)
require.Equal(t, "high", *got.log.ReasoningEffort)
require.True(t, got.log.OpenAIWSMode)
}
func TestOpenAIResponsesWebSocket_PassthroughUsageLogInfersReasoningFromInitialRequestModel(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4-xhigh","stream":false}`,
userAgent: testStringPtr("codex_cli_rs/0.125.0 mapped"),
channelMapping: map[string]string{
"gpt-5.4-xhigh": "gpt-5.4",
},
})
require.Equal(t, "gpt-5.4", gjson.GetBytes(got.upstreamFirstPayload, "model").String(),
"上游首帧应使用渠道映射后的模型")
require.NotNil(t, got.log.ReasoningEffort)
require.Equal(t, "xhigh", *got.log.ReasoningEffort,
"usage log reasoning effort 必须使用渠道映射前首帧模型后缀推导")
}
func TestOpenAIResponsesWebSocket_PassthroughUsageLogLeavesUserAgentNilWhenMissing(t *testing.T) {
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"medium"}}`,
userAgent: testStringPtr(""),
})
require.Nil(t, got.log.UserAgent, "空入站 User-Agent 不应由上游握手 UA 或默认 UA 兜底")
require.NotNil(t, got.log.ReasoningEffort)
require.Equal(t, "medium", *got.log.ReasoningEffort)
}
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -796,3 +837,278 @@ func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
return httptest.NewServer(router)
}
type openAIResponsesWSUsageLogCase struct {
firstPayload string
userAgent *string
channelMapping map[string]string
}
type openAIResponsesWSUsageLogResult struct {
log *service.UsageLog
upstreamFirstPayload []byte
}
type openAIWSUsageHandlerAccountRepoStub struct {
service.AccountRepository
account service.Account
}
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
if s.account.Platform != platform {
return nil, nil
}
return []service.Account{s.account}, nil
}
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return s.ListSchedulableByPlatform(ctx, platform)
}
func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
if s.account.ID != id {
return nil, nil
}
account := s.account
return &account, nil
}
type openAIWSUsageHandlerUsageLogRepoStub struct {
service.UsageLogRepository
created chan *service.UsageLog
}
func (s *openAIWSUsageHandlerUsageLogRepoStub) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
if s.created != nil {
s.created <- log
}
return true, nil
}
type openAIWSUsageHandlerChannelRepoStub struct {
service.ChannelRepository
channels []service.Channel
groupPlatforms map[int64]string
}
func (s *openAIWSUsageHandlerChannelRepoStub) ListAll(ctx context.Context) ([]service.Channel, error) {
return s.channels, nil
}
func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
out := make(map[int64]string, len(groupIDs))
for _, groupID := range groupIDs {
if platform := strings.TrimSpace(s.groupPlatforms[groupID]); platform != "" {
out[groupID] = platform
}
}
return out, nil
}
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
t.Helper()
gin.SetMode(gin.TestMode)
upstreamPayloadCh := make(chan []byte, 1)
upstreamErrCh := make(chan error, 1)
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
upstreamErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
msgType, payload, readErr := conn.Read(readCtx)
cancelRead()
if readErr != nil {
upstreamErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
upstreamErrCh <- errors.New("unexpected upstream websocket message type")
return
}
upstreamPayloadCh <- payload
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
writeErr := conn.Write(writeCtx, coderws.MessageText, []byte(
`{"type":"response.completed","response":{"id":"resp_usage_e2e","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}}`,
))
cancelWrite()
if writeErr != nil {
upstreamErrCh <- writeErr
return
}
_ = conn.Close(coderws.StatusNormalClosure, "done")
upstreamErrCh <- nil
}))
defer upstreamServer.Close()
groupID := int64(4201)
account := service.Account{
ID: 9901,
Name: "openai-ws-passthrough-usage-e2e",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": upstreamServer.URL,
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
},
}
cfg := &config.Config{}
cfg.RunMode = config.RunModeSimple
cfg.Default.RateMultiplier = 1
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
accountRepo := &openAIWSUsageHandlerAccountRepoStub{account: account}
usageRepo := &openAIWSUsageHandlerUsageLogRepoStub{created: make(chan *service.UsageLog, 1)}
var channelSvc *service.ChannelService
if len(tc.channelMapping) > 0 {
channelSvc = service.NewChannelService(&openAIWSUsageHandlerChannelRepoStub{
channels: []service.Channel{{
ID: 7701,
Name: "openai-ws-e2e-channel",
Status: service.StatusActive,
GroupIDs: []int64{groupID},
ModelMapping: map[string]map[string]string{service.PlatformOpenAI: tc.channelMapping},
}},
groupPlatforms: map[int64]string{groupID: service.PlatformOpenAI},
}, nil, nil, nil)
}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
gatewaySvc := service.NewOpenAIGatewayService(
accountRepo,
usageRepo,
nil,
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
service.NewBillingService(cfg, nil),
nil,
billingCacheSvc,
nil,
&service.DeferredService{},
nil,
nil,
channelSvc,
nil,
nil,
)
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
h := &OpenAIGatewayHandler{
gatewayService: gatewaySvc,
billingCacheService: billingCacheSvc,
apiKeyService: &service.APIKeyService{},
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
}
apiKey := &service.APIKey{
ID: 1801,
GroupID: &groupID,
User: &service.User{ID: 1701, Status: service.StatusActive},
}
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
c.Next()
})
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
handlerServer := httptest.NewServer(router)
defer handlerServer.Close()
headers := http.Header{}
if tc.userAgent != nil {
headers.Set("User-Agent", *tc.userAgent)
}
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(
dialCtx,
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
&coderws.DialOptions{HTTPHeader: headers, CompressionMode: coderws.CompressionContextTakeover},
)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(tc.firstPayload))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, err := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, err)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
var usageLog *service.UsageLog
select {
case usageLog = <-usageRepo.created:
require.NotNil(t, usageLog)
case <-time.After(3 * time.Second):
t.Fatal("等待 WebSocket usage log 写入超时")
}
var upstreamFirstPayload []byte
select {
case upstreamFirstPayload = <-upstreamPayloadCh:
case <-time.After(3 * time.Second):
t.Fatal("等待上游 WebSocket 首帧超时")
}
select {
case upstreamErr := <-upstreamErrCh:
require.NoError(t, upstreamErr)
case <-time.After(3 * time.Second):
t.Fatal("等待上游 WebSocket 结束超时")
}
return openAIResponsesWSUsageLogResult{
log: usageLog,
upstreamFirstPayload: upstreamFirstPayload,
}
}
func testStringPtr(v string) *string {
return &v
}