mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-20 14:44:45 +08:00
Merge release/custom-0.1.83 into release/custom-0.1.84
This commit is contained in:
@@ -50,6 +50,7 @@ type AdminService interface {
|
||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
|
||||
|
||||
// Proxy management
|
||||
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
|
||||
@@ -1706,6 +1707,11 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
|
||||
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
|
||||
if s.proxyLatencyCache == nil || len(proxies) == 0 {
|
||||
return
|
||||
|
||||
@@ -85,7 +85,6 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
||||
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
|
||||
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||||
)
|
||||
@@ -1311,6 +1310,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
||||
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
||||
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
||||
billingModel := mappedModel
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
@@ -1372,6 +1372,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
ForceCacheBilling: switchErr.IsStickySession,
|
||||
}
|
||||
}
|
||||
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
|
||||
if c.Request.Context().Err() != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response")
|
||||
}
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
}
|
||||
resp := result.resp
|
||||
@@ -1620,7 +1624,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Model: billingModel, // 使用映射模型用于计费和日志
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
@@ -1974,6 +1978,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if mappedModel == "" {
|
||||
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
|
||||
}
|
||||
billingModel := mappedModel
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
@@ -2044,6 +2049,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
ForceCacheBilling: switchErr.IsStickySession,
|
||||
}
|
||||
}
|
||||
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
|
||||
if c.Request.Context().Err() != nil {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Client disconnected before upstream response")
|
||||
}
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
resp := result.resp
|
||||
@@ -2199,7 +2208,7 @@ handleSuccess:
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Model: billingModel,
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
@@ -2644,7 +2653,16 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
||||
defaultDur := s.getDefaultRateLimitDuration()
|
||||
|
||||
// 尝试解析模型 key 并设置模型级限流
|
||||
modelKey := resolveAntigravityModelKey(requestedModel)
|
||||
//
|
||||
// 注意:requestedModel 可能是“映射前”的请求模型名(例如 claude-opus-4-6),
|
||||
// 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。
|
||||
// 因此这里必须写入最终模型 key,确保后续调度能正确避开已限流模型。
|
||||
modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel)
|
||||
if strings.TrimSpace(modelKey) == "" {
|
||||
// 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过),
|
||||
// 保持旧行为作为兜底,避免完全丢失模型级限流记录。
|
||||
modelKey = resolveAntigravityModelKey(requestedModel)
|
||||
}
|
||||
if modelKey != "" {
|
||||
ra := s.resolveResetTime(resetAt, defaultDur)
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
|
||||
@@ -3875,7 +3893,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
||||
return nil, fmt.Errorf("missing model")
|
||||
}
|
||||
originalModel := claudeReq.Model
|
||||
billingModel := originalModel
|
||||
|
||||
// 构建上游请求 URL
|
||||
upstreamURL := baseURL + "/v1/messages"
|
||||
@@ -3928,7 +3945,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
Model: originalModel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -3969,7 +3986,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
||||
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
Model: originalModel,
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: duration,
|
||||
FirstTokenMs: firstTokenMs,
|
||||
|
||||
@@ -133,6 +133,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
type antigravitySettingRepoStub struct{}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
@@ -159,8 +189,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
@@ -417,6 +448,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
|
||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||
}
|
||||
|
||||
// TestAntigravityGatewayService_Forward_BillsWithMappedModel
|
||||
// 验证:Antigravity Claude 转发返回的计费模型使用映射后的模型
|
||||
func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "hello"},
|
||||
},
|
||||
"max_tokens": 16,
|
||||
"stream": true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-bill-1"}},
|
||||
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
const mappedModel = "gemini-3-pro-high"
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Name: "acc-forward-billing",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-5": mappedModel,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
}
|
||||
|
||||
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
|
||||
// 验证:Antigravity Gemini 转发返回的计费模型使用映射后的模型
|
||||
func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-bill-2"}},
|
||||
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
const mappedModel = "gemini-3-pro-high"
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Name: "acc-gemini-billing",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-2.5-flash": mappedModel,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
}
|
||||
|
||||
// --- 流式 happy path 测试 ---
|
||||
|
||||
// TestStreamUpstreamResponse_NormalComplete
|
||||
|
||||
@@ -191,6 +191,22 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
|
||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
func TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 20, Name: "acc-20", Platform: PlatformAntigravity}
|
||||
|
||||
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ 走模型级限流兜底
|
||||
// 场景:requestedModel 会被默认映射到 Antigravity 最终模型(例如 claude-opus-4-6 -> claude-opus-4-6-thinking)
|
||||
body := buildGeminiRateLimitBody("5s")
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-opus-4-6", 0, "", false)
|
||||
|
||||
require.Nil(t, result)
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "claude-opus-4-6-thinking", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
|
||||
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
|
||||
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
|
||||
|
||||
@@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
|
||||
v, exists := c.Get(OpsSkipPassthroughKey)
|
||||
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
||||
boolVal, ok := v.(bool)
|
||||
assert.True(t, ok, "value should be bool")
|
||||
assert.True(t, ok, "value should be a bool")
|
||||
assert.True(t, boolVal)
|
||||
}
|
||||
|
||||
|
||||
@@ -890,6 +890,55 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *t
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiAPIKeyModelMappingFilter(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Priority: 1,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Priority: 2,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-flash": "gemini-2.5-flash"}},
|
||||
},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-flash", nil, PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "应过滤不支持请求模型的 APIKey 账号")
|
||||
|
||||
acc, err = svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-3-pro-preview", nil, PlatformGemini)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "supporting model")
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(50)
|
||||
@@ -1065,6 +1114,36 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-无映射配置-支持所有模型",
|
||||
account: &Account{Platform: PlatformGemini, Type: AccountTypeAPIKey},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-有映射配置-只支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
|
||||
},
|
||||
},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-有映射配置-支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
|
||||
},
|
||||
},
|
||||
model: "gemini-2.5-pro",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -2549,10 +2549,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
}
|
||||
// Gemini API Key 账户直接透传,由上游判断模型是否支持
|
||||
if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey {
|
||||
return true
|
||||
}
|
||||
// 其他平台使用账户的模型支持检查
|
||||
return account.IsModelSupported(requestedModel)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user