mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
stabilize image request handling
This commit is contained in:
@@ -596,7 +596,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
|||||||
var usage OpenAIUsage
|
var usage OpenAIUsage
|
||||||
imageCount := parsed.N
|
imageCount := parsed.N
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
if parsed.Stream {
|
if parsed.Stream && isEventStreamResponse(resp.Header) {
|
||||||
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
|
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -811,6 +811,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
|||||||
usage := OpenAIUsage{}
|
usage := OpenAIUsage{}
|
||||||
imageCount := 0
|
imageCount := 0
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
|
var fallbackBody bytes.Buffer
|
||||||
|
fallbackBytes := int64(0)
|
||||||
|
fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
seenSSEData := false
|
||||||
|
fallbackTooLarge := false
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadBytes('\n')
|
line, err := reader.ReadBytes('\n')
|
||||||
@@ -824,13 +829,26 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
|||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|
||||||
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" {
|
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok {
|
||||||
|
if data != "" && data != "[DONE]" {
|
||||||
|
seenSSEData = true
|
||||||
|
fallbackBody.Reset()
|
||||||
|
fallbackBytes = 0
|
||||||
dataBytes := []byte(data)
|
dataBytes := []byte(data)
|
||||||
mergeOpenAIUsage(&usage, dataBytes)
|
mergeOpenAIUsage(&usage, dataBytes)
|
||||||
if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount {
|
if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount {
|
||||||
imageCount = count
|
imageCount = count
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if !seenSSEData && !fallbackTooLarge {
|
||||||
|
fallbackBytes += int64(len(line))
|
||||||
|
if fallbackBytes <= fallbackLimit {
|
||||||
|
_, _ = fallbackBody.Write(line)
|
||||||
|
} else {
|
||||||
|
fallbackTooLarge = true
|
||||||
|
fallbackBody.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
@@ -839,9 +857,41 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
|||||||
return OpenAIUsage{}, 0, firstTokenMs, err
|
return OpenAIUsage{}, 0, firstTokenMs, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !seenSSEData && fallbackBody.Len() > 0 {
|
||||||
|
body := bytes.TrimSpace(fallbackBody.Bytes())
|
||||||
|
if len(body) > 0 {
|
||||||
|
mergeOpenAIUsage(&usage, body)
|
||||||
|
if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount {
|
||||||
|
imageCount = count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return usage, imageCount, firstTokenMs, nil
|
return usage, imageCount, firstTokenMs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int {
|
||||||
|
if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 {
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 {
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 {
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String())
|
||||||
|
if eventType == "" || !strings.HasSuffix(eventType, ".completed") {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
|
func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
|
||||||
if dst == nil {
|
if dst == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -446,6 +446,109 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseU
|
|||||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
httpUpstream: &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"X-Request-Id": []string{"req_img_stream_json"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 7,
|
||||||
|
Name: "openai-apikey",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"base_url": "https://image-upstream.example/v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.Stream)
|
||||||
|
require.Equal(t, 1, result.ImageCount)
|
||||||
|
require.Equal(t, 12, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 21, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 9, result.Usage.ImageOutputTokens)
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
httpUpstream: &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/event-stream"},
|
||||||
|
"X-Request-Id": []string{"req_img_stream_json_mislabeled"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 8,
|
||||||
|
Name: "openai-apikey",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"base_url": "https://image-upstream.example/v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.Stream)
|
||||||
|
require.Equal(t, 1, result.ImageCount)
|
||||||
|
require.Equal(t, 10, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 18, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 8, result.Usage.ImageOutputTokens)
|
||||||
|
require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) {
|
||||||
|
body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`)
|
||||||
|
|
||||||
|
require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body))
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
|
func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user