refactor: 重构 Chat Completions 端点,采用类型安全的 Responses API 转换

将 /v1/chat/completions 端点从 ResponseWriter 劫持模式重构为独立的
类型安全转换路径,与 Anthropic Messages 端点架构对齐:

- 在 apicompat 包新增 Chat Completions 完整类型定义和双向转换器
- 新增 ForwardAsChatCompletions service 方法,走 Responses API 上游
- Handler 改为独立的账号选择/failover 循环,不再劫持 Responses handler
- 提取 handleCompatErrorResponse 为 Chat Completions 和 Messages 共用
- 删除旧的 forwardChatCompletions 直传路径及相关死代码
This commit is contained in:
shaw
2026-03-11 22:10:22 +08:00
parent 8dd38f4775
commit 9d81467937
11 changed files with 2420 additions and 1717 deletions

View File

@@ -1,513 +0,0 @@
package service
import (
"encoding/json"
"errors"
"strings"
"time"
)
// ConvertChatCompletionsToResponses converts an OpenAI Chat Completions request to a Responses request.
func ConvertChatCompletionsToResponses(req map[string]any) (map[string]any, error) {
if req == nil {
return nil, errors.New("request is nil")
}
model := strings.TrimSpace(getString(req["model"]))
if model == "" {
return nil, errors.New("model is required")
}
messagesRaw, ok := req["messages"]
if !ok {
return nil, errors.New("messages is required")
}
messages, ok := messagesRaw.([]any)
if !ok {
return nil, errors.New("messages must be an array")
}
input, err := convertChatMessagesToResponsesInput(messages)
if err != nil {
return nil, err
}
out := make(map[string]any, len(req)+1)
for key, value := range req {
switch key {
case "messages", "max_tokens", "max_completion_tokens", "stream_options", "functions", "function_call":
continue
default:
out[key] = value
}
}
out["model"] = model
out["input"] = input
if _, ok := out["max_output_tokens"]; !ok {
if v, ok := req["max_tokens"]; ok {
out["max_output_tokens"] = v
} else if v, ok := req["max_completion_tokens"]; ok {
out["max_output_tokens"] = v
}
}
if _, ok := out["tools"]; !ok {
if functions, ok := req["functions"].([]any); ok && len(functions) > 0 {
tools := make([]any, 0, len(functions))
for _, fn := range functions {
if fnMap, ok := fn.(map[string]any); ok {
tools = append(tools, map[string]any{
"type": "function",
"function": fnMap,
})
}
}
if len(tools) > 0 {
out["tools"] = tools
}
}
}
if _, ok := out["tool_choice"]; !ok {
if functionCall, ok := req["function_call"]; ok {
out["tool_choice"] = functionCall
}
}
return out, nil
}
// ConvertResponsesToChatCompletion converts an OpenAI Responses response body to Chat Completions format.
func ConvertResponsesToChatCompletion(body []byte) ([]byte, error) {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
id := strings.TrimSpace(getString(resp["id"]))
if id == "" {
id = "chatcmpl-" + safeRandomHex(12)
}
model := strings.TrimSpace(getString(resp["model"]))
created := getInt64(resp["created_at"])
if created == 0 {
created = getInt64(resp["created"])
}
if created == 0 {
created = time.Now().Unix()
}
text, toolCalls := extractResponseTextAndToolCalls(resp)
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
message := map[string]any{
"role": "assistant",
"content": text,
}
if len(toolCalls) > 0 {
message["tool_calls"] = toolCalls
}
chatResp := map[string]any{
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"message": message,
"finish_reason": finishReason,
},
},
}
if usage := extractResponseUsage(resp); usage != nil {
chatResp["usage"] = usage
}
if fingerprint := strings.TrimSpace(getString(resp["system_fingerprint"])); fingerprint != "" {
chatResp["system_fingerprint"] = fingerprint
}
return json.Marshal(chatResp)
}
func convertChatMessagesToResponsesInput(messages []any) ([]any, error) {
input := make([]any, 0, len(messages))
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
return nil, errors.New("message must be an object")
}
role := strings.TrimSpace(getString(msgMap["role"]))
if role == "" {
return nil, errors.New("message role is required")
}
switch role {
case "tool":
callID := strings.TrimSpace(getString(msgMap["tool_call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(msgMap["id"]))
}
output := extractMessageContentText(msgMap["content"])
input = append(input, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
case "function":
callID := strings.TrimSpace(getString(msgMap["name"]))
output := extractMessageContentText(msgMap["content"])
input = append(input, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
default:
convertedContent := convertChatContent(msgMap["content"])
toolCalls := []any(nil)
if role == "assistant" {
toolCalls = extractToolCallsFromMessage(msgMap)
}
skipAssistantMessage := role == "assistant" && len(toolCalls) > 0 && isEmptyContent(convertedContent)
if !skipAssistantMessage {
msgItem := map[string]any{
"role": role,
"content": convertedContent,
}
if name := strings.TrimSpace(getString(msgMap["name"])); name != "" {
msgItem["name"] = name
}
input = append(input, msgItem)
}
if role == "assistant" && len(toolCalls) > 0 {
input = append(input, toolCalls...)
}
}
}
return input, nil
}
func convertChatContent(content any) any {
switch v := content.(type) {
case nil:
return ""
case string:
return v
case []any:
converted := make([]any, 0, len(v))
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
converted = append(converted, part)
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
switch partType {
case "text":
text := getString(partMap["text"])
if text != "" {
converted = append(converted, map[string]any{
"type": "input_text",
"text": text,
})
continue
}
case "image_url":
imageURL := ""
if imageObj, ok := partMap["image_url"].(map[string]any); ok {
imageURL = getString(imageObj["url"])
} else {
imageURL = getString(partMap["image_url"])
}
if imageURL != "" {
converted = append(converted, map[string]any{
"type": "input_image",
"image_url": imageURL,
})
continue
}
case "input_text", "input_image":
converted = append(converted, partMap)
continue
}
converted = append(converted, partMap)
}
return converted
default:
return v
}
}
func extractToolCallsFromMessage(msg map[string]any) []any {
var out []any
if toolCalls, ok := msg["tool_calls"].([]any); ok {
for _, call := range toolCalls {
callMap, ok := call.(map[string]any)
if !ok {
continue
}
callID := strings.TrimSpace(getString(callMap["id"]))
if callID == "" {
callID = strings.TrimSpace(getString(callMap["call_id"]))
}
name := ""
args := ""
if fn, ok := callMap["function"].(map[string]any); ok {
name = strings.TrimSpace(getString(fn["name"]))
args = getString(fn["arguments"])
}
if name == "" && args == "" {
continue
}
item := map[string]any{
"type": "tool_call",
}
if callID != "" {
item["call_id"] = callID
}
if name != "" {
item["name"] = name
}
if args != "" {
item["arguments"] = args
}
out = append(out, item)
}
}
if fnCall, ok := msg["function_call"].(map[string]any); ok {
name := strings.TrimSpace(getString(fnCall["name"]))
args := getString(fnCall["arguments"])
if name != "" || args != "" {
callID := strings.TrimSpace(getString(msg["tool_call_id"]))
if callID == "" {
callID = name
}
item := map[string]any{
"type": "function_call",
}
if callID != "" {
item["call_id"] = callID
}
if name != "" {
item["name"] = name
}
if args != "" {
item["arguments"] = args
}
out = append(out, item)
}
}
return out
}
func extractMessageContentText(content any) string {
switch v := content.(type) {
case string:
return v
case []any:
parts := make([]string, 0, len(v))
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
if partType == "" || partType == "text" || partType == "output_text" || partType == "input_text" {
text := getString(partMap["text"])
if text != "" {
parts = append(parts, text)
}
}
}
return strings.Join(parts, "")
default:
return ""
}
}
func isEmptyContent(content any) bool {
switch v := content.(type) {
case nil:
return true
case string:
return strings.TrimSpace(v) == ""
case []any:
return len(v) == 0
default:
return false
}
}
func extractResponseTextAndToolCalls(resp map[string]any) (string, []any) {
output, ok := resp["output"].([]any)
if !ok {
if text, ok := resp["output_text"].(string); ok {
return text, nil
}
return "", nil
}
textParts := make([]string, 0)
toolCalls := make([]any, 0)
for _, item := range output {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType := strings.TrimSpace(getString(itemMap["type"]))
if itemType == "tool_call" || itemType == "function_call" {
if tc := responseItemToChatToolCall(itemMap); tc != nil {
toolCalls = append(toolCalls, tc)
}
continue
}
content := itemMap["content"]
switch v := content.(type) {
case string:
if v != "" {
textParts = append(textParts, v)
}
case []any:
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
switch partType {
case "output_text", "text", "input_text":
text := getString(partMap["text"])
if text != "" {
textParts = append(textParts, text)
}
case "tool_call", "function_call":
if tc := responseItemToChatToolCall(partMap); tc != nil {
toolCalls = append(toolCalls, tc)
}
}
}
}
}
return strings.Join(textParts, ""), toolCalls
}
func responseItemToChatToolCall(item map[string]any) map[string]any {
callID := strings.TrimSpace(getString(item["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(item["id"]))
}
name := strings.TrimSpace(getString(item["name"]))
arguments := getString(item["arguments"])
if fn, ok := item["function"].(map[string]any); ok {
if name == "" {
name = strings.TrimSpace(getString(fn["name"]))
}
if arguments == "" {
arguments = getString(fn["arguments"])
}
}
if name == "" && arguments == "" && callID == "" {
return nil
}
if callID == "" {
callID = "call_" + safeRandomHex(6)
}
return map[string]any{
"id": callID,
"type": "function",
"function": map[string]any{
"name": name,
"arguments": arguments,
},
}
}
func extractResponseUsage(resp map[string]any) map[string]any {
usage, ok := resp["usage"].(map[string]any)
if !ok {
return nil
}
promptTokens := int(getNumber(usage["input_tokens"]))
completionTokens := int(getNumber(usage["output_tokens"]))
if promptTokens == 0 && completionTokens == 0 {
return nil
}
return map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": completionTokens,
"total_tokens": promptTokens + completionTokens,
}
}
func getString(value any) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
case json.Number:
return v.String()
default:
return ""
}
}
func getNumber(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
f, _ := v.Float64()
return f
default:
return 0
}
}
func getInt64(value any) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case float64:
return int64(v)
case json.Number:
i, _ := v.Int64()
return i
default:
return 0
}
}
func safeRandomHex(byteLength int) string {
value, err := randomHexString(byteLength)
if err != nil || value == "" {
return "000000"
}
return value
}

