refactor: 重构 Chat Completions 端点,采用类型安全的 Responses API 转换

将 /v1/chat/completions 端点从 ResponseWriter 劫持模式重构为独立的
类型安全转换路径,与 Anthropic Messages 端点架构对齐:

- 在 apicompat 包新增 Chat Completions 完整类型定义和双向转换器
- 新增 ForwardAsChatCompletions service 方法,走 Responses API 上游
- Handler 改为独立的账号选择/failover 循环,不再劫持 Responses handler
- 提取 handleCompatErrorResponse 为 Chat Completions 和 Messages 共用
- 删除旧的 forwardChatCompletions 直传路径及相关死代码
This commit is contained in:
shaw
2026-03-11 22:10:22 +08:00
parent 8dd38f4775
commit 9d81467937
11 changed files with 2420 additions and 1717 deletions

View File

@@ -1,23 +1,53 @@
package handler
import (
"bytes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"io"
"context"
"errors"
"net/http"
"strings"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ChatCompletions handles OpenAI Chat Completions API compatibility.
// ChatCompletions handles OpenAI Chat Completions API requests.
// POST /v1/chat/completions
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
streamStarted := false
defer h.recoverResponsesPanic(c, &streamStarted)
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -31,516 +61,230 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
return
}
// Preserve original chat-completions request for upstream passthrough when needed.
c.Set(service.OpenAIChatCompletionsBodyKey, body)
var chatReq map[string]any
if err := json.Unmarshal(body, &chatReq); err != nil {
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
includeUsage := false
if streamOptions, ok := chatReq["stream_options"].(map[string]any); ok {
if v, ok := streamOptions["include_usage"].(bool); ok {
includeUsage = v
}
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
c.Set(service.OpenAIChatCompletionsIncludeUsageKey, includeUsage)
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
converted, err := service.ConvertChatCompletionsToResponses(chatReq)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
convertedBody, err := json.Marshal(converted)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
stream, _ := converted["stream"].(bool)
model, _ := converted["model"].(string)
originalWriter := c.Writer
writer := newChatCompletionsResponseWriter(c.Writer, stream, includeUsage, model)
c.Writer = writer
c.Request.Body = io.NopCloser(bytes.NewReader(convertedBody))
c.Request.ContentLength = int64(len(convertedBody))
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
h.Responses(c)
writer.Finalize()
c.Writer = originalWriter
}
type chatCompletionsResponseWriter struct {
gin.ResponseWriter
stream bool
includeUsage bool
buffer bytes.Buffer
streamBuf bytes.Buffer
state *chatCompletionStreamState
corrector *service.CodexToolCorrector
finalized bool
passthrough bool
}
type chatCompletionStreamState struct {
id string
model string
created int64
sentRole bool
sawToolCall bool
sawText bool
toolCallIndex map[string]int
usage map[string]any
}
func newChatCompletionsResponseWriter(w gin.ResponseWriter, stream bool, includeUsage bool, model string) *chatCompletionsResponseWriter {
return &chatCompletionsResponseWriter{
ResponseWriter: w,
stream: stream,
includeUsage: includeUsage,
state: &chatCompletionStreamState{
model: strings.TrimSpace(model),
toolCallIndex: make(map[string]int),
},
corrector: service.NewCodexToolCorrector(),
}
}
func (w *chatCompletionsResponseWriter) Write(data []byte) (int, error) {
if w.passthrough {
return w.ResponseWriter.Write(data)
}
if w.stream {
n, err := w.streamBuf.Write(data)
if err != nil {
return n, err
}
w.flushStreamBuffer()
return n, nil
}
if w.finalized {
return len(data), nil
}
return w.buffer.Write(data)
}
func (w *chatCompletionsResponseWriter) WriteString(s string) (int, error) {
return w.Write([]byte(s))
}
func (w *chatCompletionsResponseWriter) Finalize() {
if w.finalized {
return
}
w.finalized = true
if w.passthrough {
return
}
if w.stream {
return
}
body := w.buffer.Bytes()
if len(body) == 0 {
return
}
w.ResponseWriter.Header().Del("Content-Length")
converted, err := service.ConvertResponsesToChatCompletion(body)
if err != nil {
_, _ = w.ResponseWriter.Write(body)
return
}
corrected := converted
if correctedStr, ok := w.corrector.CorrectToolCallsInSSEData(string(converted)); ok {
corrected = []byte(correctedStr)
}
_, _ = w.ResponseWriter.Write(corrected)
}
func (w *chatCompletionsResponseWriter) SetPassthrough() {
w.passthrough = true
}
func (w *chatCompletionsResponseWriter) Status() int {
if w.ResponseWriter == nil {
return 0
}
return w.ResponseWriter.Status()
}
func (w *chatCompletionsResponseWriter) Written() bool {
if w.ResponseWriter == nil {
return false
}
return w.ResponseWriter.Written()
}
func (w *chatCompletionsResponseWriter) flushStreamBuffer() {
for {
buf := w.streamBuf.Bytes()
idx := bytes.IndexByte(buf, '\n')
if idx == -1 {
return
}
lineBytes := w.streamBuf.Next(idx + 1)
line := strings.TrimRight(string(lineBytes), "\r\n")
w.handleStreamLine(line)
}
}
func (w *chatCompletionsResponseWriter) handleStreamLine(line string) {
if line == "" {
return
}
if strings.HasPrefix(line, ":") {
_, _ = w.ResponseWriter.Write([]byte(line + "\n\n"))
return
}
if !strings.HasPrefix(line, "data:") {
return
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
for _, chunk := range w.convertResponseDataToChatChunks(data) {
if chunk == "" {
continue
}
if chunk == "[DONE]" {
_, _ = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
continue
}
_, _ = w.ResponseWriter.Write([]byte("data: " + chunk + "\n\n"))
}
}
func (w *chatCompletionsResponseWriter) convertResponseDataToChatChunks(data string) []string {
if data == "" {
return nil
}
if data == "[DONE]" {
return []string{"[DONE]"}
}
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
return []string{data}
}
if _, ok := payload["error"]; ok {
return []string{data}
}
eventType := strings.TrimSpace(getString(payload["type"]))
if eventType == "" {
return []string{data}
}
w.state.applyMetadata(payload)
switch eventType {
case "response.created":
return nil
case "response.output_text.delta":
delta := getString(payload["delta"])
if delta == "" {
return nil
}
w.state.sawText = true
return []string{w.buildTextDeltaChunk(delta)}
case "response.output_text.done":
if w.state.sawText {
return nil
}
text := getString(payload["text"])
if text == "" {
return nil
}
w.state.sawText = true
return []string{w.buildTextDeltaChunk(text)}
case "response.output_item.added", "response.output_item.delta":
if item, ok := payload["item"].(map[string]any); ok {
if callID, name, args, ok := extractToolCallFromItem(item); ok {
w.state.sawToolCall = true
return []string{w.buildToolCallChunk(callID, name, args)}
c.Set("openai_chat_completions_fallback_model", "")
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err != nil {
reqLog.Warn("openai_chat_completions.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
defaultModel := ""
if apiKey.Group != nil {
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != reqModel {
reqLog.Info("openai_chat_completions.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err == nil && selection != nil {
c.Set("openai_chat_completions_fallback_model", defaultModel)
}
}
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
} else {
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
} else {
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
}
return
}
}
case "response.completed", "response.done":
if responseObj, ok := payload["response"].(map[string]any); ok {
w.state.applyResponseUsage(responseObj)
if selection == nil || selection.Account == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
return []string{w.buildFinalChunk()}
}
account := selection.Account
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
_ = scheduleDecision
setOpsSelectedAccount(c, account.ID, account.Platform)
if strings.Contains(eventType, "tool_call") || strings.Contains(eventType, "function_call") {
callID := strings.TrimSpace(getString(payload["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(payload["tool_call_id"]))
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
if callID == "" {
callID = strings.TrimSpace(getString(payload["id"]))
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
defaultMappedModel := ""
if apiKey.Group != nil {
defaultMappedModel = apiKey.Group.DefaultMappedModel
}
args := getString(payload["delta"])
name := strings.TrimSpace(getString(payload["name"]))
if callID != "" && (args != "" || name != "") {
w.state.sawToolCall = true
return []string{w.buildToolCallChunk(callID, name, args)}
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
defaultMappedModel = fallbackModel
}
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
return nil
}
func (w *chatCompletionsResponseWriter) buildTextDeltaChunk(delta string) string {
w.state.ensureDefaults()
payload := map[string]any{
"content": delta,
}
if !w.state.sentRole {
payload["role"] = "assistant"
w.state.sentRole = true
}
return w.buildChunk(payload, nil, nil)
}
func (w *chatCompletionsResponseWriter) buildToolCallChunk(callID, name, args string) string {
w.state.ensureDefaults()
index := w.state.toolCallIndexFor(callID)
function := map[string]any{}
if name != "" {
function["name"] = name
}
if args != "" {
function["arguments"] = args
}
toolCall := map[string]any{
"index": index,
"id": callID,
"type": "function",
"function": function,
}
delta := map[string]any{
"tool_calls": []any{toolCall},
}
if !w.state.sentRole {
delta["role"] = "assistant"
w.state.sentRole = true
}
return w.buildChunk(delta, nil, nil)
}
func (w *chatCompletionsResponseWriter) buildFinalChunk() string {
w.state.ensureDefaults()
finishReason := "stop"
if w.state.sawToolCall {
finishReason = "tool_calls"
}
usage := map[string]any(nil)
if w.includeUsage && w.state.usage != nil {
usage = w.state.usage
}
return w.buildChunk(map[string]any{}, finishReason, usage)
}
func (w *chatCompletionsResponseWriter) buildChunk(delta map[string]any, finishReason any, usage map[string]any) string {
w.state.ensureDefaults()
chunk := map[string]any{
"id": w.state.id,
"object": "chat.completion.chunk",
"created": w.state.created,
"model": w.state.model,
"choices": []any{
map[string]any{
"index": 0,
"delta": delta,
"finish_reason": finishReason,
},
},
}
if usage != nil {
chunk["usage"] = usage
}
data, _ := json.Marshal(chunk)
if corrected, ok := w.corrector.CorrectToolCallsInSSEData(string(data)); ok {
return corrected
}
return string(data)
}
func (s *chatCompletionStreamState) ensureDefaults() {
if s.id == "" {
s.id = "chatcmpl-" + randomHexUnsafe(12)
}
if s.model == "" {
s.model = "unknown"
}
if s.created == 0 {
s.created = time.Now().Unix()
}
}
func (s *chatCompletionStreamState) toolCallIndexFor(callID string) int {
if idx, ok := s.toolCallIndex[callID]; ok {
return idx
}
idx := len(s.toolCallIndex)
s.toolCallIndex[callID] = idx
return idx
}
func (s *chatCompletionStreamState) applyMetadata(payload map[string]any) {
if responseObj, ok := payload["response"].(map[string]any); ok {
s.applyResponseMetadata(responseObj)
}
if s.id == "" {
if id := strings.TrimSpace(getString(payload["response_id"])); id != "" {
s.id = id
} else if id := strings.TrimSpace(getString(payload["id"])); id != "" {
s.id = id
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}
if s.model == "" {
if model := strings.TrimSpace(getString(payload["model"])); model != "" {
s.model = model
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
}
if s.created == 0 {
if created := getInt64(payload["created_at"]); created != 0 {
s.created = created
} else if created := getInt64(payload["created"]); created != 0 {
s.created = created
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// Pool mode: retry on the same account
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Warn("openai_chat_completions.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
} else {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
}
}
func (s *chatCompletionStreamState) applyResponseMetadata(responseObj map[string]any) {
if s.id == "" {
if id := strings.TrimSpace(getString(responseObj["id"])); id != "" {
s.id = id
}
}
if s.model == "" {
if model := strings.TrimSpace(getString(responseObj["model"])); model != "" {
s.model = model
}
}
if s.created == 0 {
if created := getInt64(responseObj["created_at"]); created != 0 {
s.created = created
}
}
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
func (s *chatCompletionStreamState) applyResponseUsage(responseObj map[string]any) {
usage, ok := responseObj["usage"].(map[string]any)
if !ok {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("openai_chat_completions.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
promptTokens := int(getNumber(usage["input_tokens"]))
completionTokens := int(getNumber(usage["output_tokens"]))
if promptTokens == 0 && completionTokens == 0 {
return
}
s.usage = map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": completionTokens,
"total_tokens": promptTokens + completionTokens,
}
}
func extractToolCallFromItem(item map[string]any) (string, string, string, bool) {
itemType := strings.TrimSpace(getString(item["type"]))
if itemType != "tool_call" && itemType != "function_call" {
return "", "", "", false
}
callID := strings.TrimSpace(getString(item["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(item["id"]))
}
name := strings.TrimSpace(getString(item["name"]))
args := getString(item["arguments"])
if fn, ok := item["function"].(map[string]any); ok {
if name == "" {
name = strings.TrimSpace(getString(fn["name"]))
}
if args == "" {
args = getString(fn["arguments"])
}
}
if callID == "" && name == "" && args == "" {
return "", "", "", false
}
if callID == "" {
callID = "call_" + randomHexUnsafe(6)
}
return callID, name, args, true
}
func getString(value any) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
case json.Number:
return v.String()
default:
return ""
}
}
func getNumber(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
f, _ := v.Float64()
return f
default:
return 0
}
}
func getInt64(value any) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case float64:
return int64(v)
case json.Number:
i, _ := v.Int64()
return i
default:
return 0
}
}
func randomHexUnsafe(byteLength int) string {
if byteLength <= 0 {
byteLength = 8
}
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil {
return "000000"
}
return hex.EncodeToString(buf)
}