mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-26 01:14:47 +08:00
add test for fix #935
This commit is contained in:
@@ -134,6 +134,43 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
type queuedHTTPUpstreamStub struct {
|
||||
responses []*http.Response
|
||||
errors []error
|
||||
requestBodies [][]byte
|
||||
callCount int
|
||||
}
|
||||
|
||||
func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
if req != nil && req.Body != nil {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
s.requestBodies = append(s.requestBodies, body)
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
} else {
|
||||
s.requestBodies = append(s.requestBodies, nil)
|
||||
}
|
||||
|
||||
idx := s.callCount
|
||||
s.callCount++
|
||||
|
||||
var resp *http.Response
|
||||
if idx < len(s.responses) {
|
||||
resp = s.responses[idx]
|
||||
}
|
||||
var err error
|
||||
if idx < len(s.errors) {
|
||||
err = s.errors[idx]
|
||||
}
|
||||
if resp == nil && err == nil {
|
||||
return nil, errors.New("unexpected upstream call")
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) {
|
||||
return s.Do(req, proxyURL, accountID, concurrency)
|
||||
}
|
||||
|
||||
type antigravitySettingRepoStub struct{}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
@@ -556,6 +593,92 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
|
||||
{"role": "model", "parts": []map[string]any{{"functionCall": map[string]any{"name": "toolA", "args": map[string]any{"x": 1}}, "thoughtSignature": "sig_bad_2"}}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
|
||||
secondRespBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"req-sig-1"},
|
||||
},
|
||||
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
|
||||
},
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req-sig-2"},
|
||||
},
|
||||
Body: io.NopCloser(bytes.NewReader(secondRespBody)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
const originalModel = "gemini-3.1-pro-preview"
|
||||
const mappedModel = "gemini-3.1-pro-high"
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Name: "acc-gemini-signature",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
"model_mapping": map[string]any{
|
||||
originalModel: mappedModel,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
|
||||
|
||||
firstReq := string(upstream.requestBodies[0])
|
||||
secondReq := string(upstream.requestBodies[1])
|
||||
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_1"`)
|
||||
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_2"`)
|
||||
require.Contains(t, secondReq, `"thoughtSignature":"skip_thought_signature_validator"`)
|
||||
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_1"`)
|
||||
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_2"`)
|
||||
|
||||
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, events)
|
||||
require.Equal(t, "signature_error", events[0].Kind)
|
||||
}
|
||||
|
||||
// TestStreamUpstreamResponse_UsageAndFirstToken
|
||||
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
||||
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user