fix: 修复gpt->claude同步请求返回sse的bug

This commit is contained in:
shaw
2026-03-09 15:53:01 +08:00
parent a461538d58
commit 25178cdbe1

View File

@@ -39,7 +39,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("parse anthropic request: %w", err) return nil, fmt.Errorf("parse anthropic request: %w", err)
} }
originalModel := anthropicReq.Model originalModel := anthropicReq.Model
isStream := anthropicReq.Stream clientStream := anthropicReq.Stream // client's original stream preference
// 2. Convert Anthropic → Responses // 2. Convert Anthropic → Responses
responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq) responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
@@ -47,6 +47,11 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("convert anthropic to responses: %w", err) return nil, fmt.Errorf("convert anthropic to responses: %w", err)
} }
// Upstream always uses streaming (upstream may not support sync mode).
// The client's original preference determines the response format.
responsesReq.Stream = true
isStream := true
// 2b. Handle BetaFastMode → service_tier: "priority" // 2b. Handle BetaFastMode → service_tier: "priority"
if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) { if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) {
responsesReq.ServiceTier = "priority" responsesReq.ServiceTier = "priority"
@@ -169,12 +174,14 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
} }
// 9. Handle normal response // 9. Handle normal response
// Upstream is always streaming; choose response format based on client preference.
var result *OpenAIForwardResult var result *OpenAIForwardResult
var handleErr error var handleErr error
if isStream { if clientStream {
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime) result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime)
} else { } else {
result, handleErr = s.handleAnthropicNonStreamingResponse(resp, c, originalModel, mappedModel, startTime) // Client wants JSON: buffer the streaming response and assemble a JSON reply.
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
} }
// Propagate ServiceTier and ReasoningEffort to result for billing // Propagate ServiceTier and ReasoningEffort to result for billing
@@ -256,9 +263,13 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
} }
// handleAnthropicNonStreamingResponse reads a Responses API JSON response, // handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
// converts it to Anthropic Messages format, and writes it to the client. // the upstream streaming response, finds the terminal event (response.completed
func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse( // / response.incomplete / response.failed), converts the complete response to
// Anthropic Messages JSON format, and writes it to the client.
// This is used when the client requested stream=false but the upstream is always
// streaming.
func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
resp *http.Response, resp *http.Response,
c *gin.Context, c *gin.Context,
originalModel string, originalModel string,
@@ -267,29 +278,61 @@ func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse(
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
respBody, err := io.ReadAll(resp.Body) scanner := bufio.NewScanner(resp.Body)
if err != nil { scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
return nil, fmt.Errorf("read upstream response: %w", err)
}
var responsesResp apicompat.ResponsesResponse
if err := json.Unmarshal(respBody, &responsesResp); err != nil {
return nil, fmt.Errorf("parse responses response: %w", err)
}
anthropicResp := apicompat.ResponsesToAnthropic(&responsesResp, originalModel)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage var usage OpenAIUsage
if responsesResp.Usage != nil {
usage = OpenAIUsage{ for scanner.Scan() {
InputTokens: responsesResp.Usage.InputTokens, line := scanner.Text()
OutputTokens: responsesResp.Usage.OutputTokens,
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
} }
if responsesResp.Usage.InputTokensDetails != nil { payload := line[6:]
usage.CacheReadInputTokens = responsesResp.Usage.InputTokensDetails.CachedTokens
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai messages buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Terminal events carry the complete ResponsesResponse with output + usage.
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
} }
} }
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai messages buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResponse == nil {
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
return nil, fmt.Errorf("upstream stream ended without terminal event")
}
anthropicResp := apicompat.ResponsesToAnthropic(finalResponse, originalModel)
if s.responseHeaderFilter != nil { if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
} }
@@ -307,6 +350,9 @@ func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse(
// handleAnthropicStreamingResponse reads Responses SSE events from upstream, // handleAnthropicStreamingResponse reads Responses SSE events from upstream,
// converts each to Anthropic SSE events, and writes them to the client. // converts each to Anthropic SSE events, and writes them to the client.
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
// pattern to send Anthropic ping events during periods of upstream silence,
// preventing proxy/client timeout disconnections.
func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
resp *http.Response, resp *http.Response,
c *gin.Context, c *gin.Context,
@@ -322,6 +368,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive") c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK) c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewResponsesEventToAnthropicState() state := apicompat.NewResponsesEventToAnthropicState()
@@ -333,28 +380,35 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() { // resultWithUsage builds the final result snapshot.
line := scanner.Text() resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { RequestID: requestID,
continue Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
} }
payload := line[6:] }
// processDataLine handles a single "data: ..." SSE line from upstream.
// Returns (clientDisconnected bool).
processDataLine := func(payload string) bool {
if firstChunk { if firstChunk {
firstChunk = false firstChunk = false
ms := int(time.Since(startTime).Milliseconds()) ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms firstTokenMs = &ms
} }
// Parse the Responses SSE event
var event apicompat.ResponsesStreamEvent var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil { if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai messages stream: failed to parse event", logger.L().Warn("openai messages stream: failed to parse event",
zap.Error(err), zap.Error(err),
zap.String("request_id", requestID), zap.String("request_id", requestID),
) )
continue return false
} }
// Extract usage from completion events // Extract usage from completion events
@@ -381,28 +435,36 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
continue continue
} }
if _, err := fmt.Fprint(c.Writer, sse); err != nil { if _, err := fmt.Fprint(c.Writer, sse); err != nil {
// Client disconnected — return collected usage
logger.L().Info("openai messages stream: client disconnected", logger.L().Info("openai messages stream: client disconnected",
zap.String("request_id", requestID), zap.String("request_id", requestID),
) )
return &OpenAIForwardResult{ return true
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
} }
} }
if len(events) > 0 { if len(events) > 0 {
c.Writer.Flush() c.Writer.Flush()
} }
return false
} }
if err := scanner.Err(); err != nil { // finalizeStream sends any remaining Anthropic events and returns the result.
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { finalizeStream := func() (*OpenAIForwardResult, error) {
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
for _, evt := range finalEvents {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
c.Writer.Flush()
}
return resultWithUsage(), nil
}
// handleScanErr logs scanner errors if meaningful.
handleScanErr := func(err error) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai messages stream: read error", logger.L().Warn("openai messages stream: read error",
zap.Error(err), zap.Error(err),
zap.String("request_id", requestID), zap.String("request_id", requestID),
@@ -410,27 +472,94 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
} }
// Ensure the Anthropic stream is properly terminated // ── Determine keepalive interval ──
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { keepaliveInterval := time.Duration(0)
for _, evt := range finalEvents { if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
c.Writer.Flush()
} }
return &OpenAIForwardResult{ // ── No keepalive: fast synchronous path (no goroutine overhead) ──
RequestID: requestID, if keepaliveInterval <= 0 {
Usage: usage, for scanner.Scan() {
Model: originalModel, line := scanner.Text()
BillingModel: mappedModel, if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
Stream: true, continue
Duration: time.Since(startTime), }
FirstTokenMs: firstTokenMs, if processDataLine(line[6:]) {
}, nil return resultWithUsage(), nil
}
}
handleScanErr(scanner.Err())
return finalizeStream()
}
// ── With keepalive: goroutine + channel + select ──
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
go func() {
defer close(events)
for scanner.Scan() {
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
lastDataAt := time.Now()
for {
select {
case ev, ok := <-events:
if !ok {
// Upstream closed
return finalizeStream()
}
if ev.err != nil {
handleScanErr(ev.err)
return finalizeStream()
}
lastDataAt = time.Now()
line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
case <-keepaliveTicker.C:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// Send Anthropic-format ping event
if _, err := fmt.Fprint(c.Writer, "event: ping\ndata: {\"type\":\"ping\"}\n\n"); err != nil {
// Client disconnected
logger.L().Info("openai messages stream: client disconnected during keepalive",
zap.String("request_id", requestID),
)
return resultWithUsage(), nil
}
c.Writer.Flush()
}
}
} }
// writeAnthropicError writes an error response in Anthropic Messages API format. // writeAnthropicError writes an error response in Anthropic Messages API format.