mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-08 01:00:21 +08:00
150 lines
4.6 KiB
Go
150 lines
4.6 KiB
Go
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"bytes"
|
|||
|
|
"context"
|
|||
|
|
"fmt"
|
|||
|
|
"io"
|
|||
|
|
"net/http"
|
|||
|
|
"strings"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|||
|
|
"github.com/gin-gonic/gin"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。
|
|||
|
|
// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions",
|
|||
|
|
// 使用 account.GetCredential("api_key") 作为 Bearer Token。
|
|||
|
|
// 支持流式和非流式响应的直接透传。
|
|||
|
|
func (s *SoraGatewayService) forwardToUpstream(
|
|||
|
|
ctx context.Context,
|
|||
|
|
c *gin.Context,
|
|||
|
|
account *Account,
|
|||
|
|
body []byte,
|
|||
|
|
clientStream bool,
|
|||
|
|
startTime time.Time,
|
|||
|
|
) (*ForwardResult, error) {
|
|||
|
|
apiKey := account.GetCredential("api_key")
|
|||
|
|
if apiKey == "" {
|
|||
|
|
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream)
|
|||
|
|
return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
baseURL := account.GetBaseURL()
|
|||
|
|
if baseURL == "" {
|
|||
|
|
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream)
|
|||
|
|
return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID)
|
|||
|
|
}
|
|||
|
|
// 校验 scheme 合法性(仅允许 http/https)
|
|||
|
|
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
|||
|
|
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream)
|
|||
|
|
return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL)
|
|||
|
|
}
|
|||
|
|
upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions"
|
|||
|
|
|
|||
|
|
// 构建上游请求
|
|||
|
|
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
|||
|
|
if err != nil {
|
|||
|
|
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream)
|
|||
|
|
return nil, fmt.Errorf("create upstream request: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
upstreamReq.Header.Set("Content-Type", "application/json")
|
|||
|
|
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
|||
|
|
|
|||
|
|
// 透传客户端的部分请求头
|
|||
|
|
for _, header := range []string{"Accept", "Accept-Encoding"} {
|
|||
|
|
if v := c.GetHeader(header); v != "" {
|
|||
|
|
upstreamReq.Header.Set(header, v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL)
|
|||
|
|
|
|||
|
|
// 获取代理 URL
|
|||
|
|
proxyURL := ""
|
|||
|
|
if account.ProxyID != nil && account.Proxy != nil {
|
|||
|
|
proxyURL = account.Proxy.URL()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 发送请求
|
|||
|
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
|||
|
|
if err != nil {
|
|||
|
|
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream)
|
|||
|
|
return nil, &UpstreamFailoverError{
|
|||
|
|
StatusCode: http.StatusBadGateway,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
defer func() {
|
|||
|
|
_ = resp.Body.Close()
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
// 错误响应处理
|
|||
|
|
if resp.StatusCode >= 400 {
|
|||
|
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
|||
|
|
|
|||
|
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
|||
|
|
return nil, &UpstreamFailoverError{
|
|||
|
|
StatusCode: resp.StatusCode,
|
|||
|
|
ResponseBody: respBody,
|
|||
|
|
ResponseHeaders: resp.Header.Clone(),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 非转移错误,直接透传给客户端
|
|||
|
|
c.Status(resp.StatusCode)
|
|||
|
|
for key, values := range resp.Header {
|
|||
|
|
for _, v := range values {
|
|||
|
|
c.Writer.Header().Add(key, v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if _, err := c.Writer.Write(respBody); err != nil {
|
|||
|
|
return nil, fmt.Errorf("write upstream error response: %w", err)
|
|||
|
|
}
|
|||
|
|
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 成功响应 — 直接透传
|
|||
|
|
c.Status(resp.StatusCode)
|
|||
|
|
for key, values := range resp.Header {
|
|||
|
|
lower := strings.ToLower(key)
|
|||
|
|
// 透传内容相关头部
|
|||
|
|
if lower == "content-type" || lower == "transfer-encoding" ||
|
|||
|
|
lower == "cache-control" || lower == "x-request-id" {
|
|||
|
|
for _, v := range values {
|
|||
|
|
c.Writer.Header().Add(key, v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 流式复制响应体
|
|||
|
|
if flusher, ok := c.Writer.(http.Flusher); ok && clientStream {
|
|||
|
|
buf := make([]byte, 4096)
|
|||
|
|
for {
|
|||
|
|
n, readErr := resp.Body.Read(buf)
|
|||
|
|
if n > 0 {
|
|||
|
|
if _, err := c.Writer.Write(buf[:n]); err != nil {
|
|||
|
|
return nil, fmt.Errorf("stream upstream response write: %w", err)
|
|||
|
|
}
|
|||
|
|
flusher.Flush()
|
|||
|
|
}
|
|||
|
|
if readErr != nil {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
|||
|
|
return nil, fmt.Errorf("copy upstream response: %w", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
duration := time.Since(startTime)
|
|||
|
|
return &ForwardResult{
|
|||
|
|
RequestID: resp.Header.Get("x-request-id"),
|
|||
|
|
Model: "", // 由调用方填充
|
|||
|
|
Stream: clientStream,
|
|||
|
|
Duration: duration,
|
|||
|
|
}, nil
|
|||
|
|
}
|