View File

@@ -1,488 +0,0 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
)
type chatStreamingResult struct {
usage *OpenAIUsage
firstTokenMs *int
}
func (s *OpenAIGatewayService) forwardChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte, includeUsage bool, startTime time.Time) (*OpenAIForwardResult, error) {
// Parse request body once (avoid multiple parse/serialize cycles)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
originalModel := reqModel
bodyModified := false
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
log.Printf("[OpenAI Chat] Model mapping applied: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
reqBody["model"] = mappedModel
bodyModified = true
}
if reqStream && includeUsage {
streamOptions, _ := reqBody["stream_options"].(map[string]any)
if streamOptions == nil {
streamOptions = map[string]any{}
}
if _, ok := streamOptions["include_usage"]; !ok {
streamOptions["include_usage"] = true
reqBody["stream_options"] = streamOptions
bodyModified = true
}
}
if bodyModified {
var err error
body, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("serialize request body: %w", err)
}
}
// Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
upstreamReq, err := s.buildChatCompletionsRequest(ctx, c, account, body, token)
if err != nil {
return nil, err
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
}
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleChatCompletionsStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
if err != nil {
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleChatCompletionsNonStreamingResponse(resp, c, originalModel, mappedModel)
if err != nil {
return nil, err
}
}
if usage == nil {
usage = &OpenAIUsage{}
}
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *OpenAIGatewayService) buildChatCompletionsRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string) (*http.Request, error) {
var targetURL string
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
targetURL = openaiChatAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/chat/completions"
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("authorization", "Bearer "+token)
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiChatAllowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
req.Header.Set("user-agent", customUA)
}
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
return req, nil
}
func (s *OpenAIGatewayService) handleChatCompletionsStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*chatStreamingResult, error) {
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &OpenAIUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
for {
select {
case ev, ok := <-events:
if !ok {
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
lastDataAt = time.Now()
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
line = "data: " + correctedData
}
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
if firstTokenMs == nil {
if event := parseChatStreamEvent(data); event != nil {
if chatChunkHasDelta(event) {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
applyChatUsageFromEvent(event, usage)
}
} else {
if event := parseChatStreamEvent(data); event != nil {
applyChatUsageFromEvent(event, usage)
}
}
} else {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
sendErrorEvent("stream_timeout")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
}
func (s *OpenAIGatewayService) handleChatCompletionsNonStreamingResponse(resp *http.Response, c *gin.Context, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
usage := &OpenAIUsage{}
var parsed map[string]any
if json.Unmarshal(body, &parsed) == nil {
if usageMap, ok := parsed["usage"].(map[string]any); ok {
applyChatUsageFromMap(usageMap, usage)
}
}
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
body = s.correctToolCallsInResponseBody(body)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func parseChatStreamEvent(data string) map[string]any {
if data == "" || data == "[DONE]" {
return nil
}
var event map[string]any
if json.Unmarshal([]byte(data), &event) != nil {
return nil
}
return event
}
func chatChunkHasDelta(event map[string]any) bool {
choices, ok := event["choices"].([]any)
if !ok {
return false
}
for _, choice := range choices {
choiceMap, ok := choice.(map[string]any)
if !ok {
continue
}
delta, ok := choiceMap["delta"].(map[string]any)
if !ok {
continue
}
if content, ok := delta["content"].(string); ok && strings.TrimSpace(content) != "" {
return true
}
if toolCalls, ok := delta["tool_calls"].([]any); ok && len(toolCalls) > 0 {
return true
}
if functionCall, ok := delta["function_call"].(map[string]any); ok && len(functionCall) > 0 {
return true
}
}
return false
}
func applyChatUsageFromEvent(event map[string]any, usage *OpenAIUsage) {
if event == nil || usage == nil {
return
}
usageMap, ok := event["usage"].(map[string]any)
if !ok {
return
}
applyChatUsageFromMap(usageMap, usage)
}
func applyChatUsageFromMap(usageMap map[string]any, usage *OpenAIUsage) {
if usageMap == nil || usage == nil {
return
}
promptTokens := int(getNumber(usageMap["prompt_tokens"]))
completionTokens := int(getNumber(usageMap["completion_tokens"]))
if promptTokens > 0 {
usage.InputTokens = promptTokens
}
if completionTokens > 0 {
usage.OutputTokens = completionTokens
}
}

View File

@@ -1,132 +0,0 @@
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestConvertChatCompletionsToResponses(t *testing.T) {
req := map[string]any{
"model": "gpt-4o",
"messages": []any{
map[string]any{
"role": "user",
"content": "hello",
},
map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "ping",
"arguments": "{}",
},
},
},
},
map[string]any{
"role": "tool",
"tool_call_id": "call_1",
"content": "ok",
"response": "ignored",
"response_time": 1,
},
},
"functions": []any{
map[string]any{
"name": "ping",
"description": "ping tool",
"parameters": map[string]any{"type": "object"},
},
},
"function_call": map[string]any{"name": "ping"},
}
converted, err := ConvertChatCompletionsToResponses(req)
require.NoError(t, err)
require.Equal(t, "gpt-4o", converted["model"])
input, ok := converted["input"].([]any)
require.True(t, ok)
require.Len(t, input, 3)
toolCall := findInputItemByType(input, "tool_call")
require.NotNil(t, toolCall)
require.Equal(t, "call_1", toolCall["call_id"])
toolOutput := findInputItemByType(input, "function_call_output")
require.NotNil(t, toolOutput)
require.Equal(t, "call_1", toolOutput["call_id"])
tools, ok := converted["tools"].([]any)
require.True(t, ok)
require.Len(t, tools, 1)
require.Equal(t, map[string]any{"name": "ping"}, converted["tool_choice"])
}
func TestConvertResponsesToChatCompletion(t *testing.T) {
resp := map[string]any{
"id": "resp_123",
"model": "gpt-4o",
"created_at": 1700000000,
"output": []any{
map[string]any{
"type": "message",
"role": "assistant",
"content": []any{
map[string]any{
"type": "output_text",
"text": "hi",
},
},
},
},
"usage": map[string]any{
"input_tokens": 2,
"output_tokens": 3,
},
}
body, err := json.Marshal(resp)
require.NoError(t, err)
converted, err := ConvertResponsesToChatCompletion(body)
require.NoError(t, err)
var chat map[string]any
require.NoError(t, json.Unmarshal(converted, &chat))
require.Equal(t, "chat.completion", chat["object"])
choices, ok := chat["choices"].([]any)
require.True(t, ok)
require.Len(t, choices, 1)
choice, ok := choices[0].(map[string]any)
require.True(t, ok)
message, ok := choice["message"].(map[string]any)
require.True(t, ok)
require.Equal(t, "hi", message["content"])
usage, ok := chat["usage"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(2), usage["prompt_tokens"])
require.Equal(t, float64(3), usage["completion_tokens"])
require.Equal(t, float64(5), usage["total_tokens"])
}
func findInputItemByType(items []any, itemType string) map[string]any {
for _, item := range items {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
if itemMap["type"] == itemType {
return itemMap
}
}
return nil
}

