Files
sub2api/backend/internal/service/openai_gateway_chat_completions.go
2026-05-03 17:11:27 +08:00

679 lines
22 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// cursorResponsesUnsupportedFields are top-level Responses API parameters that
// Codex upstreams reject with "Unsupported parameter: ...". They must be
// stripped when forwarding a raw client body through the Responses-shape
// short-circuit in ForwardAsChatCompletions (see isResponsesShape branch).
// The normal Chat Completions → Responses conversion path is unaffected
// because ChatCompletionsRequest has no fields for these parameters — unknown
// fields are dropped naturally by json.Unmarshal. Kept semantically in sync
// with the list in openai_gateway_service.go:2034 used by the /v1/responses
// passthrough path.
var cursorResponsesUnsupportedFields = []string{
"prompt_cache_retention",
"safety_identifier",
"metadata",
"stream_options",
}
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
// the response back to Chat Completions format. All account types (OAuth and API
// Key) go through the Responses API conversion path since the upstream only
// exposes the /v1/responses endpoint.
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
promptCacheKey string,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse Chat Completions request
var chatReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &chatReq); err != nil {
return nil, fmt.Errorf("parse chat completions request: %w", err)
}
originalModel := chatReq.Model
clientStream := chatReq.Stream
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
// 2. Resolve model mapping early so compat prompt_cache_key injection can
// derive a stable seed from the final upstream model family.
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
promptCacheKey = strings.TrimSpace(promptCacheKey)
compatPromptCacheInjected := false
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, upstreamModel)
compatPromptCacheInjected = promptCacheKey != ""
}
// 3. Build the upstream (Responses API) body.
//
// Cursor compatibility: some clients (notably Cursor cloud) send Responses
// API shaped bodies — `input: [...]` with no `messages` field — to the
// /v1/chat/completions URL. Running those through ChatCompletionsToResponses
// would silently drop Cursor's `input` array (the struct has no Input field)
// and produce `input: null`, which Codex upstreams reject with
// "Invalid type for 'input': expected a string, but got an object".
//
// Detect that shape and forward the raw body as-is, only rewriting `model`
// to the resolved upstream model. The downstream codex OAuth transform will
// still normalize store/stream/instructions/etc.
isResponsesShape := !gjson.GetBytes(body, "messages").Exists() && gjson.GetBytes(body, "input").Exists()
var (
responsesReq *apicompat.ResponsesRequest
responsesBody []byte
err error
)
if isResponsesShape {
responsesBody, err = sjson.SetBytes(body, "model", upstreamModel)
if err != nil {
return nil, fmt.Errorf("rewrite model in responses-shape body: %w", err)
}
// Strip Responses API parameters that no Codex upstream accepts.
// Because this branch forwards the raw body (the normal path rebuilds
// it from ChatCompletionsRequest and drops unknown fields naturally),
// we must filter these fields explicitly here — otherwise the upstream
// rejects the request with "Unsupported parameter: ...".
for _, field := range cursorResponsesUnsupportedFields {
if stripped, derr := sjson.DeleteBytes(responsesBody, field); derr == nil {
responsesBody = stripped
}
}
responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
if err != nil {
return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
}
// Minimal stub populated from the raw body so downstream billing
// propagation (ServiceTier, ReasoningEffort) keeps working.
responsesReq = &apicompat.ResponsesRequest{
Model: upstreamModel,
ServiceTier: normalizedServiceTier,
}
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
}
} else {
// Normal path: convert Chat Completions → Responses.
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
responsesReq, err = apicompat.ChatCompletionsToResponses(&chatReq)
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
responsesReq.Model = upstreamModel
normalizeResponsesRequestServiceTier(responsesReq)
responsesBody, err = json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
}
}
logFields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
zap.Bool("stream", clientStream),
zap.Bool("responses_shape", isResponsesShape),
}
if compatPromptCacheInjected {
logFields = append(logFields,
zap.Bool("compat_prompt_cache_key_injected", true),
zap.String("compat_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)),
)
}
logger.L().Debug("openai chat_completions: model mapping applied", logFields...)
if account.Type == AccountTypeOAuth {
var reqBody map[string]any
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.NormalizedModel != "" {
upstreamModel = codexResult.NormalizedModel
}
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
reqBody["prompt_cache_key"] = promptCacheKey
}
responsesBody, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
}
}
// 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
}
return nil, policyErr
}
responsesBody = updatedBody
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 6. Build upstream request
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
}
// 7. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 8. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account)
}
// 9. Handle normal response
var result *OpenAIForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime)
} else {
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
}
// Propagate ServiceTier and ReasoningEffort to result for billing
if handleErr == nil && result != nil {
if responsesReq.ServiceTier != "" {
st := responsesReq.ServiceTier
result.ServiceTier = &st
}
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
re := responsesReq.Reasoning.Effort
result.ReasoningEffort = &re
}
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if handleErr == nil && account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
return result, handleErr
}
func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
if req == nil {
return
}
req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
}
func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
if len(body) == 0 {
return body, "", nil
}
rawServiceTier := gjson.GetBytes(body, "service_tier").String()
if rawServiceTier == "" {
return body, "", nil
}
normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
if normalizedServiceTier == "" {
trimmed, err := sjson.DeleteBytes(body, "service_tier")
return trimmed, "", err
}
if normalizedServiceTier == rawServiceTier {
return body, normalizedServiceTier, nil
}
trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
return trimmed, normalizedServiceTier, err
}
func normalizedOpenAIServiceTierValue(raw string) string {
normalized := normalizeOpenAIServiceTier(raw)
if normalized == nil {
return ""
}
return *normalized
}
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
}
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
// upstream, finds the terminal event, converts to a Chat Completions JSON
// response, and writes it to the client.
func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
billingModel string,
upstreamModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID)
if err != nil {
return nil, err
}
if finalResponse == nil {
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
return nil, fmt.Errorf("upstream stream ended without terminal event")
}
// When the terminal event has an empty output array, reconstruct from
// accumulated delta events so the client receives the full content.
acc.SupplementResponseOutput(finalResponse)
chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, chatResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleChatStreamingResponse reads Responses SSE events from upstream,
// converts each to Chat Completions SSE chunks, and writes them to the client.
func (s *OpenAIGatewayService) handleChatStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
billingModel string,
upstreamModel string,
includeUsage bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewResponsesEventToChatState()
state.Model = originalModel
state.IncludeUsage = includeUsage
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
clientDisconnected := false
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
processDataLine := func(payload string) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
return false
}
// 仅按兼容转换器支持的终止事件提取 usage避免无意扩大事件语义。
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
}
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
if !clientDisconnected {
for _, chunk := range chunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing",
zap.String("request_id", requestID),
)
break
}
}
}
if len(chunks) > 0 && !clientDisconnected {
c.Writer.Flush()
}
return isTerminalEvent
}
finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected {
for _, chunk := range finalChunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during final flush",
zap.String("request_id", requestID),
)
break
}
}
}
// Send [DONE] sentinel
if !clientDisconnected {
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during done flush",
zap.String("request_id", requestID),
)
}
}
if !clientDisconnected {
c.Writer.Flush()
}
return resultWithUsage(), nil
}
handleScanErr := func(err error) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
missingTerminalErr := func() (*OpenAIForwardResult, error) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
// Determine keepalive interval
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// No keepalive: fast synchronous path
if streamInterval <= 0 && keepaliveInterval <= 0 {
for scanner.Scan() {
line := scanner.Text()
payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
}
if strings.TrimSpace(payload) == "[DONE]" {
return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
}
}
if err := scanner.Err(); err != nil {
handleScanErr(err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
return missingTerminalErr()
}
// With keepalive: goroutine + channel + select
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
for {
select {
case ev, ok := <-events:
if !ok {
return missingTerminalErr()
}
if ev.err != nil {
handleScanErr(ev.err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
}
lastDataAt = time.Now()
line := ev.line
payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
}
if strings.TrimSpace(payload) == "[DONE]" {
return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.L().Warn("openai chat_completions stream: data interval timeout",
zap.String("request_id", requestID),
zap.String("model", originalModel),
zap.Duration("interval", streamInterval),
)
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// Send SSE comment as keepalive
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID),
)
clientDisconnected = true
continue
}
c.Writer.Flush()
}
}
}
// writeChatCompletionsError writes an error response in OpenAI Chat Completions format.
func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}