Files
sub2api/backend/internal/service/openai_chat_completions_forward.go
7976723 a17ac50118 fix: 修复 Chat Completions 编译错误和运行时 panic
1. 修复 WriteFilteredHeaders API 不兼容(2处):
   将 s.cfg.Security.ResponseHeaders 改为 s.responseHeaderFilter,
   因为 main 分支已将函数签名改为接受 *responseheaders.CompiledHeaderFilter

2. 修复 writer 生命周期导致的 nil pointer panic:
   ChatCompletions handler 替换了 c.Writer 但未恢复,导致
   OpsErrorLogger 中间件的 defer 释放 opsCaptureWriter 后,
   Logger 中间件调用 c.Writer.Status() 触发空指针解引用。
   通过保存并恢复 originalWriter 修复。

3. 为 chatCompletionsResponseWriter 添加防御性 Status() 和
   Written() 方法,包含 nil 安全检查

4. 恢复 gateway.go 中被误删的 net/http import
2026-03-11 13:49:13 +08:00

489 lines
13 KiB
Go

package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
)
type chatStreamingResult struct {
usage *OpenAIUsage
firstTokenMs *int
}
func (s *OpenAIGatewayService) forwardChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte, includeUsage bool, startTime time.Time) (*OpenAIForwardResult, error) {
// Parse request body once (avoid multiple parse/serialize cycles)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
originalModel := reqModel
bodyModified := false
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
log.Printf("[OpenAI Chat] Model mapping applied: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
reqBody["model"] = mappedModel
bodyModified = true
}
if reqStream && includeUsage {
streamOptions, _ := reqBody["stream_options"].(map[string]any)
if streamOptions == nil {
streamOptions = map[string]any{}
}
if _, ok := streamOptions["include_usage"]; !ok {
streamOptions["include_usage"] = true
reqBody["stream_options"] = streamOptions
bodyModified = true
}
}
if bodyModified {
var err error
body, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("serialize request body: %w", err)
}
}
// Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
upstreamReq, err := s.buildChatCompletionsRequest(ctx, c, account, body, token)
if err != nil {
return nil, err
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
}
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleChatCompletionsStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
if err != nil {
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleChatCompletionsNonStreamingResponse(resp, c, originalModel, mappedModel)
if err != nil {
return nil, err
}
}
if usage == nil {
usage = &OpenAIUsage{}
}
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *OpenAIGatewayService) buildChatCompletionsRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string) (*http.Request, error) {
var targetURL string
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
targetURL = openaiChatAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/chat/completions"
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("authorization", "Bearer "+token)
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiChatAllowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
req.Header.Set("user-agent", customUA)
}
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
return req, nil
}
func (s *OpenAIGatewayService) handleChatCompletionsStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*chatStreamingResult, error) {
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &OpenAIUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
for {
select {
case ev, ok := <-events:
if !ok {
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
lastDataAt = time.Now()
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
line = "data: " + correctedData
}
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
if firstTokenMs == nil {
if event := parseChatStreamEvent(data); event != nil {
if chatChunkHasDelta(event) {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
applyChatUsageFromEvent(event, usage)
}
} else {
if event := parseChatStreamEvent(data); event != nil {
applyChatUsageFromEvent(event, usage)
}
}
} else {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
sendErrorEvent("stream_timeout")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
}
func (s *OpenAIGatewayService) handleChatCompletionsNonStreamingResponse(resp *http.Response, c *gin.Context, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
usage := &OpenAIUsage{}
var parsed map[string]any
if json.Unmarshal(body, &parsed) == nil {
if usageMap, ok := parsed["usage"].(map[string]any); ok {
applyChatUsageFromMap(usageMap, usage)
}
}
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
body = s.correctToolCallsInResponseBody(body)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func parseChatStreamEvent(data string) map[string]any {
if data == "" || data == "[DONE]" {
return nil
}
var event map[string]any
if json.Unmarshal([]byte(data), &event) != nil {
return nil
}
return event
}
func chatChunkHasDelta(event map[string]any) bool {
choices, ok := event["choices"].([]any)
if !ok {
return false
}
for _, choice := range choices {
choiceMap, ok := choice.(map[string]any)
if !ok {
continue
}
delta, ok := choiceMap["delta"].(map[string]any)
if !ok {
continue
}
if content, ok := delta["content"].(string); ok && strings.TrimSpace(content) != "" {
return true
}
if toolCalls, ok := delta["tool_calls"].([]any); ok && len(toolCalls) > 0 {
return true
}
if functionCall, ok := delta["function_call"].(map[string]any); ok && len(functionCall) > 0 {
return true
}
}
return false
}
func applyChatUsageFromEvent(event map[string]any, usage *OpenAIUsage) {
if event == nil || usage == nil {
return
}
usageMap, ok := event["usage"].(map[string]any)
if !ok {
return
}
applyChatUsageFromMap(usageMap, usage)
}
func applyChatUsageFromMap(usageMap map[string]any, usage *OpenAIUsage) {
if usageMap == nil || usage == nil {
return
}
promptTokens := int(getNumber(usageMap["prompt_tokens"]))
completionTokens := int(getNumber(usageMap["completion_tokens"]))
if promptTokens > 0 {
usage.InputTokens = promptTokens
}
if completionTokens > 0 {
usage.OutputTokens = completionTokens
}
}