mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
OpenAI APIKey accounts with base_url pointing to third-party OpenAI-compatible
upstreams (DeepSeek, Kimi, GLM, Qwen, etc.) were failing because the gateway
unconditionally converted Chat Completions requests to Responses format and
forwarded to {base_url}/v1/responses, which only exists on OpenAI's official
endpoint.
Detection-based routing:
- Probe upstream capability on account create/update via a minimal POST to
/v1/responses; HTTP 404/405 means 'unsupported', any other response means
'supported'.
- Persist result as accounts.extra.openai_responses_supported (bool).
- ForwardAsChatCompletions branches at function entry: APIKey accounts with
explicit support=false go through new forwardAsRawChatCompletions which
passthrough-forwards CC body to /v1/chat/completions without protocol
conversion.
Default behavior for accounts without the marker preserves the legacy
'always Responses' path — existing OpenAI APIKey accounts that were working
before this change continue to work without modification (the 'reality is
evidence' principle: an account that has been running implies upstream
capability).
Probe is fired async after Create / Update / BatchCreate; failures only log,
never block the admin flow. BulkUpdate omitted (low signal of base_url
changes; can be added if needed).
Implementation:
- New pkg internal/pkg/openai_compat: marker key + ShouldUseResponsesAPI
- New service file openai_apikey_responses_probe.go: probe + persist
- New service file openai_gateway_chat_completions_raw.go: CC pass-through
- Account test endpoint short-circuits with explicit message for
probed-unsupported accounts (full CC test path is a TODO)
Zero schema changes, zero migrations, zero frontend changes, zero wire
modifications — all wired through existing AccountTestService injection.
Closes: DeepSeek-OpenAI account (id=128) production failure
672 lines
22 KiB
Go
672 lines
22 KiB
Go
package service
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/tidwall/gjson"
|
||
"github.com/tidwall/sjson"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// cursorResponsesUnsupportedFields are top-level Responses API parameters that
|
||
// Codex upstreams reject with "Unsupported parameter: ...". They must be
|
||
// stripped when forwarding a raw client body through the Responses-shape
|
||
// short-circuit in ForwardAsChatCompletions (see isResponsesShape branch).
|
||
// The normal Chat Completions → Responses conversion path is unaffected
|
||
// because ChatCompletionsRequest has no fields for these parameters — unknown
|
||
// fields are dropped naturally by json.Unmarshal. Kept semantically in sync
|
||
// with the list in openai_gateway_service.go:2034 used by the /v1/responses
|
||
// passthrough path.
|
||
var cursorResponsesUnsupportedFields = []string{
|
||
"prompt_cache_retention",
|
||
"safety_identifier",
|
||
"metadata",
|
||
"stream_options",
|
||
}
|
||
|
||
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
|
||
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
|
||
// the response back to Chat Completions format.
|
||
//
|
||
// 历史背景:该函数原本对所有 OpenAI 账号无差别走 CC→Responses 转换 + /v1/responses
|
||
// 端点——这在 OAuth(ChatGPT 内部 API 仅支持 Responses)和官方 APIKey 账号上是
|
||
// 正确的,但 sub2api 接入 DeepSeek/Kimi/GLM 等第三方 OpenAI 兼容上游后假设破裂:
|
||
// 这些上游普遍只支持 /v1/chat/completions,无 /v1/responses 端点。
|
||
//
|
||
// 当前路由策略(基于账号探测标记,详见 openai_compat.ShouldUseResponsesAPI):
|
||
// - APIKey 账号 + 探测确认不支持 Responses → 走 forwardAsRawChatCompletions
|
||
// 直转上游 /v1/chat/completions,不做协议转换
|
||
// - 其他所有情况(OAuth、APIKey 探测确认支持、未探测)→ 走原有 CC→Responses
|
||
// 转换路径(保留旧行为,存量未探测账号零兼容破坏)
|
||
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||
ctx context.Context,
|
||
c *gin.Context,
|
||
account *Account,
|
||
body []byte,
|
||
promptCacheKey string,
|
||
defaultMappedModel string,
|
||
) (*OpenAIForwardResult, error) {
|
||
// 入口分流:APIKey 账号 + 已探测且确认上游不支持 Responses,走 CC 直转。
|
||
// 标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。
|
||
if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||
return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel)
|
||
}
|
||
|
||
startTime := time.Now()
|
||
|
||
// 1. Parse Chat Completions request
|
||
var chatReq apicompat.ChatCompletionsRequest
|
||
if err := json.Unmarshal(body, &chatReq); err != nil {
|
||
return nil, fmt.Errorf("parse chat completions request: %w", err)
|
||
}
|
||
originalModel := chatReq.Model
|
||
clientStream := chatReq.Stream
|
||
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
|
||
|
||
// 2. Resolve model mapping early so compat prompt_cache_key injection can
|
||
// derive a stable seed from the final upstream model family.
|
||
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
||
|
||
promptCacheKey = strings.TrimSpace(promptCacheKey)
|
||
compatPromptCacheInjected := false
|
||
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
|
||
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, upstreamModel)
|
||
compatPromptCacheInjected = promptCacheKey != ""
|
||
}
|
||
|
||
// 3. Build the upstream (Responses API) body.
|
||
//
|
||
// Cursor compatibility: some clients (notably Cursor cloud) send Responses
|
||
// API shaped bodies — `input: [...]` with no `messages` field — to the
|
||
// /v1/chat/completions URL. Running those through ChatCompletionsToResponses
|
||
// would silently drop Cursor's `input` array (the struct has no Input field)
|
||
// and produce `input: null`, which Codex upstreams reject with
|
||
// "Invalid type for 'input': expected a string, but got an object".
|
||
//
|
||
// Detect that shape and forward the raw body as-is, only rewriting `model`
|
||
// to the resolved upstream model. The downstream codex OAuth transform will
|
||
// still normalize store/stream/instructions/etc.
|
||
isResponsesShape := !gjson.GetBytes(body, "messages").Exists() && gjson.GetBytes(body, "input").Exists()
|
||
|
||
var (
|
||
responsesReq *apicompat.ResponsesRequest
|
||
responsesBody []byte
|
||
err error
|
||
)
|
||
if isResponsesShape {
|
||
responsesBody, err = sjson.SetBytes(body, "model", upstreamModel)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("rewrite model in responses-shape body: %w", err)
|
||
}
|
||
// Strip Responses API parameters that no Codex upstream accepts.
|
||
// Because this branch forwards the raw body (the normal path rebuilds
|
||
// it from ChatCompletionsRequest and drops unknown fields naturally),
|
||
// we must filter these fields explicitly here — otherwise the upstream
|
||
// rejects the request with "Unsupported parameter: ...".
|
||
for _, field := range cursorResponsesUnsupportedFields {
|
||
if stripped, derr := sjson.DeleteBytes(responsesBody, field); derr == nil {
|
||
responsesBody = stripped
|
||
}
|
||
}
|
||
responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
|
||
}
|
||
// Minimal stub populated from the raw body so downstream billing
|
||
// propagation (ServiceTier, ReasoningEffort) keeps working.
|
||
responsesReq = &apicompat.ResponsesRequest{
|
||
Model: upstreamModel,
|
||
ServiceTier: normalizedServiceTier,
|
||
}
|
||
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
|
||
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
|
||
}
|
||
} else {
|
||
// Normal path: convert Chat Completions → Responses.
|
||
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
|
||
responsesReq, err = apicompat.ChatCompletionsToResponses(&chatReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
||
}
|
||
responsesReq.Model = upstreamModel
|
||
normalizeResponsesRequestServiceTier(responsesReq)
|
||
responsesBody, err = json.Marshal(responsesReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("marshal responses request: %w", err)
|
||
}
|
||
}
|
||
|
||
logFields := []zap.Field{
|
||
zap.Int64("account_id", account.ID),
|
||
zap.String("original_model", originalModel),
|
||
zap.String("billing_model", billingModel),
|
||
zap.String("upstream_model", upstreamModel),
|
||
zap.Bool("stream", clientStream),
|
||
zap.Bool("responses_shape", isResponsesShape),
|
||
}
|
||
if compatPromptCacheInjected {
|
||
logFields = append(logFields,
|
||
zap.Bool("compat_prompt_cache_key_injected", true),
|
||
zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)),
|
||
)
|
||
}
|
||
logger.L().Debug("openai chat_completions: model mapping applied", logFields...)
|
||
|
||
if account.Type == AccountTypeOAuth {
|
||
var reqBody map[string]any
|
||
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
||
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||
}
|
||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||
if codexResult.NormalizedModel != "" {
|
||
upstreamModel = codexResult.NormalizedModel
|
||
}
|
||
if codexResult.PromptCacheKey != "" {
|
||
promptCacheKey = codexResult.PromptCacheKey
|
||
} else if promptCacheKey != "" {
|
||
reqBody["prompt_cache_key"] = promptCacheKey
|
||
}
|
||
responsesBody, err = json.Marshal(reqBody)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
|
||
}
|
||
}
|
||
|
||
// 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
|
||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
|
||
if policyErr != nil {
|
||
var blocked *OpenAIFastBlockedError
|
||
if errors.As(policyErr, &blocked) {
|
||
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
|
||
}
|
||
return nil, policyErr
|
||
}
|
||
responsesBody = updatedBody
|
||
|
||
// 5. Get access token
|
||
token, _, err := s.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get access token: %w", err)
|
||
}
|
||
|
||
// 6. Build upstream request
|
||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||
}
|
||
|
||
if promptCacheKey != "" {
|
||
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
|
||
}
|
||
|
||
// 7. Send request
|
||
proxyURL := ""
|
||
if account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||
setOpsUpstreamError(c, 0, safeErr, "")
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: 0,
|
||
Kind: "request_error",
|
||
Message: safeErr,
|
||
})
|
||
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
// 8. Handle error response with failover
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
_ = resp.Body.Close()
|
||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||
|
||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||
upstreamDetail := ""
|
||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||
if maxBytes <= 0 {
|
||
maxBytes = 2048
|
||
}
|
||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||
}
|
||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||
Platform: account.Platform,
|
||
AccountID: account.ID,
|
||
AccountName: account.Name,
|
||
UpstreamStatusCode: resp.StatusCode,
|
||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||
Kind: "failover",
|
||
Message: upstreamMsg,
|
||
Detail: upstreamDetail,
|
||
})
|
||
if s.rateLimitService != nil {
|
||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||
}
|
||
return nil, &UpstreamFailoverError{
|
||
StatusCode: resp.StatusCode,
|
||
ResponseBody: respBody,
|
||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||
}
|
||
}
|
||
return s.handleChatCompletionsErrorResponse(resp, c, account)
|
||
}
|
||
|
||
// 9. Handle normal response
|
||
var result *OpenAIForwardResult
|
||
var handleErr error
|
||
if clientStream {
|
||
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime)
|
||
} else {
|
||
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
|
||
}
|
||
|
||
// Propagate ServiceTier and ReasoningEffort to result for billing
|
||
if handleErr == nil && result != nil {
|
||
if responsesReq.ServiceTier != "" {
|
||
st := responsesReq.ServiceTier
|
||
result.ServiceTier = &st
|
||
}
|
||
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
|
||
re := responsesReq.Reasoning.Effort
|
||
result.ReasoningEffort = &re
|
||
}
|
||
}
|
||
|
||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||
if handleErr == nil && account.Type == AccountTypeOAuth {
|
||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||
}
|
||
}
|
||
|
||
return result, handleErr
|
||
}
|
||
|
||
func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
|
||
if req == nil {
|
||
return
|
||
}
|
||
req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
|
||
}
|
||
|
||
func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
|
||
if len(body) == 0 {
|
||
return body, "", nil
|
||
}
|
||
rawServiceTier := gjson.GetBytes(body, "service_tier").String()
|
||
if rawServiceTier == "" {
|
||
return body, "", nil
|
||
}
|
||
normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
|
||
if normalizedServiceTier == "" {
|
||
trimmed, err := sjson.DeleteBytes(body, "service_tier")
|
||
return trimmed, "", err
|
||
}
|
||
if normalizedServiceTier == rawServiceTier {
|
||
return body, normalizedServiceTier, nil
|
||
}
|
||
trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
|
||
return trimmed, normalizedServiceTier, err
|
||
}
|
||
|
||
func normalizedOpenAIServiceTierValue(raw string) string {
|
||
normalized := normalizeOpenAIServiceTier(raw)
|
||
if normalized == nil {
|
||
return ""
|
||
}
|
||
return *normalized
|
||
}
|
||
|
||
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
|
||
// OpenAI Chat Completions error format.
|
||
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
|
||
resp *http.Response,
|
||
c *gin.Context,
|
||
account *Account,
|
||
) (*OpenAIForwardResult, error) {
|
||
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
|
||
}
|
||
|
||
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
|
||
// upstream, finds the terminal event, converts to a Chat Completions JSON
|
||
// response, and writes it to the client.
|
||
func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
|
||
resp *http.Response,
|
||
c *gin.Context,
|
||
originalModel string,
|
||
billingModel string,
|
||
upstreamModel string,
|
||
startTime time.Time,
|
||
) (*OpenAIForwardResult, error) {
|
||
requestID := resp.Header.Get("x-request-id")
|
||
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||
|
||
var finalResponse *apicompat.ResponsesResponse
|
||
var usage OpenAIUsage
|
||
acc := apicompat.NewBufferedResponseAccumulator()
|
||
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||
continue
|
||
}
|
||
payload := line[6:]
|
||
|
||
var event apicompat.ResponsesStreamEvent
|
||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||
logger.L().Warn("openai chat_completions buffered: failed to parse event",
|
||
zap.Error(err),
|
||
zap.String("request_id", requestID),
|
||
)
|
||
continue
|
||
}
|
||
|
||
// Accumulate delta content for fallback when terminal output is empty.
|
||
acc.ProcessEvent(&event)
|
||
|
||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||
event.Response != nil {
|
||
finalResponse = event.Response
|
||
if event.Response.Usage != nil {
|
||
usage = OpenAIUsage{
|
||
InputTokens: event.Response.Usage.InputTokens,
|
||
OutputTokens: event.Response.Usage.OutputTokens,
|
||
}
|
||
if event.Response.Usage.InputTokensDetails != nil {
|
||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if err := scanner.Err(); err != nil {
|
||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||
logger.L().Warn("openai chat_completions buffered: read error",
|
||
zap.Error(err),
|
||
zap.String("request_id", requestID),
|
||
)
|
||
}
|
||
}
|
||
|
||
if finalResponse == nil {
|
||
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
|
||
return nil, fmt.Errorf("upstream stream ended without terminal event")
|
||
}
|
||
|
||
// When the terminal event has an empty output array, reconstruct from
|
||
// accumulated delta events so the client receives the full content.
|
||
acc.SupplementResponseOutput(finalResponse)
|
||
|
||
chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel)
|
||
|
||
if s.responseHeaderFilter != nil {
|
||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||
}
|
||
c.JSON(http.StatusOK, chatResp)
|
||
|
||
return &OpenAIForwardResult{
|
||
RequestID: requestID,
|
||
Usage: usage,
|
||
Model: originalModel,
|
||
BillingModel: billingModel,
|
||
UpstreamModel: upstreamModel,
|
||
Stream: false,
|
||
Duration: time.Since(startTime),
|
||
}, nil
|
||
}
|
||
|
||
// handleChatStreamingResponse reads Responses SSE events from upstream,
|
||
// converts each to Chat Completions SSE chunks, and writes them to the client.
|
||
func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
||
resp *http.Response,
|
||
c *gin.Context,
|
||
originalModel string,
|
||
billingModel string,
|
||
upstreamModel string,
|
||
includeUsage bool,
|
||
startTime time.Time,
|
||
) (*OpenAIForwardResult, error) {
|
||
requestID := resp.Header.Get("x-request-id")
|
||
|
||
if s.responseHeaderFilter != nil {
|
||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||
}
|
||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||
c.Writer.Header().Set("Connection", "keep-alive")
|
||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||
c.Writer.WriteHeader(http.StatusOK)
|
||
|
||
state := apicompat.NewResponsesEventToChatState()
|
||
state.Model = originalModel
|
||
state.IncludeUsage = includeUsage
|
||
|
||
var usage OpenAIUsage
|
||
var firstTokenMs *int
|
||
firstChunk := true
|
||
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
maxLineSize := defaultMaxLineSize
|
||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||
}
|
||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||
|
||
resultWithUsage := func() *OpenAIForwardResult {
|
||
return &OpenAIForwardResult{
|
||
RequestID: requestID,
|
||
Usage: usage,
|
||
Model: originalModel,
|
||
BillingModel: billingModel,
|
||
UpstreamModel: upstreamModel,
|
||
Stream: true,
|
||
Duration: time.Since(startTime),
|
||
FirstTokenMs: firstTokenMs,
|
||
}
|
||
}
|
||
|
||
processDataLine := func(payload string) bool {
|
||
if firstChunk {
|
||
firstChunk = false
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
var event apicompat.ResponsesStreamEvent
|
||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||
logger.L().Warn("openai chat_completions stream: failed to parse event",
|
||
zap.Error(err),
|
||
zap.String("request_id", requestID),
|
||
)
|
||
return false
|
||
}
|
||
|
||
// Extract usage from completion events
|
||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||
event.Response != nil && event.Response.Usage != nil {
|
||
usage = OpenAIUsage{
|
||
InputTokens: event.Response.Usage.InputTokens,
|
||
OutputTokens: event.Response.Usage.OutputTokens,
|
||
}
|
||
if event.Response.Usage.InputTokensDetails != nil {
|
||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
||
}
|
||
}
|
||
|
||
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
|
||
for _, chunk := range chunks {
|
||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||
if err != nil {
|
||
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||
zap.Error(err),
|
||
zap.String("request_id", requestID),
|
||
)
|
||
continue
|
||
}
|
||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||
logger.L().Info("openai chat_completions stream: client disconnected",
|
||
zap.String("request_id", requestID),
|
||
)
|
||
return true
|
||
}
|
||
}
|
||
if len(chunks) > 0 {
|
||
c.Writer.Flush()
|
||
}
|
||
return false
|
||
}
|
||
|
||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
|
||
for _, chunk := range finalChunks {
|
||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
||
}
|
||
}
|
||
// Send [DONE] sentinel
|
||
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
|
||
c.Writer.Flush()
|
||
return resultWithUsage(), nil
|
||
}
|
||
|
||
handleScanErr := func(err error) {
|
||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||
logger.L().Warn("openai chat_completions stream: read error",
|
||
zap.Error(err),
|
||
zap.String("request_id", requestID),
|
||
)
|
||
}
|
||
}
|
||
|
||
// Determine keepalive interval
|
||
keepaliveInterval := time.Duration(0)
|
||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||
}
|
||
|
||
// No keepalive: fast synchronous path
|
||
if keepaliveInterval <= 0 {
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||
continue
|
||
}
|
||
if processDataLine(line[6:]) {
|
||
return resultWithUsage(), nil
|
||
}
|
||
}
|
||
handleScanErr(scanner.Err())
|
||
return finalizeStream()
|
||
}
|
||
|
||
// With keepalive: goroutine + channel + select
|
||
type scanEvent struct {
|
||
line string
|
||
err error
|
||
}
|
||
events := make(chan scanEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev scanEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
go func() {
|
||
defer close(events)
|
||
for scanner.Scan() {
|
||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||
return
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
_ = sendEvent(scanEvent{err: err})
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
||
defer keepaliveTicker.Stop()
|
||
lastDataAt := time.Now()
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
return finalizeStream()
|
||
}
|
||
if ev.err != nil {
|
||
handleScanErr(ev.err)
|
||
return finalizeStream()
|
||
}
|
||
lastDataAt = time.Now()
|
||
line := ev.line
|
||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||
continue
|
||
}
|
||
if processDataLine(line[6:]) {
|
||
return resultWithUsage(), nil
|
||
}
|
||
|
||
case <-keepaliveTicker.C:
|
||
if time.Since(lastDataAt) < keepaliveInterval {
|
||
continue
|
||
}
|
||
// Send SSE comment as keepalive
|
||
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
|
||
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
|
||
zap.String("request_id", requestID),
|
||
)
|
||
return resultWithUsage(), nil
|
||
}
|
||
c.Writer.Flush()
|
||
}
|
||
}
|
||
}
|
||
|
||
// writeChatCompletionsError writes an error response in OpenAI Chat Completions format.
|
||
func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) {
|
||
c.JSON(statusCode, gin.H{
|
||
"error": gin.H{
|
||
"type": errType,
|
||
"message": message,
|
||
},
|
||
})
|
||
}
|