From b2bdba78dd15c4454d9236ea1a50258d68bbff98 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 3 May 2026 14:56:09 +0800 Subject: [PATCH] stabilize image request handling --- backend/internal/service/openai_images.go | 62 ++++++++++- .../internal/service/openai_images_test.go | 103 ++++++++++++++++++ 2 files changed, 159 insertions(+), 6 deletions(-) diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 4badcb1c..3da76525 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -596,7 +596,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( var usage OpenAIUsage imageCount := parsed.N var firstTokenMs *int - if parsed.Stream { + if parsed.Stream && isEventStreamResponse(resp.Header) { streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) if err != nil { return nil, err @@ -811,6 +811,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( usage := OpenAIUsage{} imageCount := 0 var firstTokenMs *int + var fallbackBody bytes.Buffer + fallbackBytes := int64(0) + fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg) + seenSSEData := false + fallbackTooLarge := false for { line, err := reader.ReadBytes('\n') @@ -824,11 +829,24 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( } flusher.Flush() - if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" { - dataBytes := []byte(data) - mergeOpenAIUsage(&usage, dataBytes) - if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount { - imageCount = count + if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok { + if data != "" && data != "[DONE]" { + seenSSEData = true + fallbackBody.Reset() + fallbackBytes = 0 + dataBytes := []byte(data) + mergeOpenAIUsage(&usage, dataBytes) + if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount { + imageCount = count + } + } + } else if !seenSSEData && !fallbackTooLarge { + fallbackBytes += int64(len(line)) + if fallbackBytes <= fallbackLimit { + _, _ = fallbackBody.Write(line) + } else { + fallbackTooLarge = true + fallbackBody.Reset() } } } @@ -839,9 +857,41 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( return OpenAIUsage{}, 0, firstTokenMs, err } } + if !seenSSEData && fallbackBody.Len() > 0 { + body := bytes.TrimSpace(fallbackBody.Bytes()) + if len(body) > 0 { + mergeOpenAIUsage(&usage, body) + if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount { + imageCount = count + } + } + } return usage, imageCount, firstTokenMs, nil } +func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int { + if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 { + return count + } + if len(body) == 0 || !gjson.ValidBytes(body) { + return 0 + } + if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 { + return count + } + if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 { + return count + } + eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String()) + if eventType == "" || !strings.HasSuffix(eventType, ".completed") { + return 0 + } + if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() { + return 1 + } + return 0 +} + func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { if dst == nil { return diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 47113d4d..681e0e8e 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -446,6 +446,109 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseU require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) } +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"req_img_stream_json"}, + }, + Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 7, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + "base_url": "https://image-upstream.example/v1", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 21, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.ImageOutputTokens) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) +} + +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_json_mislabeled"}, + }, + Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 8, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + "base_url": "https://image-upstream.example/v1", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 10, result.Usage.InputTokens) + require.Equal(t, 18, result.Usage.OutputTokens) + require.Equal(t, 8, result.Usage.ImageOutputTokens) + require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) +} + +func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) { + body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`) + + require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body)) +} + func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) { gin.SetMode(gin.TestMode)