mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
fix(openai): 修复 WS passthrough 使用记录缺失推理强度和 User-Agent
- 为 OpenAI Responses WebSocket v2 passthrough 补齐每轮 reasoning_effort 元数据 - 传递首帧渠道映射前模型,保留模型后缀推理强度推导能力 - 增加 usage log 端到端回归,覆盖入站 User-Agent、显式 effort 和渠道映射场景
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -651,6 +652,46 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
|
||||
userAgent: testStringPtr("codex_cli_rs/0.125.0 test"),
|
||||
})
|
||||
|
||||
require.NotNil(t, got.log.UserAgent)
|
||||
require.Equal(t, "codex_cli_rs/0.125.0 test", *got.log.UserAgent)
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "high", *got.log.ReasoningEffort)
|
||||
require.True(t, got.log.OpenAIWSMode)
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogInfersReasoningFromInitialRequestModel(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4-xhigh","stream":false}`,
|
||||
userAgent: testStringPtr("codex_cli_rs/0.125.0 mapped"),
|
||||
channelMapping: map[string]string{
|
||||
"gpt-5.4-xhigh": "gpt-5.4",
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, "gpt-5.4", gjson.GetBytes(got.upstreamFirstPayload, "model").String(),
|
||||
"上游首帧应使用渠道映射后的模型")
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "xhigh", *got.log.ReasoningEffort,
|
||||
"usage log reasoning effort 必须使用渠道映射前首帧模型后缀推导")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogLeavesUserAgentNilWhenMissing(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"medium"}}`,
|
||||
userAgent: testStringPtr(""),
|
||||
})
|
||||
|
||||
require.Nil(t, got.log.UserAgent, "空入站 User-Agent 不应由上游握手 UA 或默认 UA 兜底")
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "medium", *got.log.ReasoningEffort)
|
||||
}
|
||||
|
||||
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -796,3 +837,278 @@ func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
return httptest.NewServer(router)
|
||||
}
|
||||
|
||||
type openAIResponsesWSUsageLogCase struct {
|
||||
firstPayload string
|
||||
userAgent *string
|
||||
channelMapping map[string]string
|
||||
}
|
||||
|
||||
type openAIResponsesWSUsageLogResult struct {
|
||||
log *service.UsageLog
|
||||
upstreamFirstPayload []byte
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerAccountRepoStub struct {
|
||||
service.AccountRepository
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
if s.account.Platform != platform {
|
||||
return nil, nil
|
||||
}
|
||||
return []service.Account{s.account}, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||
return s.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID != id {
|
||||
return nil, nil
|
||||
}
|
||||
account := s.account
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerUsageLogRepoStub struct {
|
||||
service.UsageLogRepository
|
||||
created chan *service.UsageLog
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerUsageLogRepoStub) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
if s.created != nil {
|
||||
s.created <- log
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerChannelRepoStub struct {
|
||||
service.ChannelRepository
|
||||
channels []service.Channel
|
||||
groupPlatforms map[int64]string
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerChannelRepoStub) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
return s.channels, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
out := make(map[int64]string, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
if platform := strings.TrimSpace(s.groupPlatforms[groupID]); platform != "" {
|
||||
out[groupID] = platform
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstreamPayloadCh := make(chan []byte, 1)
|
||||
upstreamErrCh := make(chan error, 1)
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
upstreamErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, payload, readErr := conn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr != nil {
|
||||
upstreamErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
upstreamErrCh <- errors.New("unexpected upstream websocket message type")
|
||||
return
|
||||
}
|
||||
upstreamPayloadCh <- payload
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
writeErr := conn.Write(writeCtx, coderws.MessageText, []byte(
|
||||
`{"type":"response.completed","response":{"id":"resp_usage_e2e","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}}`,
|
||||
))
|
||||
cancelWrite()
|
||||
if writeErr != nil {
|
||||
upstreamErrCh <- writeErr
|
||||
return
|
||||
}
|
||||
_ = conn.Close(coderws.StatusNormalClosure, "done")
|
||||
upstreamErrCh <- nil
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
|
||||
groupID := int64(4201)
|
||||
account := service.Account{
|
||||
ID: 9901,
|
||||
Name: "openai-ws-passthrough-usage-e2e",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": upstreamServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.RunMode = config.RunModeSimple
|
||||
cfg.Default.RateMultiplier = 1
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
accountRepo := &openAIWSUsageHandlerAccountRepoStub{account: account}
|
||||
usageRepo := &openAIWSUsageHandlerUsageLogRepoStub{created: make(chan *service.UsageLog, 1)}
|
||||
|
||||
var channelSvc *service.ChannelService
|
||||
if len(tc.channelMapping) > 0 {
|
||||
channelSvc = service.NewChannelService(&openAIWSUsageHandlerChannelRepoStub{
|
||||
channels: []service.Channel{{
|
||||
ID: 7701,
|
||||
Name: "openai-ws-e2e-channel",
|
||||
Status: service.StatusActive,
|
||||
GroupIDs: []int64{groupID},
|
||||
ModelMapping: map[string]map[string]string{service.PlatformOpenAI: tc.channelMapping},
|
||||
}},
|
||||
groupPlatforms: map[int64]string{groupID: service.PlatformOpenAI},
|
||||
}, nil, nil, nil)
|
||||
}
|
||||
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
gatewaySvc := service.NewOpenAIGatewayService(
|
||||
accountRepo,
|
||||
usageRepo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
service.NewBillingService(cfg, nil),
|
||||
nil,
|
||||
billingCacheSvc,
|
||||
nil,
|
||||
&service.DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
channelSvc,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: gatewaySvc,
|
||||
billingCacheService: billingCacheSvc,
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||
}
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 1801,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: 1701, Status: service.StatusActive},
|
||||
}
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
handlerServer := httptest.NewServer(router)
|
||||
defer handlerServer.Close()
|
||||
|
||||
headers := http.Header{}
|
||||
if tc.userAgent != nil {
|
||||
headers.Set("User-Agent", *tc.userAgent)
|
||||
}
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(
|
||||
dialCtx,
|
||||
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
|
||||
&coderws.DialOptions{HTTPHeader: headers, CompressionMode: coderws.CompressionContextTakeover},
|
||||
)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(tc.firstPayload))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, event, err := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||
|
||||
var usageLog *service.UsageLog
|
||||
select {
|
||||
case usageLog = <-usageRepo.created:
|
||||
require.NotNil(t, usageLog)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待 WebSocket usage log 写入超时")
|
||||
}
|
||||
|
||||
var upstreamFirstPayload []byte
|
||||
select {
|
||||
case upstreamFirstPayload = <-upstreamPayloadCh:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待上游 WebSocket 首帧超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case upstreamErr := <-upstreamErrCh:
|
||||
require.NoError(t, upstreamErr)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待上游 WebSocket 结束超时")
|
||||
}
|
||||
|
||||
return openAIResponsesWSUsageLogResult{
|
||||
log: usageLog,
|
||||
upstreamFirstPayload: upstreamFirstPayload,
|
||||
}
|
||||
}
|
||||
|
||||
func testStringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user