Merge pull request #926 from 7976723/feat/chat-completions-compat-v2

feat: 添加 OpenAI Chat Completions 兼容端点(基于 #648,修复编译错误和运行时 panic)
This commit is contained in:
Wesley Liddick
2026-03-11 17:42:03 +08:00
committed by GitHub
6 changed files with 1725 additions and 9 deletions

View File

@@ -0,0 +1,546 @@
package handler
import (
"bytes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ChatCompletions handles OpenAI Chat Completions API compatibility.
// POST /v1/chat/completions
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// Preserve original chat-completions request for upstream passthrough when needed.
c.Set(service.OpenAIChatCompletionsBodyKey, body)
var chatReq map[string]any
if err := json.Unmarshal(body, &chatReq); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
includeUsage := false
if streamOptions, ok := chatReq["stream_options"].(map[string]any); ok {
if v, ok := streamOptions["include_usage"].(bool); ok {
includeUsage = v
}
}
c.Set(service.OpenAIChatCompletionsIncludeUsageKey, includeUsage)
converted, err := service.ConvertChatCompletionsToResponses(chatReq)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
return
}
convertedBody, err := json.Marshal(converted)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
stream, _ := converted["stream"].(bool)
model, _ := converted["model"].(string)
originalWriter := c.Writer
writer := newChatCompletionsResponseWriter(c.Writer, stream, includeUsage, model)
c.Writer = writer
c.Request.Body = io.NopCloser(bytes.NewReader(convertedBody))
c.Request.ContentLength = int64(len(convertedBody))
h.Responses(c)
writer.Finalize()
c.Writer = originalWriter
}
type chatCompletionsResponseWriter struct {
gin.ResponseWriter
stream bool
includeUsage bool
buffer bytes.Buffer
streamBuf bytes.Buffer
state *chatCompletionStreamState
corrector *service.CodexToolCorrector
finalized bool
passthrough bool
}
type chatCompletionStreamState struct {
id string
model string
created int64
sentRole bool
sawToolCall bool
sawText bool
toolCallIndex map[string]int
usage map[string]any
}
func newChatCompletionsResponseWriter(w gin.ResponseWriter, stream bool, includeUsage bool, model string) *chatCompletionsResponseWriter {
return &chatCompletionsResponseWriter{
ResponseWriter: w,
stream: stream,
includeUsage: includeUsage,
state: &chatCompletionStreamState{
model: strings.TrimSpace(model),
toolCallIndex: make(map[string]int),
},
corrector: service.NewCodexToolCorrector(),
}
}
func (w *chatCompletionsResponseWriter) Write(data []byte) (int, error) {
if w.passthrough {
return w.ResponseWriter.Write(data)
}
if w.stream {
n, err := w.streamBuf.Write(data)
if err != nil {
return n, err
}
w.flushStreamBuffer()
return n, nil
}
if w.finalized {
return len(data), nil
}
return w.buffer.Write(data)
}
func (w *chatCompletionsResponseWriter) WriteString(s string) (int, error) {
return w.Write([]byte(s))
}
func (w *chatCompletionsResponseWriter) Finalize() {
if w.finalized {
return
}
w.finalized = true
if w.passthrough {
return
}
if w.stream {
return
}
body := w.buffer.Bytes()
if len(body) == 0 {
return
}
w.ResponseWriter.Header().Del("Content-Length")
converted, err := service.ConvertResponsesToChatCompletion(body)
if err != nil {
_, _ = w.ResponseWriter.Write(body)
return
}
corrected := converted
if correctedStr, ok := w.corrector.CorrectToolCallsInSSEData(string(converted)); ok {
corrected = []byte(correctedStr)
}
_, _ = w.ResponseWriter.Write(corrected)
}
func (w *chatCompletionsResponseWriter) SetPassthrough() {
w.passthrough = true
}
func (w *chatCompletionsResponseWriter) Status() int {
if w.ResponseWriter == nil {
return 0
}
return w.ResponseWriter.Status()
}
func (w *chatCompletionsResponseWriter) Written() bool {
if w.ResponseWriter == nil {
return false
}
return w.ResponseWriter.Written()
}
func (w *chatCompletionsResponseWriter) flushStreamBuffer() {
for {
buf := w.streamBuf.Bytes()
idx := bytes.IndexByte(buf, '\n')
if idx == -1 {
return
}
lineBytes := w.streamBuf.Next(idx + 1)
line := strings.TrimRight(string(lineBytes), "\r\n")
w.handleStreamLine(line)
}
}
func (w *chatCompletionsResponseWriter) handleStreamLine(line string) {
if line == "" {
return
}
if strings.HasPrefix(line, ":") {
_, _ = w.ResponseWriter.Write([]byte(line + "\n\n"))
return
}
if !strings.HasPrefix(line, "data:") {
return
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
for _, chunk := range w.convertResponseDataToChatChunks(data) {
if chunk == "" {
continue
}
if chunk == "[DONE]" {
_, _ = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
continue
}
_, _ = w.ResponseWriter.Write([]byte("data: " + chunk + "\n\n"))
}
}
func (w *chatCompletionsResponseWriter) convertResponseDataToChatChunks(data string) []string {
if data == "" {
return nil
}
if data == "[DONE]" {
return []string{"[DONE]"}
}
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
return []string{data}
}
if _, ok := payload["error"]; ok {
return []string{data}
}
eventType := strings.TrimSpace(getString(payload["type"]))
if eventType == "" {
return []string{data}
}
w.state.applyMetadata(payload)
switch eventType {
case "response.created":
return nil
case "response.output_text.delta":
delta := getString(payload["delta"])
if delta == "" {
return nil
}
w.state.sawText = true
return []string{w.buildTextDeltaChunk(delta)}
case "response.output_text.done":
if w.state.sawText {
return nil
}
text := getString(payload["text"])
if text == "" {
return nil
}
w.state.sawText = true
return []string{w.buildTextDeltaChunk(text)}
case "response.output_item.added", "response.output_item.delta":
if item, ok := payload["item"].(map[string]any); ok {
if callID, name, args, ok := extractToolCallFromItem(item); ok {
w.state.sawToolCall = true
return []string{w.buildToolCallChunk(callID, name, args)}
}
}
case "response.completed", "response.done":
if responseObj, ok := payload["response"].(map[string]any); ok {
w.state.applyResponseUsage(responseObj)
}
return []string{w.buildFinalChunk()}
}
if strings.Contains(eventType, "tool_call") || strings.Contains(eventType, "function_call") {
callID := strings.TrimSpace(getString(payload["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(payload["tool_call_id"]))
}
if callID == "" {
callID = strings.TrimSpace(getString(payload["id"]))
}
args := getString(payload["delta"])
name := strings.TrimSpace(getString(payload["name"]))
if callID != "" && (args != "" || name != "") {
w.state.sawToolCall = true
return []string{w.buildToolCallChunk(callID, name, args)}
}
}
return nil
}
func (w *chatCompletionsResponseWriter) buildTextDeltaChunk(delta string) string {
w.state.ensureDefaults()
payload := map[string]any{
"content": delta,
}
if !w.state.sentRole {
payload["role"] = "assistant"
w.state.sentRole = true
}
return w.buildChunk(payload, nil, nil)
}
func (w *chatCompletionsResponseWriter) buildToolCallChunk(callID, name, args string) string {
w.state.ensureDefaults()
index := w.state.toolCallIndexFor(callID)
function := map[string]any{}
if name != "" {
function["name"] = name
}
if args != "" {
function["arguments"] = args
}
toolCall := map[string]any{
"index": index,
"id": callID,
"type": "function",
"function": function,
}
delta := map[string]any{
"tool_calls": []any{toolCall},
}
if !w.state.sentRole {
delta["role"] = "assistant"
w.state.sentRole = true
}
return w.buildChunk(delta, nil, nil)
}
func (w *chatCompletionsResponseWriter) buildFinalChunk() string {
w.state.ensureDefaults()
finishReason := "stop"
if w.state.sawToolCall {
finishReason = "tool_calls"
}
usage := map[string]any(nil)
if w.includeUsage && w.state.usage != nil {
usage = w.state.usage
}
return w.buildChunk(map[string]any{}, finishReason, usage)
}
func (w *chatCompletionsResponseWriter) buildChunk(delta map[string]any, finishReason any, usage map[string]any) string {
w.state.ensureDefaults()
chunk := map[string]any{
"id": w.state.id,
"object": "chat.completion.chunk",
"created": w.state.created,
"model": w.state.model,
"choices": []any{
map[string]any{
"index": 0,
"delta": delta,
"finish_reason": finishReason,
},
},
}
if usage != nil {
chunk["usage"] = usage
}
data, _ := json.Marshal(chunk)
if corrected, ok := w.corrector.CorrectToolCallsInSSEData(string(data)); ok {
return corrected
}
return string(data)
}
func (s *chatCompletionStreamState) ensureDefaults() {
if s.id == "" {
s.id = "chatcmpl-" + randomHexUnsafe(12)
}
if s.model == "" {
s.model = "unknown"
}
if s.created == 0 {
s.created = time.Now().Unix()
}
}
func (s *chatCompletionStreamState) toolCallIndexFor(callID string) int {
if idx, ok := s.toolCallIndex[callID]; ok {
return idx
}
idx := len(s.toolCallIndex)
s.toolCallIndex[callID] = idx
return idx
}
func (s *chatCompletionStreamState) applyMetadata(payload map[string]any) {
if responseObj, ok := payload["response"].(map[string]any); ok {
s.applyResponseMetadata(responseObj)
}
if s.id == "" {
if id := strings.TrimSpace(getString(payload["response_id"])); id != "" {
s.id = id
} else if id := strings.TrimSpace(getString(payload["id"])); id != "" {
s.id = id
}
}
if s.model == "" {
if model := strings.TrimSpace(getString(payload["model"])); model != "" {
s.model = model
}
}
if s.created == 0 {
if created := getInt64(payload["created_at"]); created != 0 {
s.created = created
} else if created := getInt64(payload["created"]); created != 0 {
s.created = created
}
}
}
func (s *chatCompletionStreamState) applyResponseMetadata(responseObj map[string]any) {
if s.id == "" {
if id := strings.TrimSpace(getString(responseObj["id"])); id != "" {
s.id = id
}
}
if s.model == "" {
if model := strings.TrimSpace(getString(responseObj["model"])); model != "" {
s.model = model
}
}
if s.created == 0 {
if created := getInt64(responseObj["created_at"]); created != 0 {
s.created = created
}
}
}
func (s *chatCompletionStreamState) applyResponseUsage(responseObj map[string]any) {
usage, ok := responseObj["usage"].(map[string]any)
if !ok {
return
}
promptTokens := int(getNumber(usage["input_tokens"]))
completionTokens := int(getNumber(usage["output_tokens"]))
if promptTokens == 0 && completionTokens == 0 {
return
}
s.usage = map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": completionTokens,
"total_tokens": promptTokens + completionTokens,
}
}
func extractToolCallFromItem(item map[string]any) (string, string, string, bool) {
itemType := strings.TrimSpace(getString(item["type"]))
if itemType != "tool_call" && itemType != "function_call" {
return "", "", "", false
}
callID := strings.TrimSpace(getString(item["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(item["id"]))
}
name := strings.TrimSpace(getString(item["name"]))
args := getString(item["arguments"])
if fn, ok := item["function"].(map[string]any); ok {
if name == "" {
name = strings.TrimSpace(getString(fn["name"]))
}
if args == "" {
args = getString(fn["arguments"])
}
}
if callID == "" && name == "" && args == "" {
return "", "", "", false
}
if callID == "" {
callID = "call_" + randomHexUnsafe(6)
}
return callID, name, args, true
}
func getString(value any) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
case json.Number:
return v.String()
default:
return ""
}
}
func getNumber(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
f, _ := v.Float64()
return f
default:
return 0
}
}
func getInt64(value any) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case float64:
return int64(v)
case json.Number:
i, _ := v.Int64()
return i
default:
return 0
}
}
func randomHexUnsafe(byteLength int) string {
if byteLength <= 0 {
byteLength = 8
}
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil {
return "000000"
}
return hex.EncodeToString(buf)
}