Files
sub2api/backend/internal/service/antigravity_gateway_service_test.go

1493 lines
51 KiB
Go
Raw Permalink Normal View History

package service
import (
2026-01-23 22:24:46 +08:00
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
2026-01-23 22:24:46 +08:00
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
2026-01-23 22:24:46 +08:00
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
type antigravityFailingWriter struct {
gin.ResponseWriter
failAfter int // 允许成功写入的次数,之后所有写入返回错误
writes int
}
func (w *antigravityFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService {
return &AntigravityGatewayService{
settingService: &SettingService{cfg: cfg},
}
}
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
Thinking: &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: 1024,
},
Messages: []antigravity.ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[
{"type":"thinking","thinking":"secret plan","signature":""},
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
]`),
},
{
Role: "user",
Content: json.RawMessage(`[
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false},
{"type":"redacted_thinking","data":"..."}
]`),
},
},
}
changed, err := stripSignatureSensitiveBlocksFromClaudeRequest(req)
require.NoError(t, err)
require.True(t, changed)
require.Nil(t, req.Thinking)
require.Len(t, req.Messages, 2)
var blocks0 []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks0))
require.Len(t, blocks0, 2)
require.Equal(t, "text", blocks0[0]["type"])
require.Equal(t, "secret plan", blocks0[0]["text"])
require.Equal(t, "text", blocks0[1]["type"])
var blocks1 []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[1].Content, &blocks1))
require.Len(t, blocks1, 1)
require.Equal(t, "text", blocks1[0]["type"])
require.NotEmpty(t, blocks1[0]["text"])
}
func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
Thinking: &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: 1024,
},
Messages: []antigravity.ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[{"type":"thinking","thinking":"secret plan"},{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}]`),
},
},
}
changed, err := stripThinkingFromClaudeRequest(req)
require.NoError(t, err)
require.True(t, changed)
require.Nil(t, req.Thinking)
var blocks []map[string]any
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks))
require.Len(t, blocks, 2)
require.Equal(t, "text", blocks[0]["type"])
require.Equal(t, "secret plan", blocks[0]["text"])
require.Equal(t, "tool_use", blocks[1]["type"])
}
2026-01-23 22:24:46 +08:00
func TestIsPromptTooLongError(t *testing.T) {
require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`)))
require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`)))
require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`)))
}
type httpUpstreamStub struct {
resp *http.Response
err error
}
func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return s.resp, s.err
}
2026-02-03 15:36:17 +08:00
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
return s.resp, s.err
}
2026-03-12 11:04:14 +08:00
type queuedHTTPUpstreamStub struct {
responses []*http.Response
errors []error
requestBodies [][]byte
callCount int
onCall func(*http.Request, *queuedHTTPUpstreamStub)
2026-03-12 11:04:14 +08:00
}
func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
if req != nil && req.Body != nil {
body, _ := io.ReadAll(req.Body)
s.requestBodies = append(s.requestBodies, body)
req.Body = io.NopCloser(bytes.NewReader(body))
} else {
s.requestBodies = append(s.requestBodies, nil)
}
idx := s.callCount
s.callCount++
if s.onCall != nil {
s.onCall(req, s)
}
2026-03-12 11:04:14 +08:00
var resp *http.Response
if idx < len(s.responses) {
resp = s.responses[idx]
}
var err error
if idx < len(s.errors) {
err = s.errors[idx]
}
if resp == nil && err == nil {
return nil, errors.New("unexpected upstream call")
}
return resp, err
}
func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, concurrency)
}
type antigravitySettingRepoStub struct{}
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
return "", ErrSettingNotFound
}
func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
2026-01-23 22:24:46 +08:00
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-6",
2026-01-23 22:24:46 +08:00
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
"max_tokens": 1,
"stream": false,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
respBody := []byte(`{"error":{"message":"Prompt is too long"}}`)
resp := &http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{"X-Request-Id": []string{"req-1"}},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
2026-01-23 22:24:46 +08:00
}
account := &Account{
ID: 1,
Name: "acc-1",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
}
result, err := svc.Forward(context.Background(), c, account, body, false)
2026-01-23 22:24:46 +08:00
require.Nil(t, result)
var promptErr *PromptTooLongError
require.ErrorAs(t, err, &promptErr)
require.Equal(t, http.StatusBadRequest, promptErr.StatusCode)
require.Equal(t, "req-1", promptErr.RequestID)
require.NotEmpty(t, promptErr.Body)
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 1)
require.Equal(t, "prompt_too_long", events[0].Kind)
}
// TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover
// 验证:当账号存在模型限流且剩余时间 >= antigravityRateLimitThreshold 时,
// Forward 方法应返回 UpstreamFailoverError触发 Handler 切换账号
func TestAntigravityGatewayService_Forward_ModelRateLimitTriggersFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-6",
"messages": []map[string]any{
{"role": "user", "content": "hi"},
},
"max_tokens": 1,
"stream": false,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
// 不需要真正调用上游,因为预检查会直接返回切换信号
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 1,
Name: "acc-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
result, err := svc.Forward(context.Background(), c, account, body, false)
require.Nil(t, result, "Forward should not return result when model rate limited")
require.NotNil(t, err, "Forward should return error")
// 核心验证:错误应该是 UpstreamFailoverError而不是普通 502 错误
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
// 非粘性会话请求ForceCacheBilling 应为 false
require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session")
}
// TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover
// 验证ForwardGemini 方法同样能正确将 AntigravityAccountSwitchError 转换为 UpstreamFailoverError
func TestAntigravityGatewayService_ForwardGemini_ModelRateLimitTriggersFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
// 不需要真正调用上游,因为预检查会直接返回切换信号
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 2,
Name: "acc-gemini-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-2.5-flash": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false)
require.Nil(t, result, "ForwardGemini should not return result when model rate limited")
require.NotNil(t, err, "ForwardGemini should return error")
// 核心验证:错误应该是 UpstreamFailoverError而不是普通 502 错误
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
// 非粘性会话请求ForceCacheBilling 应为 false
require.False(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be false for non-sticky session")
}
// TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling
// 验证粘性会话切换时UpstreamFailoverError.ForceCacheBilling 应为 true
func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-opus-4-6",
"messages": []map[string]string{{"role": "user", "content": "hello"}},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 3,
Name: "acc-sticky-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-opus-4-6-thinking": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
// 传入 isStickySession = true
result, err := svc.Forward(context.Background(), c, account, body, true)
require.Nil(t, result, "Forward should not return result when model rate limited")
require.NotNil(t, err, "Forward should return error")
// 核心验证粘性会话切换时ForceCacheBilling 应为 true
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies
// that ForwardGemini sets ForceCacheBilling=true for sticky session switch.
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
svc := &AntigravityGatewayService{
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: nil, err: nil},
}
// 设置模型限流:剩余时间 30 秒(> antigravityRateLimitThreshold 7s
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account := &Account{
ID: 4,
Name: "acc-gemini-sticky-rate-limited",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gemini-2.5-flash": map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
},
}
// 传入 isStickySession = true
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, true)
require.Nil(t, result, "ForwardGemini should not return result when model rate limited")
require.NotNil(t, err, "ForwardGemini should return error")
// 核心验证粘性会话切换时ForceCacheBilling 应为 true
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "error should be UpstreamFailoverError to trigger account switch")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// TestAntigravityGatewayService_Forward_BillsWithMappedModel
// 验证Antigravity Claude 转发返回的计费模型使用映射后的模型
func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"model": "claude-sonnet-4-5",
"messages": []map[string]any{
{"role": "user", "content": "hello"},
},
"max_tokens": 16,
"stream": true,
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request = req
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"X-Request-Id": []string{"req-bill-1"}},
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
const mappedModel = "gemini-3-pro-high"
account := &Account{
ID: 5,
Name: "acc-forward-billing",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
"claude-sonnet-4-5": mappedModel,
},
},
}
result, err := svc.Forward(context.Background(), c, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
}
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
// 验证Antigravity Gemini 转发返回的计费模型使用映射后的模型
func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
c.Request = req
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"X-Request-Id": []string{"req-bill-2"}},
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: &httpUpstreamStub{resp: resp},
}
const mappedModel = "gemini-3-pro-high"
account := &Account{
ID: 6,
Name: "acc-gemini-billing",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
"gemini-2.5-flash": mappedModel,
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
}
2026-03-12 11:04:14 +08:00
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
{"role": "model", "parts": []map[string]any{{"functionCall": map[string]any{"name": "toolA", "args": map[string]any{"x": 1}}, "thoughtSignature": "sig_bad_2"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
c.Request = req
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
secondRespBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req-sig-1"},
},
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
},
{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req-sig-2"},
},
Body: io.NopCloser(bytes.NewReader(secondRespBody)),
},
},
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: upstream,
}
const originalModel = "gemini-3.1-pro-preview"
const mappedModel = "gemini-3.1-pro-high"
account := &Account{
ID: 7,
Name: "acc-gemini-signature",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
originalModel: mappedModel,
},
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model)
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
firstReq := string(upstream.requestBodies[0])
secondReq := string(upstream.requestBodies[1])
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_1"`)
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_2"`)
require.Contains(t, secondReq, `"thoughtSignature":"skip_thought_signature_validator"`)
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_1"`)
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_2"`)
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.NotEmpty(t, events)
require.Equal(t, "signature_error", events[0].Kind)
}
func TestAntigravityGatewayService_ForwardGemini_SignatureRetryPropagatesFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
c.Request = req
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
const originalModel = "gemini-3.1-pro-preview"
const mappedModel = "gemini-3.1-pro-high"
account := &Account{
ID: 8,
Name: "acc-gemini-signature-failover",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"model_mapping": map[string]any{
originalModel: mappedModel,
},
},
}
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-Request-Id": []string{"req-sig-failover-1"},
},
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
},
},
onCall: func(_ *http.Request, stub *queuedHTTPUpstreamStub) {
if stub.callCount != 1 {
return
}
futureResetAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
account.Extra = map[string]any{
modelRateLimitsKey: map[string]any{
mappedModel: map[string]any{
"rate_limit_reset_at": futureResetAt,
},
},
}
},
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: upstream,
}
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, true)
require.Nil(t, result)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr, "signature retry should propagate failover instead of falling back to the original 400")
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling)
require.Len(t, upstream.requestBodies, 1, "retry should stop at preflight failover and not issue a second upstream request")
raw, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := raw.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 2)
require.Equal(t, "signature_error", events[0].Kind)
require.Equal(t, "failover", events[1].Kind)
}
// TestStreamUpstreamResponse_UsageAndFirstToken
// 验证usage 字段可被累积/覆盖更新,并且能记录首 token 时间
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"usage":{"input_tokens":1,"output_tokens":2,"cache_read_input_tokens":3,"cache_creation_input_tokens":4}}`)
fmt.Fprintln(pw, `data: {"usage":{"output_tokens":5}}`)
}()
start := time.Now().Add(-10 * time.Millisecond)
result := svc.streamUpstreamResponse(c, resp, start)
_ = pr.Close()
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 1, result.usage.InputTokens)
// 第二次事件覆盖 output_tokens
require.Equal(t, 5, result.usage.OutputTokens)
require.Equal(t, 3, result.usage.CacheReadInputTokens)
require.Equal(t, 4, result.usage.CacheCreationInputTokens)
require.NotNil(t, result.firstTokenMs)
// 确保有透传输出
require.Contains(t, rec.Body.String(), "data:")
}
// --- 流式 happy path 测试 ---
// TestStreamUpstreamResponse_NormalComplete
// 验证正常流式转发完成时数据正确透传、usage 正确收集、clientDisconnect=false
func TestStreamUpstreamResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `event: message_start`)
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: content_block_delta`)
fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: message_delta`)
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`)
fmt.Fprintln(pw, "")
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pr.Close()
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta")
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证数据被透传到客户端
body := rec.Body.String()
require.Contains(t, body, "event: message_start")
require.Contains(t, body, "content_block_delta")
require.Contains(t, body, "message_delta")
}
// TestHandleGeminiStreamingResponse_NormalComplete
// 验证:正常 Gemini 流式转发数据正确透传、usage 正确收集
func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// 第一个 chunk部分内容
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`)
fmt.Fprintln(pw, "")
// 第二个 chunk最终内容+完整 usage
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
require.Equal(t, 8, result.usage.InputTokens)
require.Equal(t, 8, result.usage.OutputTokens)
require.Equal(t, 2, result.usage.CacheReadInputTokens)
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证数据被透传到客户端
body := rec.Body.String()
require.Contains(t, body, "Hello")
require.Contains(t, body, "world")
// 不应包含错误事件
require.NotContains(t, body, "event: error")
}
// TestHandleClaudeStreamingResponse_NormalComplete
// 验证:正常 Claude 流式转发Gemini→Claude 转换),数据正确转换并输出
func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// v1internal 包装格式Gemini 数据嵌套在 "response" 字段下
// ProcessLine 先尝试反序列化为 V1InternalResponse裸格式会导致 Response.UsageMetadata 为空
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
// Gemini→Claude 转换的 usagepromptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
require.Equal(t, 5, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证输出是 Claude SSE 格式processor 会转换)
body := rec.Body.String()
require.Contains(t, body, "event: message_start", "should contain Claude message_start event")
require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event")
// 不应包含错误事件
require.NotContains(t, body, "event: error")
}
// TestHandleGeminiStreamingResponse_ThoughtsTokenCount
// 验证Gemini 流式转发时 thoughtsTokenCount 被计入 OutputTokens
func TestHandleGeminiStreamingResponse_ThoughtsTokenCount(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":30,"thoughtsTokenCount":80,"cachedContentTokenCount":10}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
// promptTokenCount=100, cachedContentTokenCount=10 → InputTokens=90
require.Equal(t, 90, result.usage.InputTokens)
// candidatesTokenCount=30 + thoughtsTokenCount=80 → OutputTokens=110
require.Equal(t, 110, result.usage.OutputTokens)
require.Equal(t, 10, result.usage.CacheReadInputTokens)
}
// TestHandleClaudeStreamingResponse_ThoughtsTokenCount
// 验证Gemini→Claude 流式转换时 thoughtsTokenCount 被计入 OutputTokens
func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":10,"thoughtsTokenCount":25}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro", 0)
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
// promptTokenCount=50 → InputTokens=50
require.Equal(t, 50, result.usage.InputTokens)
// candidatesTokenCount=10 + thoughtsTokenCount=25 → OutputTokens=35
require.Equal(t, 35, result.usage.OutputTokens)
}
// --- 流式客户端断开检测测试 ---
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
// 验证客户端写入失败后streamUpstreamResponse 继续读取上游以收集 usage
func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `event: message_start`)
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: message_delta`)
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`)
fmt.Fprintln(pw, "")
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pr.Close()
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotNil(t, result.usage)
require.Equal(t, 20, result.usage.OutputTokens)
}
// TestStreamUpstreamResponse_ContextCanceled
// 验证context 取消时返回 usage 且标记 clientDisconnect
func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result := svc.streamUpstreamResponse(c, resp, time.Now())
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestStreamUpstreamResponse_Timeout
// 验证:上游超时时返回已收集的 usage
func TestStreamUpstreamResponse_Timeout(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pw.Close()
_ = pr.Close()
require.NotNil(t, result)
require.False(t, result.clientDisconnect)
}
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`)
fmt.Fprintln(pw, "")
// 不关闭 pw → 等待超时
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pw.Close()
_ = pr.Close()
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
// TestHandleGeminiStreamingResponse_ClientDisconnect
// 验证Gemini 流式转发中客户端断开后继续 drain 上游
func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "write_failed")
}
// TestHandleGeminiStreamingResponse_ContextCanceled
// 验证context 取消时不注入错误事件
func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestHandleClaudeStreamingResponse_ClientDisconnect
// 验证Claude 流式转发中客户端断开后继续 drain 上游
func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// v1internal 包装格式
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
2026-03-06 20:11:57 +08:00
// TestHandleClaudeStreamingResponse_EmptyStream
// 验证:上游只返回无法解析的 SSE 行时,触发 UpstreamFailoverError 而不是向客户端发出残缺流
func TestHandleClaudeStreamingResponse_EmptyStream(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// 所有行均为无法 JSON 解析的内容ProcessLine 全部返回 nil
fmt.Fprintln(pw, "data: not-valid-json")
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, "data: also-invalid")
fmt.Fprintln(pw, "")
}()
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
2026-03-06 20:11:57 +08:00
_ = pr.Close()
// 应当返回 UpstreamFailoverError 而非 nil以便上层触发 failover
require.Error(t, err)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, err, &failoverErr)
require.True(t, failoverErr.RetryableOnSameAccount)
// 客户端不应收到任何 SSE 事件(既无 message_start 也无 message_stop
body := rec.Body.String()
require.NotContains(t, body, "event: message_start")
require.NotContains(t, body, "event: message_stop")
require.NotContains(t, body, "event: message_delta")
}
// TestHandleClaudeStreamingResponse_ContextCanceled
// 验证context 取消时不注入错误事件
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
func TestExtractSSEUsage(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
line string
expected ClaudeUsage
}{
{
name: "message_delta with output_tokens",
line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`,
expected: ClaudeUsage{OutputTokens: 42},
},
{
name: "non-data line ignored",
line: `event: message_start`,
expected: ClaudeUsage{},
},
{
name: "top-level usage with all fields",
line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`,
expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
usage := &ClaudeUsage{}
svc.extractSSEUsage(tt.line, usage)
require.Equal(t, tt.expected, *usage)
})
}
}
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
func TestAntigravityClientWriter(t *testing.T) {
t.Run("normal write succeeds", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(c.Writer, flusher, "test")
ok := cw.Write([]byte("hello"))
require.True(t, ok)
require.False(t, cw.Disconnected())
require.Contains(t, rec.Body.String(), "hello")
})
t.Run("write failure marks disconnected", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(fw, flusher, "test")
ok := cw.Write([]byte("hello"))
require.False(t, ok)
require.True(t, cw.Disconnected())
})
t.Run("subsequent writes are no-op", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(fw, flusher, "test")
cw.Write([]byte("first"))
ok := cw.Fprintf("second %d", 2)
require.False(t, ok)
require.True(t, cw.Disconnected())
})
}
// TestUnwrapV1InternalResponse 测试 unwrapV1InternalResponse 的各种输入场景
func TestUnwrapV1InternalResponse(t *testing.T) {
svc := &AntigravityGatewayService{}
// 构造 >50KB 的大型 JSON
largePadding := strings.Repeat("x", 50*1024)
largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding))
largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding)
tests := []struct {
name string
input []byte
expected string
wantErr bool
}{
{
name: "正常 response 包装",
input: []byte(`{"response":{"id":"123","content":"hello"}}`),
expected: `{"id":"123","content":"hello"}`,
},
{
name: "无 response 透传",
input: []byte(`{"id":"456"}`),
expected: `{"id":"456"}`,
},
{
name: "空 JSON",
input: []byte(`{}`),
expected: `{}`,
},
{
name: "response 为 null",
input: []byte(`{"response":null}`),
expected: `null`,
},
{
name: "response 为基础类型 string",
input: []byte(`{"response":"hello"}`),
expected: `"hello"`,
},
{
name: "非法 JSON",
input: []byte(`not json`),
expected: `not json`,
},
{
name: "嵌套 response 只解一层",
input: []byte(`{"response":{"response":{"inner":true}}}`),
expected: `{"response":{"inner":true}}`,
},
{
name: "大型 JSON >50KB",
input: largeInput,
expected: largeExpected,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := svc.unwrapV1InternalResponse(tt.input)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.expected, strings.TrimSpace(string(got)))
})
}
}
// --- unwrapV1InternalResponse benchmark 对照组 ---
// unwrapV1InternalResponseOld 旧实现Unmarshal+Marshal 双重开销(仅用于 benchmark 对照)
func unwrapV1InternalResponseOld(body []byte) ([]byte, error) {
var outer map[string]any
if err := json.Unmarshal(body, &outer); err != nil {
return nil, err
}
if resp, ok := outer["response"]; ok {
return json.Marshal(resp)
}
return body, nil
}
func BenchmarkUnwrapV1Internal_Old_Small(b *testing.B) {
body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = unwrapV1InternalResponseOld(body)
}
}
func BenchmarkUnwrapV1Internal_New_Small(b *testing.B) {
body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`)
svc := &AntigravityGatewayService{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = svc.unwrapV1InternalResponse(body)
}
}
func BenchmarkUnwrapV1Internal_Old_Large(b *testing.B) {
body := generateLargeUnwrapJSON(10 * 1024) // ~10KB
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = unwrapV1InternalResponseOld(body)
}
}
func BenchmarkUnwrapV1Internal_New_Large(b *testing.B) {
body := generateLargeUnwrapJSON(10 * 1024) // ~10KB
svc := &AntigravityGatewayService{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = svc.unwrapV1InternalResponse(body)
}
}
// generateLargeUnwrapJSON 生成指定最小大小的包含 response 包装的 JSON
func generateLargeUnwrapJSON(minSize int) []byte {
parts := make([]map[string]string, 0)
current := 0
for current < minSize {
text := fmt.Sprintf("这是第 %d 段内容,用于填充 JSON 到目标大小。", len(parts)+1)
parts = append(parts, map[string]string{"text": text})
current += len(text) + 20 // 估算 JSON 编码开销
}
inner := map[string]any{
"candidates": []map[string]any{
{"content": map[string]any{"parts": parts}},
},
"usageMetadata": map[string]any{
"promptTokenCount": 100,
"candidatesTokenCount": 50,
},
}
outer := map[string]any{"response": inner}
b, _ := json.Marshal(outer)
return b
}