View File

@@ -0,0 +1,512 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
// the response back to Chat Completions format. All account types (OAuth and API
// Key) go through the Responses API conversion path since the upstream only
// exposes the /v1/responses endpoint.
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
promptCacheKey string,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse Chat Completions request
var chatReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &chatReq); err != nil {
return nil, fmt.Errorf("parse chat completions request: %w", err)
}
originalModel := chatReq.Model
clientStream := chatReq.Stream
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
// 2. Convert to Responses and forward
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
// 3. Model mapping
mappedModel := account.GetMappedModel(originalModel)
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
responsesReq.Model = mappedModel
logger.L().Debug("openai chat_completions: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("stream", clientStream),
)
// 4. Marshal Responses request body, then apply OAuth codex transform
responsesBody, err := json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
}
if account.Type == AccountTypeOAuth {
var reqBody map[string]any
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
reqBody["prompt_cache_key"] = promptCacheKey
}
responsesBody, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
}
}
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
}
// 7. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 8. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account)
}
// 9. Handle normal response
var result *OpenAIForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime)
} else {
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
}
// Propagate ServiceTier and ReasoningEffort to result for billing
if handleErr == nil && result != nil {
if responsesReq.ServiceTier != "" {
st := responsesReq.ServiceTier
result.ServiceTier = &st
}
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
re := responsesReq.Reasoning.Effort
result.ReasoningEffort = &re
}
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if handleErr == nil && account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
return result, handleErr
}
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
}
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
// upstream, finds the terminal event, converts to a Chat Completions JSON
// response, and writes it to the client.
func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResponse == nil {
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
return nil, fmt.Errorf("upstream stream ended without terminal event")
}
chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, chatResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleChatStreamingResponse reads Responses SSE events from upstream,
// converts each to Chat Completions SSE chunks, and writes them to the client.
func (s *OpenAIGatewayService) handleChatStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
includeUsage bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewResponsesEventToChatState()
state.Model = originalModel
state.IncludeUsage = includeUsage
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
processDataLine := func(payload string) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
return false
}
// Extract usage from completion events
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected",
zap.String("request_id", requestID),
)
return true
}
}
if len(chunks) > 0 {
c.Writer.Flush()
}
return false
}
finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
for _, chunk := range finalChunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
}
// Send [DONE] sentinel
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
c.Writer.Flush()
return resultWithUsage(), nil
}
handleScanErr := func(err error) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
// Determine keepalive interval
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// No keepalive: fast synchronous path
if keepaliveInterval <= 0 {
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
}
handleScanErr(scanner.Err())
return finalizeStream()
}
// With keepalive: goroutine + channel + select
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
go func() {
defer close(events)
for scanner.Scan() {
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
lastDataAt := time.Now()
for {
select {
case ev, ok := <-events:
if !ok {
return finalizeStream()
}
if ev.err != nil {
handleScanErr(ev.err)
return finalizeStream()
}
lastDataAt = time.Now()
line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
case <-keepaliveTicker.C:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// Send SSE comment as keepalive
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID),
)
return resultWithUsage(), nil
}
c.Writer.Flush()
}
}
}
// writeChatCompletionsError writes an error response in OpenAI Chat Completions format.
func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -172,7 +172,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody),
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
// Non-failover error: return Anthropic-formatted error to client
@@ -219,54 +219,7 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// Record upstream error details for ops logging
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go)
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeAnthropicError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
}
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from

