mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-22 15:44:46 +08:00
286 lines
9.2 KiB
Go
286 lines
9.2 KiB
Go
|
|
//go:build unit
|
|||
|
|
|
|||
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"bytes"
|
|||
|
|
"context"
|
|||
|
|
"encoding/json"
|
|||
|
|
"io"
|
|||
|
|
"net/http"
|
|||
|
|
"net/http/httptest"
|
|||
|
|
"testing"
|
|||
|
|
|
|||
|
|
"github.com/gin-gonic/gin"
|
|||
|
|
"github.com/stretchr/testify/require"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// httpUpstreamCapture captures the outgoing *http.Request for assertion.
|
|||
|
|
type httpUpstreamCapture struct {
|
|||
|
|
capturedReq *http.Request
|
|||
|
|
resp *http.Response
|
|||
|
|
err error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
|||
|
|
s.capturedReq = req
|
|||
|
|
return s.resp, s.err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
|
|||
|
|
s.capturedReq = req
|
|||
|
|
return s.resp, s.err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func newUpstreamAccount() *Account {
|
|||
|
|
return &Account{
|
|||
|
|
ID: 100,
|
|||
|
|
Name: "upstream-test",
|
|||
|
|
Platform: PlatformAntigravity,
|
|||
|
|
Type: AccountTypeUpstream,
|
|||
|
|
Status: StatusActive,
|
|||
|
|
Concurrency: 1,
|
|||
|
|
Credentials: map[string]any{
|
|||
|
|
"base_url": "https://upstream.example.com",
|
|||
|
|
"api_key": "sk-upstream-secret",
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// makeSSEOKResponse builds a minimal SSE response that
|
|||
|
|
// handleClaudeStreamingResponse / handleGeminiStreamingResponse
|
|||
|
|
// can consume without error.
|
|||
|
|
// We return 502 to bypass streaming and hit the error branch instead,
|
|||
|
|
// which is sufficient for testing header passthrough.
|
|||
|
|
func makeUpstreamErrorResponse() *http.Response {
|
|||
|
|
body := []byte(`{"error":{"message":"test error"}}`)
|
|||
|
|
return &http.Response{
|
|||
|
|
StatusCode: http.StatusBadGateway,
|
|||
|
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|||
|
|
Body: io.NopCloser(bytes.NewReader(body)),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// --- ForwardUpstream tests ---
|
|||
|
|
|
|||
|
|
func TestForwardUpstream_PassthroughHeaders(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
c, _ := gin.CreateTestContext(w)
|
|||
|
|
|
|||
|
|
body, _ := json.Marshal(map[string]any{
|
|||
|
|
"model": "claude-sonnet-4-5",
|
|||
|
|
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
|||
|
|
"max_tokens": 1,
|
|||
|
|
"stream": false,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
|||
|
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
|
req.Header.Set("anthropic-version", "2024-10-22")
|
|||
|
|
req.Header.Set("anthropic-beta", "output-128k-2025-02-19")
|
|||
|
|
req.Header.Set("X-Custom-Header", "custom-value")
|
|||
|
|
c.Request = req
|
|||
|
|
|
|||
|
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|||
|
|
svc := &AntigravityGatewayService{
|
|||
|
|
tokenProvider: &AntigravityTokenProvider{},
|
|||
|
|
httpUpstream: stub,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
|||
|
|
|
|||
|
|
captured := stub.capturedReq
|
|||
|
|
require.NotNil(t, captured, "upstream request should have been made")
|
|||
|
|
|
|||
|
|
// 客户端 header 应被透传
|
|||
|
|
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
|||
|
|
require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version"))
|
|||
|
|
require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta"))
|
|||
|
|
require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header"))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
c, _ := gin.CreateTestContext(w)
|
|||
|
|
|
|||
|
|
body, _ := json.Marshal(map[string]any{
|
|||
|
|
"model": "claude-sonnet-4-5",
|
|||
|
|
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
|||
|
|
"max_tokens": 1,
|
|||
|
|
"stream": false,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
|||
|
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
|
// 客户端发来的认证头应被覆盖
|
|||
|
|
req.Header.Set("Authorization", "Bearer client-token")
|
|||
|
|
req.Header.Set("x-api-key", "client-api-key")
|
|||
|
|
c.Request = req
|
|||
|
|
|
|||
|
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|||
|
|
svc := &AntigravityGatewayService{
|
|||
|
|
tokenProvider: &AntigravityTokenProvider{},
|
|||
|
|
httpUpstream: stub,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
|||
|
|
|
|||
|
|
captured := stub.capturedReq
|
|||
|
|
require.NotNil(t, captured)
|
|||
|
|
|
|||
|
|
// 认证头应使用上游账号的 api_key,而非客户端的
|
|||
|
|
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
|
|||
|
|
require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key"))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
c, _ := gin.CreateTestContext(w)
|
|||
|
|
|
|||
|
|
body, _ := json.Marshal(map[string]any{
|
|||
|
|
"model": "claude-sonnet-4-5",
|
|||
|
|
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
|||
|
|
"max_tokens": 1,
|
|||
|
|
"stream": false,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
|||
|
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
|
req.Header.Set("Connection", "keep-alive")
|
|||
|
|
req.Header.Set("Keep-Alive", "timeout=5")
|
|||
|
|
req.Header.Set("Transfer-Encoding", "chunked")
|
|||
|
|
req.Header.Set("Upgrade", "websocket")
|
|||
|
|
req.Header.Set("Te", "trailers")
|
|||
|
|
c.Request = req
|
|||
|
|
|
|||
|
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|||
|
|
svc := &AntigravityGatewayService{
|
|||
|
|
tokenProvider: &AntigravityTokenProvider{},
|
|||
|
|
httpUpstream: stub,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
|||
|
|
|
|||
|
|
captured := stub.capturedReq
|
|||
|
|
require.NotNil(t, captured)
|
|||
|
|
|
|||
|
|
// hop-by-hop header 不应出现
|
|||
|
|
require.Empty(t, captured.Header.Get("Connection"))
|
|||
|
|
require.Empty(t, captured.Header.Get("Keep-Alive"))
|
|||
|
|
require.Empty(t, captured.Header.Get("Transfer-Encoding"))
|
|||
|
|
require.Empty(t, captured.Header.Get("Upgrade"))
|
|||
|
|
require.Empty(t, captured.Header.Get("Te"))
|
|||
|
|
|
|||
|
|
// 但普通 header 应保留
|
|||
|
|
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// --- ForwardUpstreamGemini tests ---
|
|||
|
|
|
|||
|
|
func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
c, _ := gin.CreateTestContext(w)
|
|||
|
|
|
|||
|
|
body, _ := json.Marshal(map[string]any{
|
|||
|
|
"contents": []map[string]any{
|
|||
|
|
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
|||
|
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
|
req.Header.Set("X-Custom-Gemini", "gemini-value")
|
|||
|
|
req.Header.Set("X-Request-Id", "req-abc-123")
|
|||
|
|
c.Request = req
|
|||
|
|
|
|||
|
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|||
|
|
svc := &AntigravityGatewayService{
|
|||
|
|
tokenProvider: &AntigravityTokenProvider{},
|
|||
|
|
httpUpstream: stub,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
|||
|
|
|
|||
|
|
captured := stub.capturedReq
|
|||
|
|
require.NotNil(t, captured, "upstream request should have been made")
|
|||
|
|
|
|||
|
|
// 客户端 header 应被透传
|
|||
|
|
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
|||
|
|
require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini"))
|
|||
|
|
require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id"))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
c, _ := gin.CreateTestContext(w)
|
|||
|
|
|
|||
|
|
body, _ := json.Marshal(map[string]any{
|
|||
|
|
"contents": []map[string]any{
|
|||
|
|
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
|||
|
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
|
req.Header.Set("Authorization", "Bearer client-gemini-token")
|
|||
|
|
c.Request = req
|
|||
|
|
|
|||
|
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|||
|
|
svc := &AntigravityGatewayService{
|
|||
|
|
tokenProvider: &AntigravityTokenProvider{},
|
|||
|
|
httpUpstream: stub,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
|||
|
|
|
|||
|
|
captured := stub.capturedReq
|
|||
|
|
require.NotNil(t, captured)
|
|||
|
|
|
|||
|
|
// 认证头应使用上游账号的 api_key
|
|||
|
|
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
c, _ := gin.CreateTestContext(w)
|
|||
|
|
|
|||
|
|
body, _ := json.Marshal(map[string]any{
|
|||
|
|
"contents": []map[string]any{
|
|||
|
|
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
|||
|
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
|
req.Header.Set("Connection", "keep-alive")
|
|||
|
|
req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz")
|
|||
|
|
req.Header.Set("Host", "evil.example.com")
|
|||
|
|
c.Request = req
|
|||
|
|
|
|||
|
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
|||
|
|
svc := &AntigravityGatewayService{
|
|||
|
|
tokenProvider: &AntigravityTokenProvider{},
|
|||
|
|
httpUpstream: stub,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
|||
|
|
|
|||
|
|
captured := stub.capturedReq
|
|||
|
|
require.NotNil(t, captured)
|
|||
|
|
|
|||
|
|
// hop-by-hop header 不应出现
|
|||
|
|
require.Empty(t, captured.Header.Get("Connection"))
|
|||
|
|
require.Empty(t, captured.Header.Get("Proxy-Authorization"))
|
|||
|
|
// Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传
|
|||
|
|
require.Empty(t, captured.Header.Values("Host"))
|
|||
|
|
|
|||
|
|
// 普通 header 应保留
|
|||
|
|
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
|||
|
|
}
|