mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-16 21:04:45 +08:00
1. 修复 WriteFilteredHeaders API 不兼容(2处): 将 s.cfg.Security.ResponseHeaders 改为 s.responseHeaderFilter, 因为 main 分支已将函数签名改为接受 *responseheaders.CompiledHeaderFilter 2. 修复 writer 生命周期导致的 nil pointer panic: ChatCompletions handler 替换了 c.Writer 但未恢复,导致 OpsErrorLogger 中间件的 defer 释放 opsCaptureWriter 后, Logger 中间件调用 c.Writer.Status() 触发空指针解引用。 通过保存并恢复 originalWriter 修复。 3. 为 chatCompletionsResponseWriter 添加防御性 Status() 和 Written() 方法,包含 nil 安全检查 4. 恢复 gateway.go 中被误删的 net/http import
489 lines
13 KiB
Go
489 lines
13 KiB
Go
package service
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type chatStreamingResult struct {
|
|
usage *OpenAIUsage
|
|
firstTokenMs *int
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) forwardChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte, includeUsage bool, startTime time.Time) (*OpenAIForwardResult, error) {
|
|
// Parse request body once (avoid multiple parse/serialize cycles)
|
|
var reqBody map[string]any
|
|
if err := json.Unmarshal(body, &reqBody); err != nil {
|
|
return nil, fmt.Errorf("parse request: %w", err)
|
|
}
|
|
|
|
reqModel, _ := reqBody["model"].(string)
|
|
reqStream, _ := reqBody["stream"].(bool)
|
|
originalModel := reqModel
|
|
|
|
bodyModified := false
|
|
mappedModel := account.GetMappedModel(reqModel)
|
|
if mappedModel != reqModel {
|
|
log.Printf("[OpenAI Chat] Model mapping applied: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
|
|
reqBody["model"] = mappedModel
|
|
bodyModified = true
|
|
}
|
|
|
|
if reqStream && includeUsage {
|
|
streamOptions, _ := reqBody["stream_options"].(map[string]any)
|
|
if streamOptions == nil {
|
|
streamOptions = map[string]any{}
|
|
}
|
|
if _, ok := streamOptions["include_usage"]; !ok {
|
|
streamOptions["include_usage"] = true
|
|
reqBody["stream_options"] = streamOptions
|
|
bodyModified = true
|
|
}
|
|
}
|
|
|
|
if bodyModified {
|
|
var err error
|
|
body, err = json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("serialize request body: %w", err)
|
|
}
|
|
}
|
|
|
|
// Get access token
|
|
token, _, err := s.GetAccessToken(ctx, account)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
upstreamReq, err := s.buildChatCompletionsRequest(ctx, c, account, body, token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
proxyURL := ""
|
|
if account.ProxyID != nil && account.Proxy != nil {
|
|
proxyURL = account.Proxy.URL()
|
|
}
|
|
|
|
if c != nil {
|
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
|
}
|
|
|
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
|
if err != nil {
|
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
|
setOpsUpstreamError(c, 0, safeErr, "")
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
Platform: account.Platform,
|
|
AccountID: account.ID,
|
|
AccountName: account.Name,
|
|
UpstreamStatusCode: 0,
|
|
Kind: "request_error",
|
|
Message: safeErr,
|
|
})
|
|
c.JSON(http.StatusBadGateway, gin.H{
|
|
"error": gin.H{
|
|
"type": "upstream_error",
|
|
"message": "Upstream request failed",
|
|
},
|
|
})
|
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode >= 400 {
|
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
_ = resp.Body.Close()
|
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
|
|
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
upstreamDetail := ""
|
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|
if maxBytes <= 0 {
|
|
maxBytes = 2048
|
|
}
|
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
|
}
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
Platform: account.Platform,
|
|
AccountID: account.ID,
|
|
AccountName: account.Name,
|
|
UpstreamStatusCode: resp.StatusCode,
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|
Kind: "failover",
|
|
Message: upstreamMsg,
|
|
Detail: upstreamDetail,
|
|
})
|
|
|
|
s.handleFailoverSideEffects(ctx, resp, account)
|
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
|
}
|
|
return s.handleErrorResponse(ctx, resp, c, account, body)
|
|
}
|
|
|
|
var usage *OpenAIUsage
|
|
var firstTokenMs *int
|
|
if reqStream {
|
|
streamResult, err := s.handleChatCompletionsStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
usage = streamResult.usage
|
|
firstTokenMs = streamResult.firstTokenMs
|
|
} else {
|
|
usage, err = s.handleChatCompletionsNonStreamingResponse(resp, c, originalModel, mappedModel)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if usage == nil {
|
|
usage = &OpenAIUsage{}
|
|
}
|
|
|
|
return &OpenAIForwardResult{
|
|
RequestID: resp.Header.Get("x-request-id"),
|
|
Usage: *usage,
|
|
Model: originalModel,
|
|
Stream: reqStream,
|
|
Duration: time.Since(startTime),
|
|
FirstTokenMs: firstTokenMs,
|
|
}, nil
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) buildChatCompletionsRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string) (*http.Request, error) {
|
|
var targetURL string
|
|
baseURL := account.GetOpenAIBaseURL()
|
|
if baseURL == "" {
|
|
targetURL = openaiChatAPIURL
|
|
} else {
|
|
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
targetURL = validatedURL + "/chat/completions"
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("authorization", "Bearer "+token)
|
|
|
|
for key, values := range c.Request.Header {
|
|
lowerKey := strings.ToLower(key)
|
|
if openaiChatAllowedHeaders[lowerKey] {
|
|
for _, v := range values {
|
|
req.Header.Add(key, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
customUA := account.GetOpenAIUserAgent()
|
|
if customUA != "" {
|
|
req.Header.Set("user-agent", customUA)
|
|
}
|
|
|
|
if req.Header.Get("content-type") == "" {
|
|
req.Header.Set("content-type", "application/json")
|
|
}
|
|
|
|
return req, nil
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) handleChatCompletionsStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*chatStreamingResult, error) {
|
|
if s.responseHeaderFilter != nil {
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|
}
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Header("X-Accel-Buffering", "no")
|
|
|
|
if v := resp.Header.Get("x-request-id"); v != "" {
|
|
c.Header("x-request-id", v)
|
|
}
|
|
|
|
w := c.Writer
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
return nil, errors.New("streaming not supported")
|
|
}
|
|
|
|
usage := &OpenAIUsage{}
|
|
var firstTokenMs *int
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
maxLineSize := defaultMaxLineSize
|
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|
}
|
|
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
|
|
|
type scanEvent struct {
|
|
line string
|
|
err error
|
|
}
|
|
events := make(chan scanEvent, 16)
|
|
done := make(chan struct{})
|
|
sendEvent := func(ev scanEvent) bool {
|
|
select {
|
|
case events <- ev:
|
|
return true
|
|
case <-done:
|
|
return false
|
|
}
|
|
}
|
|
var lastReadAt int64
|
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
|
go func() {
|
|
defer close(events)
|
|
for scanner.Scan() {
|
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
|
return
|
|
}
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
_ = sendEvent(scanEvent{err: err})
|
|
}
|
|
}()
|
|
defer close(done)
|
|
|
|
streamInterval := time.Duration(0)
|
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
|
}
|
|
var intervalTicker *time.Ticker
|
|
if streamInterval > 0 {
|
|
intervalTicker = time.NewTicker(streamInterval)
|
|
defer intervalTicker.Stop()
|
|
}
|
|
var intervalCh <-chan time.Time
|
|
if intervalTicker != nil {
|
|
intervalCh = intervalTicker.C
|
|
}
|
|
|
|
keepaliveInterval := time.Duration(0)
|
|
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
|
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
|
}
|
|
var keepaliveTicker *time.Ticker
|
|
if keepaliveInterval > 0 {
|
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
|
defer keepaliveTicker.Stop()
|
|
}
|
|
var keepaliveCh <-chan time.Time
|
|
if keepaliveTicker != nil {
|
|
keepaliveCh = keepaliveTicker.C
|
|
}
|
|
lastDataAt := time.Now()
|
|
|
|
errorEventSent := false
|
|
sendErrorEvent := func(reason string) {
|
|
if errorEventSent {
|
|
return
|
|
}
|
|
errorEventSent = true
|
|
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
|
flusher.Flush()
|
|
}
|
|
|
|
needModelReplace := originalModel != mappedModel
|
|
|
|
for {
|
|
select {
|
|
case ev, ok := <-events:
|
|
if !ok {
|
|
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
|
}
|
|
if ev.err != nil {
|
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
|
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
|
sendErrorEvent("response_too_large")
|
|
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
|
}
|
|
sendErrorEvent("stream_read_error")
|
|
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
|
}
|
|
|
|
line := ev.line
|
|
lastDataAt = time.Now()
|
|
|
|
if openaiSSEDataRe.MatchString(line) {
|
|
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
|
|
|
if needModelReplace {
|
|
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
|
}
|
|
|
|
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
|
line = "data: " + correctedData
|
|
}
|
|
|
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
|
sendErrorEvent("write_failed")
|
|
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
}
|
|
flusher.Flush()
|
|
|
|
if firstTokenMs == nil {
|
|
if event := parseChatStreamEvent(data); event != nil {
|
|
if chatChunkHasDelta(event) {
|
|
ms := int(time.Since(startTime).Milliseconds())
|
|
firstTokenMs = &ms
|
|
}
|
|
applyChatUsageFromEvent(event, usage)
|
|
}
|
|
} else {
|
|
if event := parseChatStreamEvent(data); event != nil {
|
|
applyChatUsageFromEvent(event, usage)
|
|
}
|
|
}
|
|
} else {
|
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
|
sendErrorEvent("write_failed")
|
|
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
}
|
|
flusher.Flush()
|
|
}
|
|
|
|
case <-intervalCh:
|
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
|
if time.Since(lastRead) < streamInterval {
|
|
continue
|
|
}
|
|
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
|
if s.rateLimitService != nil {
|
|
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
|
}
|
|
sendErrorEvent("stream_timeout")
|
|
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
|
|
|
case <-keepaliveCh:
|
|
if time.Since(lastDataAt) < keepaliveInterval {
|
|
continue
|
|
}
|
|
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
|
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
}
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) handleChatCompletionsNonStreamingResponse(resp *http.Response, c *gin.Context, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
usage := &OpenAIUsage{}
|
|
var parsed map[string]any
|
|
if json.Unmarshal(body, &parsed) == nil {
|
|
if usageMap, ok := parsed["usage"].(map[string]any); ok {
|
|
applyChatUsageFromMap(usageMap, usage)
|
|
}
|
|
}
|
|
|
|
if originalModel != mappedModel {
|
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
|
}
|
|
body = s.correctToolCallsInResponseBody(body)
|
|
|
|
if s.responseHeaderFilter != nil {
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|
}
|
|
|
|
contentType := "application/json"
|
|
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
|
|
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
|
|
contentType = upstreamType
|
|
}
|
|
}
|
|
|
|
c.Data(resp.StatusCode, contentType, body)
|
|
return usage, nil
|
|
}
|
|
|
|
func parseChatStreamEvent(data string) map[string]any {
|
|
if data == "" || data == "[DONE]" {
|
|
return nil
|
|
}
|
|
var event map[string]any
|
|
if json.Unmarshal([]byte(data), &event) != nil {
|
|
return nil
|
|
}
|
|
return event
|
|
}
|
|
|
|
func chatChunkHasDelta(event map[string]any) bool {
|
|
choices, ok := event["choices"].([]any)
|
|
if !ok {
|
|
return false
|
|
}
|
|
for _, choice := range choices {
|
|
choiceMap, ok := choice.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
delta, ok := choiceMap["delta"].(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if content, ok := delta["content"].(string); ok && strings.TrimSpace(content) != "" {
|
|
return true
|
|
}
|
|
if toolCalls, ok := delta["tool_calls"].([]any); ok && len(toolCalls) > 0 {
|
|
return true
|
|
}
|
|
if functionCall, ok := delta["function_call"].(map[string]any); ok && len(functionCall) > 0 {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func applyChatUsageFromEvent(event map[string]any, usage *OpenAIUsage) {
|
|
if event == nil || usage == nil {
|
|
return
|
|
}
|
|
usageMap, ok := event["usage"].(map[string]any)
|
|
if !ok {
|
|
return
|
|
}
|
|
applyChatUsageFromMap(usageMap, usage)
|
|
}
|
|
|
|
func applyChatUsageFromMap(usageMap map[string]any, usage *OpenAIUsage) {
|
|
if usageMap == nil || usage == nil {
|
|
return
|
|
}
|
|
promptTokens := int(getNumber(usageMap["prompt_tokens"]))
|
|
completionTokens := int(getNumber(usageMap["completion_tokens"]))
|
|
if promptTokens > 0 {
|
|
usage.InputTokens = promptTokens
|
|
}
|
|
if completionTokens > 0 {
|
|
usage.OutputTokens = completionTokens
|
|
}
|
|
}
|