mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
261 lines
9.7 KiB
Go
261 lines
9.7 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"io"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/stretchr/testify/require"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
func TestBuildOpenAIChatCompletionsURL(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tests := []struct {
|
||
name string
|
||
base string
|
||
want string
|
||
}{
|
||
// 已是 /chat/completions:原样返回
|
||
{"already chat/completions", "https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions"},
|
||
// 以 /v1 结尾:追加 /chat/completions
|
||
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/chat/completions"},
|
||
// 其他情况:追加 /v1/chat/completions
|
||
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/chat/completions"},
|
||
{"domain with trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/chat/completions"},
|
||
// 第三方上游常见形式
|
||
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/chat/completions"},
|
||
{"third-party with path prefix", "https://api.gptgod.online/api", "https://api.gptgod.online/api/v1/chat/completions"},
|
||
// 带空白字符
|
||
{"whitespace trimmed", " https://api.openai.com/v1 ", "https://api.openai.com/v1/chat/completions"},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
got := buildOpenAIChatCompletionsURL(tt.base)
|
||
require.Equal(t, tt.want, got)
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestBuildOpenAIResponsesURL_ProbeURL 锁定 probe/测试端点使用的 URL 构建逻辑,
|
||
// 确保 buildOpenAIResponsesURL 对标准 OpenAI base_url 格式均拼出 `/v1/responses`。
|
||
func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tests := []struct {
|
||
name string
|
||
base string
|
||
want string
|
||
}{
|
||
{"bare domain", "https://api.openai.com", "https://api.openai.com/v1/responses"},
|
||
{"domain trailing slash", "https://api.openai.com/", "https://api.openai.com/v1/responses"},
|
||
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/responses"},
|
||
{"already /responses", "https://api.openai.com/v1/responses", "https://api.openai.com/v1/responses"},
|
||
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/responses"},
|
||
{"only domain, no scheme", "api.gptgod.online", "api.gptgod.online/v1/responses"},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
got := buildOpenAIResponsesURL(tt.base)
|
||
require.Equal(t, tt.want, got)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDownstream(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||
c.Request.Header.Set("Content-Type", "application/json")
|
||
|
||
upstreamBody := strings.Join([]string{
|
||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
|
||
"",
|
||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":4,"total_tokens":13,"prompt_tokens_details":{"cached_tokens":3}}}`,
|
||
"",
|
||
"data: [DONE]",
|
||
"",
|
||
}, "\n")
|
||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||
StatusCode: http.StatusOK,
|
||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_usage"}},
|
||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||
}}
|
||
|
||
svc := &OpenAIGatewayService{
|
||
cfg: rawChatCompletionsTestConfig(),
|
||
httpUpstream: upstream,
|
||
}
|
||
account := rawChatCompletionsTestAccount()
|
||
|
||
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.Equal(t, 9, result.Usage.InputTokens)
|
||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||
require.NotNil(t, upstream.lastReq)
|
||
require.NoError(t, upstream.lastReq.Context().Err())
|
||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||
require.Contains(t, rec.Body.String(), `"usage"`)
|
||
require.Contains(t, rec.Body.String(), "data: [DONE]")
|
||
}
|
||
|
||
func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||
c.Request.Header.Set("Content-Type", "application/json")
|
||
|
||
upstreamBody := strings.Join([]string{
|
||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
|
||
"",
|
||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":8,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":6}}}`,
|
||
"",
|
||
"data: [DONE]",
|
||
"",
|
||
}, "\n")
|
||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||
StatusCode: http.StatusOK,
|
||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_disconnect"}},
|
||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||
}}
|
||
|
||
svc := &OpenAIGatewayService{
|
||
cfg: rawChatCompletionsTestConfig(),
|
||
httpUpstream: upstream,
|
||
}
|
||
account := rawChatCompletionsTestAccount()
|
||
|
||
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.Equal(t, 17, result.Usage.InputTokens)
|
||
require.Equal(t, 8, result.Usage.OutputTokens)
|
||
require.Equal(t, 6, result.Usage.CacheReadInputTokens)
|
||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||
}
|
||
|
||
func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
reqCtx, cancel := context.WithCancel(context.Background())
|
||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
|
||
c.Request.Header.Set("Content-Type", "application/json")
|
||
cancel()
|
||
|
||
upstreamBody := strings.Join([]string{
|
||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`,
|
||
"",
|
||
"data: [DONE]",
|
||
"",
|
||
}, "\n")
|
||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||
StatusCode: http.StatusOK,
|
||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_ctx"}},
|
||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||
}}
|
||
|
||
svc := &OpenAIGatewayService{
|
||
cfg: rawChatCompletionsTestConfig(),
|
||
httpUpstream: upstream,
|
||
}
|
||
account := rawChatCompletionsTestAccount()
|
||
|
||
result, err := svc.forwardAsRawChatCompletions(reqCtx, c, account, body, "")
|
||
require.NoError(t, err)
|
||
require.NotNil(t, result)
|
||
require.NotNil(t, upstream.lastReq)
|
||
require.NoError(t, upstream.lastReq.Context().Err())
|
||
}
|
||
|
||
func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
require.True(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[{"index":0}],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[]}`))
|
||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(``))
|
||
}
|
||
|
||
func TestEnsureOpenAIChatStreamUsage(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
body, err := ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4"}`))
|
||
require.NoError(t, err)
|
||
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
|
||
|
||
body, err = ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4","stream_options":{"include_usage":false}}`))
|
||
require.NoError(t, err)
|
||
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
|
||
}
|
||
|
||
func TestBufferRawChatCompletions_RejectsOversizedResponse(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||
resp := &http.Response{
|
||
StatusCode: http.StatusOK,
|
||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||
Body: io.NopCloser(strings.NewReader("toolong")),
|
||
}
|
||
svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()}
|
||
svc.cfg.Gateway.UpstreamResponseReadMaxBytes = 3
|
||
|
||
result, err := svc.bufferRawChatCompletions(c, resp, "gpt-5.4", "gpt-5.4", "gpt-5.4", nil, nil, time.Now())
|
||
require.ErrorIs(t, err, ErrUpstreamResponseBodyTooLarge)
|
||
require.Nil(t, result)
|
||
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||
}
|
||
|
||
func rawChatCompletionsTestConfig() *config.Config {
|
||
return &config.Config{
|
||
Security: config.SecurityConfig{
|
||
URLAllowlist: config.URLAllowlistConfig{
|
||
Enabled: false,
|
||
AllowInsecureHTTP: true,
|
||
},
|
||
},
|
||
}
|
||
}
|
||
|
||
func rawChatCompletionsTestAccount() *Account {
|
||
return &Account{
|
||
ID: 101,
|
||
Name: "raw-openai-apikey",
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
Concurrency: 1,
|
||
Credentials: map[string]any{
|
||
"api_key": "sk-test",
|
||
"base_url": "http://upstream.example",
|
||
},
|
||
}
|
||
}
|