Files
sub2api/backend/internal/service/openai_gateway_chat_completions_raw_test.go

261 lines
9.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildOpenAIChatCompletionsURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
base string
want string
}{
// 已是 /chat/completions原样返回
{"already chat/completions", "https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions"},
// 以 /v1 结尾:追加 /chat/completions
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/chat/completions"},
// 其他情况:追加 /v1/chat/completions
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/chat/completions"},
{"domain with trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/chat/completions"},
// 第三方上游常见形式
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/chat/completions"},
{"third-party with path prefix", "https://api.gptgod.online/api", "https://api.gptgod.online/api/v1/chat/completions"},
// 带空白字符
{"whitespace trimmed", " https://api.openai.com/v1 ", "https://api.openai.com/v1/chat/completions"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := buildOpenAIChatCompletionsURL(tt.base)
require.Equal(t, tt.want, got)
})
}
}
// TestBuildOpenAIResponsesURL_ProbeURL 锁定 probe/测试端点使用的 URL 构建逻辑,
// 确保 buildOpenAIResponsesURL 对标准 OpenAI base_url 格式均拼出 `/v1/responses`。
func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
base string
want string
}{
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/responses"},
{"domain trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/responses"},
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/responses"},
{"already /responses", "https://api.openai.com/v1/responses", "https://api.openai.com/v1/responses"},
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/responses"},
{"only domain, no scheme", "api.gptgod.online", "api.gptgod.online/v1/responses"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := buildOpenAIResponsesURL(tt.base)
require.Equal(t, tt.want, got)
})
}
}
func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDownstream(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
"",
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":4,"total_tokens":13,"prompt_tokens_details":{"cached_tokens":3}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_usage"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
account := rawChatCompletionsTestAccount()
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 9, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
require.Contains(t, rec.Body.String(), `"usage"`)
require.Contains(t, rec.Body.String(), "data: [DONE]")
}
func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
"",
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":8,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":6}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_disconnect"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
account := rawChatCompletionsTestAccount()
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 17, result.Usage.InputTokens)
require.Equal(t, 8, result.Usage.OutputTokens)
require.Equal(t, 6, result.Usage.CacheReadInputTokens)
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
}
func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
reqCtx, cancel := context.WithCancel(context.Background())
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
c.Request.Header.Set("Content-Type", "application/json")
cancel()
upstreamBody := strings.Join([]string{
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_ctx"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
account := rawChatCompletionsTestAccount()
result, err := svc.forwardAsRawChatCompletions(reqCtx, c, account, body, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) {
t.Parallel()
require.True(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[{"index":0}],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[]}`))
require.False(t, isOpenAIChatUsageOnlyStreamChunk(``))
}
func TestEnsureOpenAIChatStreamUsage(t *testing.T) {
t.Parallel()
body, err := ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4"}`))
require.NoError(t, err)
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
body, err = ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4","stream_options":{"include_usage":false}}`))
require.NoError(t, err)
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
}
func TestBufferRawChatCompletions_RejectsOversizedResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader("toolong")),
}
svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()}
svc.cfg.Gateway.UpstreamResponseReadMaxBytes = 3
result, err := svc.bufferRawChatCompletions(c, resp, "gpt-5.4", "gpt-5.4", "gpt-5.4", nil, nil, time.Now())
require.ErrorIs(t, err, ErrUpstreamResponseBodyTooLarge)
require.Nil(t, result)
require.Equal(t, http.StatusBadGateway, rec.Code)
}
func rawChatCompletionsTestConfig() *config.Config {
return &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
Enabled: false,
AllowInsecureHTTP: true,
},
},
}
}
func rawChatCompletionsTestAccount() *Account {
return &Account{
ID: 101,
Name: "raw-openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "http://upstream.example",
},
}
}