View File

@@ -12,7 +12,6 @@ import (
"io"
"math/rand"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
@@ -37,7 +36,6 @@ const (
chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses"
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiChatAPIURL = "https://api.openai.com/v1/chat/completions"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
codexCLIUserAgent = "codex_cli_rs/0.104.0"
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
@@ -56,16 +54,6 @@ const (
codexCLIVersion = "0.104.0"
)
// OpenAIChatCompletionsBodyKey stores the original chat-completions payload in gin.Context.
const OpenAIChatCompletionsBodyKey = "openai_chat_completions_body"
// OpenAIChatCompletionsIncludeUsageKey stores stream_options.include_usage in gin.Context.
const OpenAIChatCompletionsIncludeUsageKey = "openai_chat_completions_include_usage"
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
// OpenAI allowed headers whitelist (for non-passthrough).
var openaiAllowedHeaders = map[string]bool{
"accept-language": true,
@@ -109,19 +97,6 @@ var codexCLIOnlyDebugHeaderWhitelist = []string{
"X-Real-IP",
}
// OpenAI chat-completions allowed headers (extend responses whitelist).
var openaiChatAllowedHeaders = map[string]bool{
"accept-language": true,
"content-type": true,
"conversation_id": true,
"user-agent": true,
"originator": true,
"session_id": true,
"openai-organization": true,
"openai-project": true,
"openai-beta": true,
}
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
type OpenAICodexUsageSnapshot struct {
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
@@ -1602,23 +1577,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed")
}
if c != nil && account != nil && account.Type == AccountTypeAPIKey {
if raw, ok := c.Get(OpenAIChatCompletionsBodyKey); ok {
if rawBody, ok := raw.([]byte); ok && len(rawBody) > 0 {
includeUsage := false
if v, ok := c.Get(OpenAIChatCompletionsIncludeUsageKey); ok {
if flag, ok := v.(bool); ok {
includeUsage = flag
}
}
if passthroughWriter, ok := c.Writer.(interface{ SetPassthrough() }); ok {
passthroughWriter.SetPassthrough()
}
return s.forwardChatCompletions(ctx, c, account, rawBody, includeUsage, startTime)
}
}
}
originalBody := body
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
originalModel := reqModel
@@ -2989,6 +2947,120 @@ func (s *OpenAIGatewayService) handleErrorResponse(
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
// compatErrorWriter is the signature for format-specific error writers used by
// the compat paths (Chat Completions and Anthropic Messages).
type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string)
// handleCompatErrorResponse is the shared non-failover error handler for the
// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of
// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit
// tracking, secondary failover) but delegates the final error write to the
// format-specific writer function.
func (s *OpenAIGatewayService) handleCompatErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
writeError compatErrorWriter,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
// Check custom error codes — if the account does not handle this status,
// return a generic error without exposing upstream details.
if !account.ShouldHandleErrorCode(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error")
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
}
// Track rate limits and decide whether to trigger secondary failover.
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
)
}
kind := "http_error"
if shouldDisable {
kind = "failover"
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
Message: upstreamMsg,
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: body,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
// Map status code to error type and write response
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// openaiStreamingResult streaming response result
type openaiStreamingResult struct {
usage *OpenAIUsage