Merge pull request #523 from touwaeriol/feat/antigravity-improvements

feat: Antigravity improvements and scope-to-model rate limiting refactor
This commit is contained in:
Wesley Liddick
2026-02-09 09:38:55 +08:00
committed by GitHub
52 changed files with 4565 additions and 2053 deletions

View File

@@ -154,7 +154,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)

View File

@@ -103,6 +103,7 @@ require (
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect

View File

@@ -213,6 +213,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=

View File

@@ -2,11 +2,6 @@ package dto
import "time"
type ScopeRateLimitInfo struct {
ResetAt time.Time `json:"reset_at"`
RemainingSec int64 `json:"remaining_sec"`
}
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
@@ -129,9 +124,6 @@ type Account struct {
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
// Antigravity scope 级限流状态(从 extra 提取)
ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"`

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
@@ -114,7 +115,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
@@ -203,6 +204,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则使用分组平台
@@ -335,7 +341,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
@@ -344,6 +350,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// 错误响应已在Forward中处理这里只记录日志
@@ -530,7 +541,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
@@ -539,6 +550,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// 错误响应已在Forward中处理这里只记录日志
@@ -801,6 +817,27 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
// sleepFailoverDelay 账号切换线性递增延时第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
delay := time.Duration(switchCount-1) * time.Second
if delay <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
@@ -934,7 +971,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
@@ -962,6 +999,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
// 计算粘性会话 hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 选择支持该模型的账号

View File

@@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
@@ -30,13 +31,6 @@ import (
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
@@ -239,7 +233,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
parsedReq, _ := service.ParseGatewayRequest(body, domain.PlatformGemini)
if parsedReq != nil {
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
}
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash
@@ -258,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var geminiDigestChain string
var geminiPrefixHash string
var geminiSessionUUID string
var matchedDigestChain string
useDigestFallback := sessionBoundAccountID == 0
if useDigestFallback {
@@ -284,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
)
// 查找会话
foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession(
foundUUID, foundAccountID, foundMatchedChain, found := h.gatewayService.FindGeminiSession(
c.Request.Context(),
derefGroupID(apiKey.GroupID),
geminiPrefixHash,
geminiDigestChain,
)
if found {
matchedDigestChain = foundMatchedChain
sessionBoundAccountID = foundAccountID
geminiSessionUUID = foundUUID
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
@@ -316,7 +319,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
@@ -344,10 +346,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
// 为避免第一次转发就 400这里做一次确定性清理让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
@@ -422,7 +424,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if failoverErr.ForceCacheBilling {
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
@@ -433,6 +435,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
lastFailoverErr = failoverErr
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// ForwardNative already wrote the response
@@ -453,6 +460,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
geminiDigestChain,
geminiSessionUUID,
account.ID,
matchedDigestChain,
); err != nil {
log.Printf("[Gemini] Failed to save digest session: %v", err)
}

View File

@@ -798,53 +798,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return nil
}
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
now := time.Now().UTC()
payload := map[string]string{
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
raw, err := json.Marshal(payload)
if err != nil {
return err
}
scopeKey := string(scope)
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
`UPDATE accounts SET
extra = jsonb_set(
jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
ARRAY['antigravity_quota_scopes', $1]::text[],
$2::jsonb,
true
),
updated_at = NOW(),
last_used_at = NOW()
WHERE id = $3 AND deleted_at IS NULL`,
scopeKey,
raw,
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
if scope == "" {
return nil

View File

@@ -11,63 +11,6 @@ import (
const stickySessionPrefix = "sticky_session:"
// Gemini Trie Lua 脚本
const (
// geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本
// KEYS[1] = trie key
// ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d")
// ARGV[2] = TTL seconds (用于刷新)
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
// 查找成功时自动刷新 TTL防止活跃会话意外过期
geminiTrieFindScript = `
local chain = ARGV[1]
local ttl = tonumber(ARGV[2])
local lastMatch = nil
local path = ""
for part in string.gmatch(chain, "[^-]+") do
path = path == "" and part or path .. "-" .. part
local val = redis.call('HGET', KEYS[1], path)
if val and val ~= "" then
lastMatch = val
end
end
if lastMatch then
redis.call('EXPIRE', KEYS[1], ttl)
end
return lastMatch
`
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
// KEYS[1] = trie key
// ARGV[1] = digestChain
// ARGV[2] = value (uuid:accountID)
// ARGV[3] = TTL seconds
geminiTrieSaveScript = `
local chain = ARGV[1]
local value = ARGV[2]
local ttl = tonumber(ARGV[3])
local path = ""
for part in string.gmatch(chain, "[^-]+") do
path = path == "" and part or path .. "-" .. part
end
redis.call('HSET', KEYS[1], path, value)
redis.call('EXPIRE', KEYS[1], ttl)
return "OK"
`
)
// 模型负载统计相关常量
const (
modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀
modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀
modelLoadTTL = 24 * time.Hour // 调用次数 TTL24 小时无调用后清零)
modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL
)
type gatewayCache struct {
rdb *redis.Client
}
@@ -108,171 +51,3 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}
// ============ Antigravity 模型负载统计方法 ============
// modelLoadKey 构建模型调用次数 key
// 格式: ag:model_load:{accountID}:{model}
func modelLoadKey(accountID int64, model string) string {
return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model)
}
// modelLastUsedKey 构建模型最后调度时间 key
// 格式: ag:model_last_used:{accountID}:{model}
func modelLastUsedKey(accountID int64, model string) string {
return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model)
}
// IncrModelCallCount 增加模型调用次数并更新最后调度时间
// 返回更新后的调用次数
func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
loadKey := modelLoadKey(accountID, model)
lastUsedKey := modelLastUsedKey(accountID, model)
pipe := c.rdb.Pipeline()
incrCmd := pipe.Incr(ctx, loadKey)
pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL
pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, err
}
return incrCmd.Val(), nil
}
// GetModelLoadBatch 批量获取账号的模型负载信息
func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) {
if len(accountIDs) == 0 {
return make(map[int64]*service.ModelLoadInfo), nil
}
loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model)
return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil
}
// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作
func (c *gatewayCache) pipelineModelLoadGet(
ctx context.Context,
accountIDs []int64,
model string,
) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) {
pipe := c.rdb.Pipeline()
loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
for _, id := range accountIDs {
loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model))
lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model))
}
_, _ = pipe.Exec(ctx) // 忽略错误key 不存在是正常的
return loadCmds, lastUsedCmds
}
// parseModelLoadResults 解析 Pipeline 结果
func (c *gatewayCache) parseModelLoadResults(
accountIDs []int64,
loadCmds map[int64]*redis.StringCmd,
lastUsedCmds map[int64]*redis.StringCmd,
) map[int64]*service.ModelLoadInfo {
result := make(map[int64]*service.ModelLoadInfo, len(accountIDs))
for _, id := range accountIDs {
result[id] = &service.ModelLoadInfo{
CallCount: getInt64OrZero(loadCmds[id]),
LastUsedAt: getTimeOrZero(lastUsedCmds[id]),
}
}
return result
}
// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0
func getInt64OrZero(cmd *redis.StringCmd) int64 {
val, _ := cmd.Int64()
return val
}
// getTimeOrZero 从 StringCmd 获取 time.Time失败返回零值
func getTimeOrZero(cmd *redis.StringCmd) time.Time {
val, err := cmd.Int64()
if err != nil {
return time.Time{}
}
return time.Unix(val, 0)
}
// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============
// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询)
// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL
func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" {
return "", 0, false
}
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
// 使用 Lua 脚本在 Redis 端执行 Trie 查找O(L) 次 HGET1 次网络往返
// 查找成功时自动刷新 TTL防止活跃会话意外过期
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
if err != nil || result == nil {
return "", 0, false
}
value, ok := result.(string)
if !ok || value == "" {
return "", 0, false
}
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
return uuid, accountID, ok
}
// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本)
func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" {
return nil
}
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
value := service.FormatGeminiSessionValue(uuid, accountID)
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
}
// ============ Anthropic 会话 Fallback 方法 (复用 Trie 实现) ============
// FindAnthropicSession 查找 Anthropic 会话(复用 Gemini Trie Lua 脚本)
func (c *gatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" {
return "", 0, false
}
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
if err != nil || result == nil {
return "", 0, false
}
value, ok := result.(string)
if !ok || value == "" {
return "", 0, false
}
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
return uuid, accountID, ok
}
// SaveAnthropicSession 保存 Anthropic 会话(复用 Gemini Trie Lua 脚本)
func (c *gatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" {
return nil
}
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
value := service.FormatGeminiSessionValue(uuid, accountID)
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
}

View File

@@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
// ============ Gemini Trie 会话测试 ============
func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() {
groupID := int64(1)
prefixHash := "testprefix"
digestChain := "u:hash1-m:hash2-u:hash3"
uuid := "test-uuid-123"
accountID := int64(42)
// 保存会话
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID)
require.NoError(s.T(), err, "SaveGeminiSession")
// 精确匹配查找
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain)
require.True(s.T(), found, "should find exact match")
require.Equal(s.T(), uuid, foundUUID)
require.Equal(s.T(), accountID, foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() {
groupID := int64(1)
prefixHash := "prefixmatch"
shortChain := "u:a-m:b"
longChain := "u:a-m:b-u:c-m:d"
uuid := "uuid-prefix"
accountID := int64(100)
// 保存短链
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID)
require.NoError(s.T(), err)
// 用长链查找,应该匹配到短链(前缀匹配)
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain)
require.True(s.T(), found, "should find prefix match")
require.Equal(s.T(), uuid, foundUUID)
require.Equal(s.T(), accountID, foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() {
groupID := int64(1)
prefixHash := "longestmatch"
// 保存多个不同长度的链
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1)
require.NoError(s.T(), err)
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2)
require.NoError(s.T(), err)
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3)
require.NoError(s.T(), err)
// 查找更长的链,应该匹配到最长的前缀
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e")
require.True(s.T(), found, "should find longest prefix match")
require.Equal(s.T(), "uuid-long", foundUUID)
require.Equal(s.T(), int64(3), foundAccountID)
// 查找中等长度的链
foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x")
require.True(s.T(), found)
require.Equal(s.T(), "uuid-medium", foundUUID)
require.Equal(s.T(), int64(2), foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() {
groupID := int64(1)
prefixHash := "nomatch"
digestChain := "u:a-m:b"
// 保存一个会话
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1)
require.NoError(s.T(), err)
// 用不同的链查找,应该找不到
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y")
require.False(s.T(), found, "should not find non-matching chain")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() {
groupID := int64(1)
digestChain := "u:a-m:b"
// 保存到 prefixHash1
err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1)
require.NoError(s.T(), err)
// 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离)
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain)
require.False(s.T(), found, "different prefixHash should be isolated")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() {
prefixHash := "sameprefix"
digestChain := "u:a-m:b"
// 保存到 groupID 1
err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1)
require.NoError(s.T(), err)
// 用 groupID 2 查找,应该找不到(分组隔离)
_, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain)
require.False(s.T(), found, "different groupID should be isolated")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() {
groupID := int64(1)
prefixHash := "emptytest"
// 空链不应该保存
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1)
require.NoError(s.T(), err, "empty chain should not error")
// 空链查找应该返回 false
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "")
require.False(s.T(), found, "empty chain should not match")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() {
groupID := int64(1)
prefixHash := "multisession"
// 保存多个不同会话(模拟 1000 个并发会话的场景)
sessions := []struct {
chain string
uuid string
accountID int64
}{
{"u:session1", "uuid-1", 1},
{"u:session2-m:reply2", "uuid-2", 2},
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
}
for _, sess := range sessions {
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID)
require.NoError(s.T(), err)
}
// 验证每个会话都能正确查找
for _, sess := range sessions {
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain)
require.True(s.T(), found, "should find session: %s", sess.chain)
require.Equal(s.T(), sess.uuid, foundUUID)
require.Equal(s.T(), sess.accountID, foundAccountID)
}
// 验证继续对话的场景
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg")
require.True(s.T(), found)
require.Equal(s.T(), "uuid-2", foundUUID)
require.Equal(s.T(), int64(2), foundAccountID)
}
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))

View File

@@ -1,234 +0,0 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// ============ Gateway Cache 模型负载统计集成测试 ============
type GatewayCacheModelLoadSuite struct {
suite.Suite
}
func TestGatewayCacheModelLoadSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheModelLoadSuite))
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(123)
model := "claude-sonnet-4-20250514"
// 首次调用应返回 1
count1, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
// 第二次调用应返回 2
count2, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(2), count2)
// 第三次调用应返回 3
count3, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(3), count3)
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(456)
model1 := "claude-sonnet-4-20250514"
model2 := "claude-opus-4-5-20251101"
// 不同模型应该独立计数
count1, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
count2, err := cache.IncrModelCallCount(ctx, accountID, model2)
require.NoError(t, err)
require.Equal(t, int64(1), count2)
count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
require.Equal(t, int64(2), count1Again)
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
account1 := int64(111)
account2 := int64(222)
model := "gemini-2.5-pro"
// 不同账号应该独立计数
count1, err := cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
count2, err := cache.IncrModelCallCount(ctx, account2, model)
require.NoError(t, err)
require.Equal(t, int64(1), count2)
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model")
require.NoError(t, err)
require.NotNil(t, result)
require.Empty(t, result)
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
// 查询不存在的账号应返回零值
result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514")
require.NoError(t, err)
require.Len(t, result, 2)
require.Equal(t, int64(0), result[9999].CallCount)
require.True(t, result[9999].LastUsedAt.IsZero())
require.Equal(t, int64(0), result[9998].CallCount)
require.True(t, result[9998].LastUsedAt.IsZero())
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(789)
model := "claude-sonnet-4-20250514"
// 先增加调用次数
beforeIncr := time.Now()
_, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
afterIncr := time.Now()
// 获取负载信息
result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model)
require.NoError(t, err)
require.Len(t, result, 1)
loadInfo := result[accountID]
require.NotNil(t, loadInfo)
require.Equal(t, int64(3), loadInfo.CallCount)
require.False(t, loadInfo.LastUsedAt.IsZero())
// LastUsedAt 应该在 beforeIncr 和 afterIncr 之间
require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr))
require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr))
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
model := "claude-opus-4-5-20251101"
account1 := int64(1001)
account2 := int64(1002)
account3 := int64(1003) // 不调用
// account1 调用 2 次
_, err := cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
// account2 调用 5 次
for i := 0; i < 5; i++ {
_, err = cache.IncrModelCallCount(ctx, account2, model)
require.NoError(t, err)
}
// 批量获取
result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model)
require.NoError(t, err)
require.Len(t, result, 3)
require.Equal(t, int64(2), result[account1].CallCount)
require.False(t, result[account1].LastUsedAt.IsZero())
require.Equal(t, int64(5), result[account2].CallCount)
require.False(t, result[account2].LastUsedAt.IsZero())
require.Equal(t, int64(0), result[account3].CallCount)
require.True(t, result[account3].LastUsedAt.IsZero())
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(2001)
model1 := "claude-sonnet-4-20250514"
model2 := "gemini-2.5-pro"
// 对 model1 调用 3 次
for i := 0; i < 3; i++ {
_, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
}
// 获取 model1 的负载
result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1)
require.NoError(t, err)
require.Equal(t, int64(3), result1[accountID].CallCount)
// 获取 model2 的负载(应该为 0
result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2)
require.NoError(t, err)
require.Equal(t, int64(0), result2[accountID].CallCount)
}
// ============ 辅助函数测试 ============
func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() {
t := s.T()
key := modelLoadKey(123, "claude-sonnet-4")
require.Equal(t, "ag:model_load:123:claude-sonnet-4", key)
}
func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() {
t := s.T()
key := modelLastUsedKey(456, "gemini-2.5-pro")
require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key)
}

View File

@@ -1008,10 +1008,6 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return errors.New("not implemented")
}

View File

@@ -50,7 +50,6 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error

View File

@@ -143,10 +143,6 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic("unexpected SetRateLimited call")
}
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
panic("unexpected SetAntigravityQuotaScopeLimit call")
}
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
panic("unexpected SetModelRateLimit call")
}

View File

@@ -2,7 +2,6 @@ package service
import (
"encoding/json"
"strconv"
"strings"
"time"
)
@@ -12,9 +11,6 @@ const (
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL5 分钟)
anthropicSessionTTLSeconds = 300
// anthropicTrieKeyPrefix Anthropic Trie 会话 key 前缀
anthropicTrieKeyPrefix = "anthropic:trie:"
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
)
@@ -68,12 +64,6 @@ func rolePrefix(role string) string {
}
}
// BuildAnthropicTrieKey 构建 Anthropic Trie Redis key
// 格式: anthropic:trie:{groupID}:{prefixHash}
func BuildAnthropicTrieKey(groupID int64, prefixHash string) string {
return anthropicTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {

View File

@@ -236,43 +236,6 @@ func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
}
}
func TestBuildAnthropicTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "anthropic:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "anthropic:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "anthropic:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildAnthropicTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildAnthropicTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
tests := []struct {
name string

File diff suppressed because it is too large Load Diff

View File

@@ -4,17 +4,42 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
type antigravityFailingWriter struct {
gin.ResponseWriter
failAfter int // 允许成功写入的次数,之后所有写入返回错误
writes int
}
func (w *antigravityFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService {
return &AntigravityGatewayService{
settingService: &SettingService{cfg: cfg},
}
}
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
@@ -337,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
// 验证:ForwardGemini 粘性会话切换时UpstreamFailoverError.ForceCacheBilling 应为 true
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies
// that ForwardGemini sets ForceCacheBilling=true for sticky session switch.
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
@@ -391,3 +416,438 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// --- 流式 happy path 测试 ---
// TestStreamUpstreamResponse_NormalComplete
// 验证正常流式转发完成时数据正确透传、usage 正确收集、clientDisconnect=false
func TestStreamUpstreamResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `event: message_start`)
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: content_block_delta`)
fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: message_delta`)
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`)
fmt.Fprintln(pw, "")
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pr.Close()
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta")
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证数据被透传到客户端
body := rec.Body.String()
require.Contains(t, body, "event: message_start")
require.Contains(t, body, "content_block_delta")
require.Contains(t, body, "message_delta")
}
// TestHandleGeminiStreamingResponse_NormalComplete
// 验证:正常 Gemini 流式转发数据正确透传、usage 正确收集
func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// 第一个 chunk部分内容
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`)
fmt.Fprintln(pw, "")
// 第二个 chunk最终内容+完整 usage
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
require.Equal(t, 8, result.usage.InputTokens)
require.Equal(t, 8, result.usage.OutputTokens)
require.Equal(t, 2, result.usage.CacheReadInputTokens)
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证数据被透传到客户端
body := rec.Body.String()
require.Contains(t, body, "Hello")
require.Contains(t, body, "world")
// 不应包含错误事件
require.NotContains(t, body, "event: error")
}
// TestHandleClaudeStreamingResponse_NormalComplete
// 验证:正常 Claude 流式转发Gemini→Claude 转换),数据正确转换并输出
func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// v1internal 包装格式Gemini 数据嵌套在 "response" 字段下
// ProcessLine 先尝试反序列化为 V1InternalResponse裸格式会导致 Response.UsageMetadata 为空
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
// Gemini→Claude 转换的 usagepromptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
require.Equal(t, 5, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证输出是 Claude SSE 格式processor 会转换)
body := rec.Body.String()
require.Contains(t, body, "event: message_start", "should contain Claude message_start event")
require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event")
// 不应包含错误事件
require.NotContains(t, body, "event: error")
}
// --- 流式客户端断开检测测试 ---
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
// 验证客户端写入失败后streamUpstreamResponse 继续读取上游以收集 usage
func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `event: message_start`)
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: message_delta`)
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`)
fmt.Fprintln(pw, "")
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pr.Close()
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotNil(t, result.usage)
require.Equal(t, 20, result.usage.OutputTokens)
}
// TestStreamUpstreamResponse_ContextCanceled
// 验证context 取消时返回 usage 且标记 clientDisconnect
func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result := svc.streamUpstreamResponse(c, resp, time.Now())
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestStreamUpstreamResponse_Timeout
// 验证:上游超时时返回已收集的 usage
func TestStreamUpstreamResponse_Timeout(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pw.Close()
_ = pr.Close()
require.NotNil(t, result)
require.False(t, result.clientDisconnect)
}
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`)
fmt.Fprintln(pw, "")
// 不关闭 pw → 等待超时
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pw.Close()
_ = pr.Close()
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
// TestHandleGeminiStreamingResponse_ClientDisconnect
// 验证Gemini 流式转发中客户端断开后继续 drain 上游
func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "write_failed")
}
// TestHandleGeminiStreamingResponse_ContextCanceled
// 验证context 取消时不注入错误事件
func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestHandleClaudeStreamingResponse_ClientDisconnect
// 验证Claude 流式转发中客户端断开后继续 drain 上游
func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// v1internal 包装格式
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
// TestHandleClaudeStreamingResponse_ContextCanceled
// 验证context 取消时不注入错误事件
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
func TestExtractSSEUsage(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
line string
expected ClaudeUsage
}{
{
name: "message_delta with output_tokens",
line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`,
expected: ClaudeUsage{OutputTokens: 42},
},
{
name: "non-data line ignored",
line: `event: message_start`,
expected: ClaudeUsage{},
},
{
name: "top-level usage with all fields",
line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`,
expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
usage := &ClaudeUsage{}
svc.extractSSEUsage(tt.line, usage)
require.Equal(t, tt.expected, *usage)
})
}
}
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
func TestAntigravityClientWriter(t *testing.T) {
t.Run("normal write succeeds", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(c.Writer, flusher, "test")
ok := cw.Write([]byte("hello"))
require.True(t, ok)
require.False(t, cw.Disconnected())
require.Contains(t, rec.Body.String(), "hello")
})
t.Run("write failure marks disconnected", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(fw, flusher, "test")
ok := cw.Write([]byte("hello"))
require.False(t, ok)
require.True(t, cw.Disconnected())
})
t.Run("subsequent writes are no-op", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(fw, flusher, "test")
cw.Write([]byte("first"))
ok := cw.Fprintf("second %d", 2)
require.False(t, ok)
require.True(t, cw.Disconnected())
})
}

View File

@@ -2,63 +2,23 @@ package service
import (
"context"
"slices"
"strings"
"time"
)
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
// AntigravityQuotaScope 表示 Antigravity 的配额域
type AntigravityQuotaScope string
const (
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
)
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool {
if len(supportedScopes) == 0 {
// 未配置时默认全部支持
return true
}
supported := slices.Contains(supportedScopes, string(scope))
return supported
}
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
return resolveAntigravityQuotaScope(requestedModel)
}
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
model := normalizeAntigravityModelName(requestedModel)
if model == "" {
return "", false
}
switch {
case strings.HasPrefix(model, "claude-"):
return AntigravityQuotaScopeClaude, true
case strings.HasPrefix(model, "gemini-"):
if isImageGenerationModel(model) {
return AntigravityQuotaScopeGeminiImage, true
}
return AntigravityQuotaScopeGeminiText, true
default:
return "", false
}
}
func normalizeAntigravityModelName(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
normalized = strings.TrimPrefix(normalized, "models/")
return normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
// resolveAntigravityModelKey 根据请求的模型名解析限流 key
// 返回空字符串表示无法解析
func resolveAntigravityModelKey(requestedModel string) string {
return normalizeAntigravityModelName(requestedModel)
}
// IsSchedulableForModel 结合模型级限流判断是否可调度。
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
@@ -74,107 +34,20 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
return false
}
if a.Platform != PlatformAntigravity {
return true
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return true
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return true
}
now := time.Now()
return !now.Before(*resetAt)
return true
}
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
if a == nil || a.Extra == nil || scope == "" {
return nil
}
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
if !ok {
return nil
}
rawScope, ok := rawScopes[string(scope)].(map[string]any)
if !ok {
return nil
}
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
if !ok || strings.TrimSpace(resetAtRaw) == "" {
return nil
}
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
if err != nil {
return nil
}
return &resetAt
}
var antigravityAllScopes = []AntigravityQuotaScope{
AntigravityQuotaScopeClaude,
AntigravityQuotaScopeGeminiText,
AntigravityQuotaScopeGeminiImage,
}
func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
if a == nil || a.Platform != PlatformAntigravity {
return nil
}
now := time.Now()
result := make(map[string]int64)
for _, scope := range antigravityAllScopes {
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt != nil && now.Before(*resetAt) {
remainingSec := int64(time.Until(*resetAt).Seconds())
if remainingSec > 0 {
result[string(scope)] = remainingSec
}
}
}
if len(result) == 0 {
return nil
}
return result
}
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
// 返回 0 表示未限流或已过期
func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration {
if a == nil || a.Platform != PlatformAntigravity {
return 0
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return 0
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return 0
}
if remaining := time.Until(*resetAt); remaining > 0 {
return remaining
}
return 0
}
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel)
if modelRemaining > scopeRemaining {
return modelRemaining
}
return scopeRemaining
return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
}

View File

@@ -59,12 +59,6 @@ func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string,
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
type scopeLimitCall struct {
accountID int64
scope AntigravityQuotaScope
resetAt time.Time
}
type rateLimitCall struct {
accountID int64
resetAt time.Time
@@ -78,16 +72,10 @@ type modelRateLimitCall struct {
type stubAntigravityAccountRepo struct {
AccountRepository
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
modelRateLimitCalls []modelRateLimitCall
}
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
return nil
}
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
return nil
@@ -131,10 +119,9 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCalled = true
return nil
},
@@ -155,23 +142,6 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
require.Equal(t, base2, available[0])
}
func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) {
// 分区限流始终开启,不再支持通过环境变量关闭
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("3s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
require.Len(t, repo.scopeCalls, 1)
require.Empty(t, repo.rateCalls)
call := repo.scopeCalls[0]
require.Equal(t, account.ID, call.accountID)
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
}
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
@@ -189,7 +159,7 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
@@ -200,22 +170,22 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走模型级限流兜底
func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity}
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reasonscope 限流
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason走模型级限流兜底
body := buildGeminiRateLimitBody("5s")
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
// 不应该触发模型限流,应该走 scope 限流
// handleModelRateLimit 不会处理(因为没有 RATE_LIMIT_EXCEEDED
// 但 429 兜底逻辑会使用 requestedModel 设置模型级限流
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Len(t, repo.scopeCalls, 1)
require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
@@ -235,7 +205,7 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
@@ -263,12 +233,11 @@ func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 503 非模型限流不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit")
require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit")
require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit")
}
@@ -281,12 +250,11 @@ func TestHandleUpstreamError_503_EmptyBody(t *testing.T) {
// 503 + 空响应体 → 不做任何处理
body := []byte(`{}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 503 空响应不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Empty(t, repo.scopeCalls)
require.Empty(t, repo.rateCalls)
}
@@ -307,15 +275,7 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
account.RateLimitResetAt = nil
account.Extra = map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future.Format(time.RFC3339),
},
},
}
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
}
@@ -635,6 +595,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 7 * time.Second,
modelName: "gemini-pro",
},
{
@@ -652,6 +613,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 39 * time.Second,
modelName: "gemini-3-pro-high",
},
{
@@ -669,6 +631,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 30 * time.Second,
modelName: "gemini-2.5-flash",
},
{
@@ -686,6 +649,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 30 * time.Second,
modelName: "claude-sonnet-4-5",
},
}
@@ -704,6 +668,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
}
}
if shouldRateLimit && tt.minWait > 0 {
if wait < tt.minWait {
t.Errorf("rate limit wait = %v, want >= %v", wait, tt.minWait)
}
}
if (shouldRetry || shouldRateLimit) && model != tt.modelName {
t.Errorf("modelName = %q, want %q", model, tt.modelName)
}
@@ -832,7 +801,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) {
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
@@ -875,7 +844,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) {
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})

View File

@@ -75,7 +75,7 @@ func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) {
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -127,7 +127,7 @@ func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) {
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -194,7 +194,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -269,7 +269,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
httpUpstream: upstream,
accountRepo: repo,
isStickySession: false,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -331,7 +331,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -387,7 +387,7 @@ func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -436,7 +436,7 @@ func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T)
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -487,7 +487,7 @@ func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) {
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -548,7 +548,7 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
@@ -604,7 +604,7 @@ func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) {
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -662,7 +662,7 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -754,7 +754,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t *
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-abc",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -842,7 +842,7 @@ func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSessio
isStickySession: false,
groupID: 42,
sessionHash: "", // 非粘性会话sessionHash 为空
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -918,7 +918,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic(
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-nil-cache",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -983,7 +983,7 @@ func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession(
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-success",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -1043,7 +1043,7 @@ func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t *
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-long-delay",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -1108,7 +1108,7 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t
isStickySession: true,
groupID: 99,
sessionHash: "sticky-net-error",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -1188,7 +1188,7 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
isStickySession: true,
groupID: 77,
sessionHash: "sticky-503-short",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -1278,7 +1278,7 @@ func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagat
isStickySession: true,
groupID: 55,
sessionHash: "sticky-loop-test",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
@@ -1296,4 +1296,4 @@ func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagat
require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry")
require.Equal(t, int64(55), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash)
}
}

View File

@@ -0,0 +1,69 @@
package service
import (
"strconv"
"strings"
"time"
gocache "github.com/patrickmn/go-cache"
)
// digestSessionTTL 摘要会话默认 TTL
const digestSessionTTL = 5 * time.Minute
// sessionEntry flat cache 条目
type sessionEntry struct {
uuid string
accountID int64
}
// DigestSessionStore 内存摘要会话存储flat cache 实现)
// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry
type DigestSessionStore struct {
cache *gocache.Cache
}
// NewDigestSessionStore 创建内存摘要会话存储
func NewDigestSessionStore() *DigestSessionStore {
return &DigestSessionStore{
cache: gocache.New(digestSessionTTL, time.Minute),
}
}
// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain用于删旧 key。
func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) {
if digestChain == "" {
return
}
ns := buildNS(groupID, prefixHash)
s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration)
if oldDigestChain != "" && oldDigestChain != digestChain {
s.cache.Delete(ns + oldDigestChain)
}
}
// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。
func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" {
return "", 0, "", false
}
ns := buildNS(groupID, prefixHash)
chain := digestChain
for {
if val, ok := s.cache.Get(ns + chain); ok {
if e, ok := val.(*sessionEntry); ok {
return e.uuid, e.accountID, chain, true
}
}
i := strings.LastIndex(chain, "-")
if i < 0 {
return "", 0, "", false
}
chain = chain[:i]
}
}
// buildNS 构建 namespace 前缀
func buildNS(groupID int64, prefixHash string) string {
return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|"
}

View File

@@ -0,0 +1,312 @@
//go:build unit
package service
import (
"fmt"
"sync"
"testing"
"time"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDigestSessionStore_SaveAndFind(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_PrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
// 保存短链
store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "")
// 用长链查找,应前缀匹配到短链
uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-short", uuid)
assert.Equal(t, int64(10), accountID)
assert.Equal(t, "u:a-m:b", matchedChain)
}
func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a", "uuid-1", 1, "")
store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "")
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "")
// 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的)
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "uuid-3", uuid)
assert.Equal(t, int64(3), accountID)
// 查找中等长度,应匹配到 "u:a-m:b"
uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) {
store := NewDigestSessionStore()
// 第一轮:保存 "u:a-m:b"
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 第二轮:同一 uuid 保存更长的链,传入旧 chain
store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b")
// 旧链 "u:a-m:b" 应已被删除
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found, "old chain should be deleted")
// 新链应能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) {
store := NewDigestSessionStore()
// 相同系统提示词,不同用户提示词
store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "")
store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(200), accountID)
}
func TestDigestSessionStore_NoMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 完全不同的 chain
_, _, _, found := store.Find(1, "prefix", "u:x-m:y")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "")
// 不同 prefixHash 应隔离
_, _, _, found := store.Find(1, "prefix2", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentGroupID(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 不同 groupID 应隔离
_, _, _, found := store.Find(2, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_EmptyDigestChain(t *testing.T) {
store := NewDigestSessionStore()
// 空链不应保存
store.Save(1, "prefix", "", "uuid-1", 100, "")
_, _, _, found := store.Find(1, "prefix", "")
assert.False(t, found)
}
func TestDigestSessionStore_TTLExpiration(t *testing.T) {
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 立即应该能找到
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
// 等待过期 + 清理周期
time.Sleep(300 * time.Millisecond)
// 过期后应找不到
_, _, _, found = store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_ConcurrentSafety(t *testing.T) {
store := NewDigestSessionStore()
var wg sync.WaitGroup
const goroutines = 50
const operations = 100
wg.Add(goroutines)
for g := 0; g < goroutines; g++ {
go func(id int) {
defer wg.Done()
prefix := fmt.Sprintf("prefix-%d", id%5)
for i := 0; i < operations; i++ {
chain := fmt.Sprintf("u:%d-m:%d", id, i)
uuid := fmt.Sprintf("uuid-%d-%d", id, i)
store.Save(1, prefix, chain, uuid, int64(id), "")
store.Find(1, prefix, chain)
}
}(g)
}
wg.Wait()
}
func TestDigestSessionStore_MultipleSessions(t *testing.T) {
store := NewDigestSessionStore()
sessions := []struct {
chain string
uuid string
accountID int64
}{
{"u:session1", "uuid-1", 1},
{"u:session2-m:reply2", "uuid-2", 2},
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
}
for _, sess := range sessions {
store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "")
}
// 验证每个会话都能正确查找
for _, sess := range sessions {
uuid, accountID, _, found := store.Find(1, "prefix", sess.chain)
require.True(t, found, "should find session: %s", sess.chain)
assert.Equal(t, sess.uuid, uuid)
assert.Equal(t, sess.accountID, accountID)
}
// 验证继续对话的场景
uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_Performance1000Sessions(t *testing.T) {
store := NewDigestSessionStore()
// 插入 1000 个会话
for i := 0; i < 1000; i++ {
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
// 查找性能测试
start := time.Now()
const lookups = 10000
for i := 0; i < lookups; i++ {
idx := i % 1000
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx)
_, _, _, found := store.Find(1, "prefix", chain)
assert.True(t, found)
}
elapsed := time.Since(start)
t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups)
}
func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "")
// 精确匹配
_, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
// 前缀匹配(截断后命中)
_, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
}
func TestDigestSessionStore_CacheItemCountStable(t *testing.T) {
store := NewDigestSessionStore()
// 模拟 100 个独立会话,每个进行 10 轮对话
// 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key
for conv := 0; conv < 100; conv++ {
var prevMatchedChain string
for round := 0; round < 10; round++ {
chain := fmt.Sprintf("s:sys-u:user%d", conv)
for r := 0; r < round; r++ {
chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1)
}
uuid := fmt.Sprintf("uuid-conv%d", conv)
_, _, matched, _ := store.Find(1, "prefix", chain)
store.Save(1, "prefix", chain, uuid, int64(conv), matched)
prevMatchedChain = matched
_ = prevMatchedChain
}
}
// 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key
// 允许少量并发残留,但绝不能接近 100×10=1000
itemCount := store.cache.ItemCount()
assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount)
t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount)
}
func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) {
// 使用极短 TTL 验证大量写入后 cache 能被清理
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
// 插入 500 个不同的 key无 oldDigestChain模拟最坏场景全是新会话首轮
for i := 0; i < 500; i++ {
chain := fmt.Sprintf("u:user%d", i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
assert.Equal(t, 500, store.cache.ItemCount())
// 等待 TTL + 清理周期
time.Sleep(300 * time.Millisecond)
assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up")
}
func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) {
store := NewDigestSessionStore()
// 保存 chain
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 用户重发相同消息oldDigestChain == digestChain不应删掉刚设置的 key
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b")
// 仍然能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}

View File

@@ -0,0 +1,366 @@
//go:build unit
package service
import (
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mocks (scoped to this file by naming convention)
// ---------------------------------------------------------------------------
// epFixedUpstream returns a fixed response for every request.
type epFixedUpstream struct {
statusCode int
body string
calls int
}
func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
u.calls++
return &http.Response{
StatusCode: u.statusCode,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(u.body)),
}, nil
}
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
// epAccountRepo records SetTempUnschedulable / SetError calls.
type epAccountRepo struct {
mockAccountRepoForGemini
tempCalls int
setErrCalls int
}
func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.tempCalls++
return nil
}
func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error {
r.setErrCalls++
return nil
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
func saveAndSetBaseURLs(t *testing.T) {
t.Helper()
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvail := antigravity.DefaultURLAvailability
antigravity.BaseURLs = []string{"https://ep-test.example"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
t.Cleanup(func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvail
})
}
func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams {
return antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[ep-test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: handleError,
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_CustomErrorCodes
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
tests := []struct {
name string
upstreamStatus int
upstreamBody string
customCodes []any
expectHandleError int
expectUpstream int
expectStatusCode int
}{
{
name: "429_in_custom_codes_matched",
upstreamStatus: 429,
upstreamBody: `{"error":"rate limited"}`,
customCodes: []any{float64(429)},
expectHandleError: 1,
expectUpstream: 1,
expectStatusCode: 429,
},
{
name: "429_not_in_custom_codes_skipped",
upstreamStatus: 429,
upstreamBody: `{"error":"rate limited"}`,
customCodes: []any{float64(500)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 429,
},
{
name: "500_in_custom_codes_matched",
upstreamStatus: 500,
upstreamBody: `{"error":"internal"}`,
customCodes: []any{float64(500)},
expectHandleError: 1,
expectUpstream: 1,
expectStatusCode: 500,
},
{
name: "500_not_in_custom_codes_skipped",
upstreamStatus: 500,
upstreamBody: `{"error":"internal"}`,
customCodes: []any{float64(429)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 500,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 100,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": tt.customCodes,
},
}
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, tt.expectStatusCode, result.resp.StatusCode)
require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count")
require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count")
})
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_TempUnschedulable
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) {
tempRulesAccount := func(rules []any) *Account {
return &Account{
ID: 200,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": rules,
},
}
}
overloadedRule := map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
}
rateLimitRule := map[string]any{
"error_code": float64(429),
"keywords": []any{"rate limited keyword"},
"duration_minutes": float64(5),
}
t.Run("503_overloaded_matches_rule", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{overloadedRule})
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
t.Error("handleError should not be called for temp unschedulable")
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, 1, upstream.calls, "should not retry")
})
t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{rateLimitRule})
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
t.Error("handleError should not be called for temp unschedulable")
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, 1, upstream.calls, "should not retry")
})
t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 503, body: `random`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{overloadedRule})
// Use a short-lived context: the backoff sleep (~1s) will be
// interrupted, proving the code entered the default retry path
// instead of breaking early via error policy.
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
return nil
})
p.ctx = ctx
result, err := svc.antigravityRetryLoop(p)
// Context cancellation during backoff proves default retry was entered
require.Nil(t, result)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once")
})
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NilRateLimitService
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
// rateLimitService is nil — must not panic
svc := &AntigravityGatewayService{rateLimitService: nil}
account := &Account{
ID: 300,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
return nil
})
p.ctx = ctx
// Should not panic; enters the default retry path (eventually times out)
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.GreaterOrEqual(t, upstream.calls, 1)
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
// Plain OAuth account with no error policy configured
account := &Account{
ID: 400,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
}

View File

@@ -0,0 +1,289 @@
//go:build unit
package service
import (
"context"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
// ---------------------------------------------------------------------------
func TestCheckErrorPolicy(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expected ErrorPolicyResult
}{
{
name: "no_policy_oauth_returns_none",
account: &Account{
ID: 1,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
// no custom error codes, no temp rules
},
statusCode: 500,
body: []byte(`"error"`),
expected: ErrorPolicyNone,
},
{
name: "custom_error_codes_hit_returns_matched",
account: &Account{
ID: 2,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 500,
body: []byte(`"error"`),
expected: ErrorPolicyMatched,
},
{
name: "custom_error_codes_miss_returns_skipped",
account: &Account{
ID: 3,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 503,
body: []byte(`"error"`),
expected: ErrorPolicySkipped,
},
{
name: "temp_unschedulable_hit_returns_temp_unscheduled",
account: &Account{
ID: 4,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_body_miss_returns_none",
account: &Account{
ID: 5,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`random msg`),
expected: ErrorPolicyNone,
},
{
name: "custom_error_codes_override_temp_unschedulable",
account: &Account{
ID: 6,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(503)},
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
})
}
}
// ---------------------------------------------------------------------------
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
// ---------------------------------------------------------------------------
func TestApplyErrorPolicy(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expectedHandled bool
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
handleErrorCalls int
}{
{
name: "none_not_handled",
account: &Account{
ID: 10,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
},
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: false,
handleErrorCalls: 0,
},
{
name: "skipped_handled_no_handleError",
account: &Account{
ID: 11,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500, // not in custom codes
body: []byte(`"error"`),
expectedHandled: true,
handleErrorCalls: 0,
},
{
name: "matched_handled_calls_handleError",
account: &Account{
ID: 12,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(500)},
},
},
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: true,
handleErrorCalls: 1,
},
{
name: "temp_unscheduled_returns_switch_error",
account: &Account{
ID: 13,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expectedHandled: true,
expectedSwitchErr: true,
handleErrorCalls: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{
rateLimitService: rlSvc,
}
var handleErrorCount int
p := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: tt.account,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
},
isStickySession: true,
}
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
if tt.expectedSwitchErr {
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, retErr, &switchErr)
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
} else {
require.NoError(t, retErr)
}
})
}
}
// ---------------------------------------------------------------------------
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
// ---------------------------------------------------------------------------
type errorPolicyRepoStub struct {
mockAccountRepoForGemini
tempCalls int
setErrCalls int
lastErrorMsg string
}
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrCalls++
r.lastErrorMsg = errorMsg
return nil
}

View File

@@ -142,9 +142,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
@@ -216,30 +213,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
return nil
}
func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int

View File

@@ -6,9 +6,19 @@ import (
"fmt"
"math"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
// 仅在 GenerateSessionHash 第 3 级 fallback消息内容 hash时混入
// 避免不同用户发送相同消息产生相同 hash 导致账号集中。
type SessionContext struct {
ClientIP string
UserAgent string
APIKeyID int64
}
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
@@ -22,20 +32,22 @@ import (
// 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal减少 CPU 和内存开销
type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id用于会话亲和
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking部分平台会影响最终模型名
MaxTokens int // max_tokens 值(用于探测请求拦截)
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id用于会话亲和
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking部分平台会影响最终模型名
MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选请求上下文区分因子nil 时行为不变)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// protocol 指定请求协议格式domain.PlatformAnthropic / domain.PlatformGemini
// 不同协议使用不同的 system/messages 字段名。
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
@@ -64,14 +76,29 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed.MetadataUserID = userID
}
}
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
switch protocol {
case domain.PlatformGemini:
// Gemini 原生格式: systemInstruction.parts / contents
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
parsed.System = parts
}
}
if contents, ok := req["contents"].([]any); ok {
parsed.Messages = contents
}
default:
// Anthropic / OpenAI 格式: system / messages
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
}
// thinking: {type: "enabled"}

View File

@@ -4,12 +4,13 @@ import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/stretchr/testify/require"
)
func TestParseGatewayRequest(t *testing.T) {
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
require.True(t, parsed.Stream)
@@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) {
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
require.True(t, parsed.ThinkingEnabled)
@@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, 1, parsed.MaxTokens)
}
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, 0, parsed.MaxTokens)
}
func TestParseGatewayRequest_SystemNull(t *testing.T) {
body := []byte(`{"model":"claude-3","system":null}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
require.True(t, parsed.HasSystem)
@@ -53,16 +54,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
body := []byte(`{"model":123}`)
_, err := ParseGatewayRequest(body)
_, err := ParseGatewayRequest(body, "")
require.Error(t, err)
}
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
body := []byte(`{"stream":"true"}`)
_, err := ParseGatewayRequest(body)
_, err := ParseGatewayRequest(body, "")
require.Error(t, err)
}
// ============ Gemini 原生格式解析测试 ============
func TestParseGatewayRequest_GeminiContents(t *testing.T) {
body := []byte(`{
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there"}]},
{"role": "user", "parts": [{"text": "How are you?"}]}
]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Len(t, parsed.Messages, 3, "should parse contents as Messages")
require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem")
require.Nil(t, parsed.System, "no systemInstruction means nil System")
}
func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) {
body := []byte(`{
"systemInstruction": {
"parts": [{"text": "You are a helpful assistant."}]
},
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]}
]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System")
parts, ok := parsed.System.([]any)
require.True(t, ok)
require.Len(t, parts, 1)
partMap, ok := parts[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "You are a helpful assistant.", partMap["text"])
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_GeminiWithModel(t *testing.T) {
body := []byte(`{
"model": "gemini-2.5-pro",
"contents": [{"role": "user", "parts": [{"text": "test"}]}]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Equal(t, "gemini-2.5-pro", parsed.Model)
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) {
// Gemini 格式下 system/messages 字段应被忽略
body := []byte(`{
"system": "should be ignored",
"messages": [{"role": "user", "content": "ignored"}],
"contents": [{"role": "user", "parts": [{"text": "real content"}]}]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field")
require.Nil(t, parsed.System, "no systemInstruction = nil System")
require.Len(t, parsed.Messages, 1, "should use contents, not messages")
}
func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) {
body := []byte(`{"contents": []}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Empty(t, parsed.Messages)
}
func TestParseGatewayRequest_GeminiNoContents(t *testing.T) {
body := []byte(`{"model": "gemini-2.5-flash"}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Nil(t, parsed.Messages)
require.Equal(t, "gemini-2.5-flash", parsed.Model)
}
func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) {
// Anthropic 格式下 contents/systemInstruction 字段应被忽略
body := []byte(`{
"system": "real system",
"messages": [{"role": "user", "content": "real content"}],
"contents": [{"role": "user", "parts": [{"text": "ignored"}]}],
"systemInstruction": {"parts": [{"text": "ignored"}]}
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic)
require.NoError(t, err)
require.True(t, parsed.HasSystem)
require.Equal(t, "real system", parsed.System)
require.Len(t, parsed.Messages, 1)
msg, ok := parsed.Messages[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "real content", msg["content"])
}
func TestFilterThinkingBlocks(t *testing.T) {
containsThinkingBlock := func(body []byte) bool {
var req map[string]any

View File

@@ -5,7 +5,6 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -17,6 +16,7 @@ import (
"os"
"regexp"
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
@@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -245,9 +246,6 @@ var (
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
// allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{
"accept": true,
@@ -273,13 +271,6 @@ var allowedHeaders = map[string]bool{
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话Sticky Session的存储、查询、刷新和删除功能。
//
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
// Model load info for Antigravity scheduling
type ModelLoadInfo struct {
CallCount int64 // 当前分钟调用次数 / Call count in current minute
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
}
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
@@ -295,32 +286,6 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// IncrModelCallCount 增加模型调用次数并更新最后调度时间Antigravity 专用)
// Increment model call count and update last scheduling time (Antigravity only)
// 返回更新后的调用次数
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
// GetModelLoadBatch 批量获取账号的模型负载信息Antigravity 专用)
// Batch get model load info for accounts (Antigravity only)
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
// FindGeminiSession 查找 Gemini 会话MGET 倒序匹配)
// Find Gemini session using MGET reverse order matching
// 返回最长匹配的会话信息uuid, accountID
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
// FindAnthropicSession 查找 Anthropic 会话Trie 匹配)
// Find Anthropic session using Trie matching
FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveAnthropicSession 保存 Anthropic 会话
// Save Anthropic session binding
SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@@ -415,6 +380,7 @@ type GatewayService struct {
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
@@ -448,6 +414,7 @@ func NewGatewayService(
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
digestStore *DigestSessionStore,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
@@ -457,6 +424,7 @@ func NewGatewayService(
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
digestStore: digestStore,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
@@ -490,8 +458,17 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return s.hashContent(cacheableContent)
}
// 3. 最后 fallback: 使用 system + 所有消息的完整摘要串
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
var combined strings.Builder
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
if parsed.SessionContext != nil {
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
_, _ = combined.WriteString("|")
}
if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" {
@@ -500,9 +477,20 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
}
for _, msg := range parsed.Messages {
if m, ok := msg.(map[string]any); ok {
msgText := s.extractTextFromContent(m["content"])
if msgText != "" {
_, _ = combined.WriteString(msgText)
if content, exists := m["content"]; exists {
// Anthropic: messages[].content
if msgText := s.extractTextFromContent(content); msgText != "" {
_, _ = combined.WriteString(msgText)
}
} else if parts, ok := m["parts"].([]any); ok {
// Gemini: contents[].parts[].text
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok {
_, _ = combined.WriteString(text)
}
}
}
}
}
}
@@ -536,35 +524,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
// 返回最长匹配的会话信息uuid, accountID
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.digestStore == nil {
return "", 0, "", false
}
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
return s.digestStore.Find(groupID, prefixHash, digestChain)
}
// SaveGeminiSession 保存 Gemini 会话
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain用于删旧 key。
func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.digestStore == nil {
return nil
}
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
}
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.digestStore == nil {
return "", 0, "", false
}
return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain)
return s.digestStore.Find(groupID, prefixHash, digestChain)
}
// SaveAnthropicSession 保存 Anthropic 会话
func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.digestStore == nil {
return nil
}
return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
@@ -649,8 +639,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
}
func (s *GatewayService) hashContent(content string) string {
hash := sha256.Sum256([]byte(content))
return hex.EncodeToString(hash[:16]) // 32字符
h := xxhash.Sum64String(content)
return strconv.FormatUint(h, 36)
}
// replaceModelInBody 替换请求体中的model字段
@@ -1009,13 +999,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
}
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
@@ -1209,6 +1192,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
shuffleWithinSortGroups(routingAvailable)
// 4. 尝试获取槽位
for _, item := range routingAvailable {
@@ -1362,10 +1346,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return result, nil
}
} else {
// Antigravity 平台:获取模型负载信息
var modelLoadMap map[int64]*ModelLoadInfo
isAntigravity := platform == PlatformAntigravity
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
@@ -1380,109 +1360,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
modelToAccountIDs := make(map[string][]int64)
for _, item := range available {
mappedModel := mapAntigravityModel(item.account, requestedModel)
if mappedModel == "" {
continue
}
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
// 分层过滤选择:优先级 → 负载率 → LRU
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
}
for model, ids := range modelToAccountIDs {
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
if err != nil {
continue
}
for id, info := range batch {
modelLoadMap[id] = info
}
}
if len(modelLoadMap) == 0 {
modelLoadMap = nil
}
}
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
if isAntigravity {
for len(available) > 0 {
// 1. 取优先级最小的集合(硬过滤)
candidates := filterByMinPriority(available)
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
// 移除已尝试的账号,重新选择
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
} else {
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
// 移除已尝试的账号,重新进行分层过滤
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
// 移除已尝试的账号,重新进行分层过滤
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
available = newAvailable
}
}
@@ -2018,87 +1933,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
shuffleWithinPriorityAndLastUsed(accounts)
}
// selectByCallCount 从候选账号中选择调用次数最少的账号Antigravity 专用)
// 新账号CallCount=0使用平均调用次数作为虚拟值避免冷启动被猛调
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。
// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。
func shuffleWithinSortGroups(accounts []accountWithLoad) {
if len(accounts) <= 1 {
return
}
if len(accounts) == 1 {
return &accounts[0]
}
// 如果没有负载信息,回退到 LRU
if modelLoadMap == nil {
return selectByLRU(accounts, preferOAuth)
}
// 1. 计算平均调用次数(用于新账号冷启动)
var totalCallCount int64
var countWithCalls int
for _, acc := range accounts {
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
totalCallCount += info.CallCount
countWithCalls++
i := 0
for i < len(accounts) {
j := i + 1
for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) {
j++
}
}
var avgCallCount int64
if countWithCalls > 0 {
avgCallCount = totalCallCount / int64(countWithCalls)
}
// 2. 获取每个账号的有效调用次数
getEffectiveCallCount := func(acc accountWithLoad) int64 {
if acc.account == nil {
return 0
if j-i > 1 {
mathrand.Shuffle(j-i, func(a, b int) {
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
})
}
info := modelLoadMap[acc.account.ID]
if info == nil || info.CallCount == 0 {
return avgCallCount // 新账号使用平均值
}
return info.CallCount
i = j
}
}
// 3. 找到最小调用次数
minCount := getEffectiveCallCount(accounts[0])
for _, acc := range accounts[1:] {
if c := getEffectiveCallCount(acc); c < minCount {
minCount = c
}
// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组
func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
if a.account.Priority != b.account.Priority {
return false
}
// 4. 收集所有具有最小调用次数的账号
var candidateIdxs []int
for i, acc := range accounts {
if getEffectiveCallCount(acc) == minCount {
candidateIdxs = append(candidateIdxs, i)
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return false
}
return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt)
}
// 5. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
func shuffleWithinPriorityAndLastUsed(accounts []*Account) {
if len(accounts) <= 1 {
return
}
// 6. preferOAuth 处理
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
i := 0
for i < len(accounts) {
j := i + 1
for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) {
j++
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
if j-i > 1 {
mathrand.Shuffle(j-i, func(a, b int) {
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
})
}
i = j
}
}
// 7. 随机选择
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
// sameAccountGroup 判断两个 Account 是否属于同一排序组Priority + LastUsedAt
func sameAccountGroup(a, b *Account) bool {
if a.Priority != b.Priority {
return false
}
return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt)
}
// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒)
func sameLastUsedAt(a, b *time.Time) bool {
switch {
case a == nil && b == nil:
return true
case a == nil || b == nil:
return false
default:
return a.Unix() == b.Unix()
}
}
// sortCandidatesForFallback 根据配置选择排序策略
@@ -2153,13 +2060,6 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
@@ -5171,27 +5071,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return normalized, nil
}
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
if !ok {
return nil // 无法解析 scope跳过检查
}
group, err := s.resolveGroupByID(ctx, groupID)
if err != nil {
return nil // 查询失败时放行
}
if group == nil {
return nil // 分组不存在时放行
}
if !IsScopeSupported(group.SupportedModelScopes, scope) {
return ErrModelScopeNotSupported
}
return nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {

View File

@@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
if err != nil {
b.Fatalf("解析请求失败: %v", err)
}

View File

@@ -0,0 +1,384 @@
//go:build unit
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestShouldFailoverGeminiUpstreamError — verifies the failover decision
// for the ErrorPolicyNone path (original logic preserved).
// ---------------------------------------------------------------------------
func TestShouldFailoverGeminiUpstreamError(t *testing.T) {
svc := &GeminiMessagesCompatService{}
tests := []struct {
name string
statusCode int
expected bool
}{
{"401_failover", 401, true},
{"403_failover", 403, true},
{"429_failover", 429, true},
{"529_failover", 529, true},
{"500_failover", 500, true},
{"502_failover", 502, true},
{"503_failover", 503, true},
{"400_no_failover", 400, false},
{"404_no_failover", 404, false},
{"422_no_failover", 422, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode)
require.Equal(t, tt.expected, got)
})
}
}
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works
// correctly for Gemini platform accounts (API Key type).
// ---------------------------------------------------------------------------
func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expected ErrorPolicyResult
}{
{
name: "gemini_apikey_custom_codes_hit",
account: &Account{
ID: 100,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 429,
body: []byte(`{"error":"rate limited"}`),
expected: ErrorPolicyMatched,
},
{
name: "gemini_apikey_custom_codes_miss",
account: &Account{
ID: 101,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500,
body: []byte(`{"error":"internal"}`),
expected: ErrorPolicySkipped,
},
{
name: "gemini_apikey_no_custom_codes_returns_none",
account: &Account{
ID: 102,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 500,
body: []byte(`{"error":"internal"}`),
expected: ErrorPolicyNone,
},
{
name: "gemini_apikey_temp_unschedulable_hit",
account: &Account{
ID: 103,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "gemini_custom_codes_override_temp_unschedulable",
account: &Account{
ID: 104,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(503)},
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
require.Equal(t, tt.expected, result)
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling
// paths produce the correct behavior for each ErrorPolicyResult.
//
// These tests simulate the inline error policy switch in handleClaudeCompat
// and forwardNativeGemini by calling the same methods in the same order.
// ---------------------------------------------------------------------------
func TestGeminiErrorPolicyIntegration(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
account *Account
statusCode int
respBody []byte
expectFailover bool // expect UpstreamFailoverError
expectHandleError bool // expect handleGeminiUpstreamError to be called
expectShouldFailover bool // for None path, whether shouldFailover triggers
}{
{
name: "custom_codes_matched_429_failover",
account: &Account{
ID: 200,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 429,
respBody: []byte(`{"error":"rate limited"}`),
expectFailover: true,
expectHandleError: true,
},
{
name: "custom_codes_skipped_500_no_failover",
account: &Account{
ID: 201,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500,
respBody: []byte(`{"error":"internal"}`),
expectFailover: false,
expectHandleError: false,
},
{
name: "temp_unschedulable_matched_failover",
account: &Account{
ID: 202,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
respBody: []byte(`overloaded`),
expectFailover: true,
expectHandleError: true,
},
{
name: "no_policy_429_failover_via_shouldFailover",
account: &Account{
ID: 203,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 429,
respBody: []byte(`{"error":"rate limited"}`),
expectFailover: true,
expectHandleError: true,
expectShouldFailover: true,
},
{
name: "no_policy_400_no_failover",
account: &Account{
ID: 204,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 400,
respBody: []byte(`{"error":"bad request"}`),
expectFailover: false,
expectHandleError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &geminiErrorPolicyRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &GeminiMessagesCompatService{
accountRepo: repo,
rateLimitService: rlSvc,
}
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// Simulate the Claude compat error handling path (same logic as native).
// This mirrors the inline switch in handleClaudeCompat.
var handleErrorCalled bool
var gotFailover bool
ctx := context.Background()
statusCode := tt.statusCode
respBody := tt.respBody
account := tt.account
headers := http.Header{}
if svc.rateLimitService != nil {
switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) {
case ErrorPolicySkipped:
// Skipped → return error directly (no handleGeminiUpstreamError, no failover)
gotFailover = false
handleErrorCalled = false
goto verify
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
handleErrorCalled = true
gotFailover = true
goto verify
}
}
// ErrorPolicyNone → original logic
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
handleErrorCalled = true
if svc.shouldFailoverGeminiUpstreamError(statusCode) {
gotFailover = true
}
verify:
require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch")
require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch")
if tt.expectShouldFailover {
require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode),
"shouldFailoverGeminiUpstreamError should return true for status %d", statusCode)
}
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety
// ---------------------------------------------------------------------------
func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) {
svc := &GeminiMessagesCompatService{
rateLimitService: nil,
}
// When rateLimitService is nil, error policy is skipped → falls through to
// shouldFailoverGeminiUpstreamError (original logic).
// Verify this doesn't panic and follows expected behavior.
ctx := context.Background()
account := &Account{
ID: 300,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
}
// The nil check should prevent CheckErrorPolicy from being called
if svc.rateLimitService != nil {
t.Fatal("rateLimitService should be nil for this test")
}
// shouldFailoverGeminiUpstreamError still works
require.True(t, svc.shouldFailoverGeminiUpstreamError(429))
require.False(t, svc.shouldFailoverGeminiUpstreamError(400))
// handleGeminiUpstreamError should not panic with nil rateLimitService
require.NotPanics(t, func() {
svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`))
})
}
// ---------------------------------------------------------------------------
// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error
// policy tests. Embeds mockAccountRepoForGemini and adds tracking.
// ---------------------------------------------------------------------------
type geminiErrorPolicyRepo struct {
mockAccountRepoForGemini
setErrorCalls int
setRateLimitedCalls int
setTempCalls int
}
func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error {
r.setErrorCalls++
return nil
}
func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error {
r.setRateLimitedCalls++
return nil
}
func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.setTempCalls++
return nil
}

View File

@@ -831,38 +831,47 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
// 统一错误策略:自定义错误码 + 临时不可调度
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if tempMatched {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
case ErrorPolicySkipped:
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamDetail = truncateString(string(respBody), maxBytes)
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
@@ -1249,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// This avoids Gemini SDKs failing hard during preflight token counting.
// Checked before error policy so it always works regardless of custom error codes.
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
@@ -1270,30 +1274,46 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
if tempMatched {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
// 统一错误策略:自定义错误码 + 临时不可调度
if s.rateLimitService != nil {
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
case ErrorPolicySkipped:
respBody = unwrapIfNeeded(isOAuth, respBody)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
upstreamDetail = truncateString(string(evBody), maxBytes)
c.Data(resp.StatusCode, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(evBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))

View File

@@ -133,9 +133,6 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
@@ -269,30 +266,6 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context,
return nil
}
func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()

View File

@@ -6,26 +6,11 @@ import (
"encoding/json"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/cespare/xxhash/v2"
)
// Gemini 会话 ID Fallback 相关常量
const (
// geminiSessionTTLSeconds Gemini 会话缓存 TTL5 分钟)
geminiSessionTTLSeconds = 300
// geminiSessionKeyPrefix Gemini 会话 Redis key 前缀
geminiSessionKeyPrefix = "gemini:sess:"
)
// GeminiSessionTTL 返回 Gemini 会话缓存 TTL
func GeminiSessionTTL() time.Duration {
return geminiSessionTTLSeconds * time.Second
}
// shortHash 使用 XXHash64 + Base36 生成短 hash16 字符)
// XXHash64 比 SHA256 快约 10 倍Base36 比 Hex 短约 20%
func shortHash(data []byte) string {
@@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m
return base64.RawURLEncoding.EncodeToString(hash[:12])
}
// BuildGeminiSessionKey 构建 Gemini 会话 Redis key
// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain}
func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string {
return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain
}
// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短)
// 用于 MGET 批量查询最长匹配
func GenerateDigestChainPrefixes(chain string) []string {
if chain == "" {
return nil
}
var prefixes []string
c := chain
for c != "" {
prefixes = append(prefixes, c)
// 找到最后一个 "-" 的位置
if i := strings.LastIndex(c, "-"); i > 0 {
c = c[:i]
} else {
break
}
}
return prefixes
}
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
// 格式: {uuid}:{accountID}
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
@@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string {
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
const geminiDigestSessionKeyPrefix = "gemini:digest:"
// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀
const geminiTrieKeyPrefix = "gemini:trie:"
// BuildGeminiTrieKey 构建 Gemini Trie Redis key
// 格式: gemini:trie:{groupID}:{prefixHash}
func BuildGeminiTrieKey(groupID int64, prefixHash string) string {
return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话

View File

@@ -1,41 +1,14 @@
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// mockGeminiSessionCache 模拟 Redis 缓存
type mockGeminiSessionCache struct {
sessions map[string]string // key -> value
}
func newMockGeminiSessionCache() *mockGeminiSessionCache {
return &mockGeminiSessionCache{sessions: make(map[string]string)}
}
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
value := FormatGeminiSessionValue(uuid, accountID)
m.sessions[key] = value
}
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
prefixes := GenerateDigestChainPrefixes(digestChain)
for _, p := range prefixes {
key := BuildGeminiSessionKey(groupID, prefixHash, p)
if val, ok := m.sessions[key]; ok {
return ParseGeminiSessionValue(val)
}
}
return "", 0, false
}
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
func TestGeminiSessionContinuousConversation(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
sessionUUID := "session-uuid-12345"
@@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 1 chain: %s", chain1)
// 第一轮:没有找到会话,创建新会话
_, _, found := cache.Find(groupID, prefixHash, chain1)
_, _, _, found := store.Find(groupID, prefixHash, chain1)
if found {
t.Error("Round 1: should not find existing session")
}
// 保存第一轮会话
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID)
// 保存第一轮会话(首轮无旧 chain
store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "")
// 模拟第二轮对话(用户继续对话)
req2 := &antigravity.GeminiRequest{
@@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 2 chain: %s", chain2)
// 第二轮:应该能找到会话(通过前缀匹配)
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2)
foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2)
if !found {
t.Error("Round 2: should find session via prefix matching")
}
@@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
}
// 保存第二轮会话
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID)
// 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key
store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain)
// 模拟第三轮对话
req3 := &antigravity.GeminiRequest{
@@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 3 chain: %s", chain3)
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3)
foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3)
if !found {
t.Error("Round 3: should find session via prefix matching")
}
@@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
if foundAccID != accountID {
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
}
t.Log("✓ Continuous conversation session matching works correctly!")
}
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
func TestGeminiSessionDifferentConversations(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
@@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
},
}
chain1 := BuildGeminiDigestChain(req1)
cache.Save(groupID, prefixHash, chain1, "session-1", 100)
store.Save(groupID, prefixHash, chain1, "session-1", 100, "")
// 第二个完全不同的会话
req2 := &antigravity.GeminiRequest{
@@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
chain2 := BuildGeminiDigestChain(req2)
// 不同会话不应该匹配
_, _, found := cache.Find(groupID, prefixHash, chain2)
_, _, _, found := store.Find(groupID, prefixHash, chain2)
if found {
t.Error("Different conversations should not match")
}
t.Log("✓ Different conversations are correctly isolated!")
}
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
// 创建一个三轮对话
req := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
},
}
fullChain := BuildGeminiDigestChain(req)
prefixes := GenerateDigestChainPrefixes(fullChain)
t.Logf("Full chain: %s", fullChain)
t.Logf("Prefixes (longest first): %v", prefixes)
// 验证前缀生成顺序(从长到短)
if len(prefixes) != 4 {
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
}
// 保存不同轮次的会话到不同账号
// 第一轮(最短前缀)-> 账号 1
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1)
// 第二轮 -> 账号 2
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
// 第三轮(最长前缀,完整链)-> 账号 3
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3)
store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "")
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "")
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "")
// 查找应该返回最长匹配(账号 3
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
// 查找更长的链,应该返回最长匹配(账号 3
_, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2")
if !found {
t.Error("Should find session")
}
if accID != 3 {
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
}
t.Log("✓ Longest prefix matching works correctly!")
}
// 确保 context 包被使用(避免未使用的导入警告)
var _ = context.Background

View File

@@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
}
}
func TestGenerateDigestChainPrefixes(t *testing.T) {
tests := []struct {
name string
chain string
want []string
wantLen int
}{
{
name: "empty",
chain: "",
wantLen: 0,
},
{
name: "single part",
chain: "u:abc123",
want: []string{"u:abc123"},
wantLen: 1,
},
{
name: "two parts",
chain: "s:xyz-u:abc",
want: []string{"s:xyz-u:abc", "s:xyz"},
wantLen: 2,
},
{
name: "four parts",
chain: "s:a-u:b-m:c-u:d",
want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"},
wantLen: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateDigestChainPrefixes(tt.chain)
if len(result) != tt.wantLen {
t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result)
}
if tt.want != nil {
for i, want := range tt.want {
if i >= len(result) {
t.Errorf("missing prefix at index %d", i)
continue
}
if result[i] != want {
t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i])
}
}
}
})
}
}
func TestParseGeminiSessionValue(t *testing.T) {
tests := []struct {
name string
@@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) {
}
})
}
func TestBuildGeminiTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "gemini:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "gemini:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "gemini:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -318,110 +318,6 @@ func TestGetModelRateLimitRemainingTime(t *testing.T) {
}
}
func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future10m := now.Add(10 * time.Minute).Format(time.RFC3339)
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
minExpected time.Duration
maxExpected time.Duration
}{
{
name: "nil account",
account: nil,
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "non-antigravity platform",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "claude scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "gemini_text scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"gemini_text": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "gemini-3-flash",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "expired scope rate limit",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": past,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "unsupported model",
account: &Account{
Platform: PlatformAntigravity,
},
requestedModel: "gpt-4",
minExpected: 0,
maxExpected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel)
if result < tt.minExpected || result > tt.maxExpected {
t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
}
})
}
}
func TestGetRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future15m := now.Add(15 * time.Minute).Format(time.RFC3339)
@@ -442,45 +338,19 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
maxExpected: 0,
},
{
name: "model remaining > scope remaining - returns model",
name: "model rate limited - 15 minutes",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟
"rate_limit_reset_at": future15m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
maxExpected: 16 * time.Minute,
},
{
name: "scope remaining > model remaining - returns scope",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
minExpected: 14 * time.Minute,
maxExpected: 16 * time.Minute,
},
{
@@ -499,22 +369,6 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "only scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "neither rate limited",
account: &Account{

View File

@@ -580,10 +580,6 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
@@ -618,6 +614,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
shuffleWithinSortGroups(available)
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)

View File

@@ -204,30 +204,6 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i
return nil
}
func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)

View File

@@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
}
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
if acc.Platform != "" {
if _, ok := platform[acc.Platform]; !ok {
@@ -85,14 +84,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError {
p.ErrorCount++
}
if len(scopeRateLimits) > 0 {
if p.ScopeRateLimitCount == nil {
p.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
p.ScopeRateLimitCount[scope]++
}
}
}
for _, grp := range acc.Groups {
@@ -117,14 +108,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError {
g.ErrorCount++
}
if len(scopeRateLimits) > 0 {
if g.ScopeRateLimitCount == nil {
g.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
g.ScopeRateLimitCount[scope]++
}
}
}
displayGroupID := int64(0)
@@ -157,9 +140,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
item.RateLimitRemainingSec = &remainingSec
}
}
if len(scopeRateLimits) > 0 {
item.ScopeRateLimits = scopeRateLimits
}
if isOverloaded && acc.OverloadUntil != nil {
item.OverloadUntil = acc.OverloadUntil
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())

View File

@@ -50,24 +50,22 @@ type UserConcurrencyInfo struct {
// PlatformAvailability aggregates account availability by platform.
type PlatformAvailability struct {
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
ErrorCount int64 `json:"error_count"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"`
}
// GroupAvailability aggregates account availability by group.
type GroupAvailability struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
ErrorCount int64 `json:"error_count"`
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"`
}
// AccountAvailability represents current availability for a single account.
@@ -85,11 +83,10 @@ type AccountAvailability struct {
IsOverloaded bool `json:"is_overloaded"`
HasError bool `json:"has_error"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"`
OverloadUntil *time.Time `json:"overload_until"`
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
ErrorMessage string `json:"error_message"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
OverloadUntil *time.Time `json:"overload_until"`
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
ErrorMessage string `json:"error_message"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
}

View File

@@ -12,6 +12,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
@@ -528,7 +529,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) {
switch reqType {
case opsRetryTypeMessages:
parsed, parseErr := ParseGatewayRequest(body)
parsed, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
if parseErr != nil {
return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr)
}
@@ -596,7 +597,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq
if s.gatewayService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"}
}
parsedReq, parseErr := ParseGatewayRequest(body)
parsedReq, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
if parseErr != nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"}
}

View File

@@ -62,6 +62,32 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali
s.tokenCacheInvalidator = invalidator
}
// ErrorPolicyResult 表示错误策略检查的结果
type ErrorPolicyResult int
const (
ErrorPolicyNone ErrorPolicyResult = iota // 未命中任何策略,继续默认逻辑
ErrorPolicySkipped // 自定义错误码开启但未命中,跳过处理
ErrorPolicyMatched // 自定义错误码命中,应停止调度
ErrorPolicyTempUnscheduled // 临时不可调度规则命中
)
// CheckErrorPolicy 检查自定义错误码和临时不可调度规则。
// 自定义错误码开启时覆盖后续所有逻辑(包括临时不可调度)。
func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Account, statusCode int, responseBody []byte) ErrorPolicyResult {
if account.IsCustomErrorCodesEnabled() {
if account.ShouldHandleErrorCode(statusCode) {
return ErrorPolicyMatched
}
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return ErrorPolicySkipped
}
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
return ErrorPolicyTempUnscheduled
}
return ErrorPolicyNone
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {

View File

@@ -0,0 +1,318 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ============ shuffleWithinSortGroups 测试 ============
func TestShuffleWithinSortGroups_Empty(t *testing.T) {
shuffleWithinSortGroups(nil)
shuffleWithinSortGroups([]accountWithLoad{})
}
func TestShuffleWithinSortGroups_SingleElement(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
shuffleWithinSortGroups(accounts)
require.Equal(t, int64(1), accounts[0].account.ID)
}
func TestShuffleWithinSortGroups_DifferentGroups_OrderPreserved(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 3, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
// 每个元素都属于不同组Priority 或 LoadRate 或 LastUsedAt 不同),顺序不变
for i := 0; i < 20; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
require.Equal(t, int64(1), cpy[0].account.ID)
require.Equal(t, int64(2), cpy[1].account.ID)
require.Equal(t, int64(3), cpy[2].account.ID)
}
}
func TestShuffleWithinSortGroups_SameGroup_Shuffled(t *testing.T) {
now := time.Now()
// 同一秒的时间戳视为同一组
sameSecond := time.Unix(now.Unix(), 0)
sameSecond2 := time.Unix(now.Unix(), 500_000_000) // 同一秒但不同纳秒
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &sameSecond2}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
// 多次执行,验证所有 ID 都出现在第一个位置(说明确实被打乱了)
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
seen[cpy[0].account.ID] = true
// 无论怎么打乱,所有 ID 都应在候选中
ids := map[int64]bool{}
for _, a := range cpy {
ids[a.account.ID] = true
}
require.True(t, ids[1] && ids[2] && ids[3])
}
// 至少 2 个不同的 ID 出现在首位(随机性验证)
require.GreaterOrEqual(t, len(seen), 2, "shuffle should produce different orderings")
}
func TestShuffleWithinSortGroups_NilLastUsedAt_SameGroup(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
{account: &Account{ID: 3, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
}
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
seen[cpy[0].account.ID] = true
}
require.GreaterOrEqual(t, len(seen), 2, "nil LastUsedAt accounts should be shuffled")
}
func TestShuffleWithinSortGroups_MixedGroups(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
sameAsNow := time.Unix(now.Unix(), 0)
// 组1: Priority=1, LoadRate=10, LastUsedAt=earlier (ID 1) — 单元素组
// 组2: Priority=1, LoadRate=20, LastUsedAt=now (ID 2, 3) — 双元素组
// 组3: Priority=2, LoadRate=10, LastUsedAt=earlier (ID 4) — 单元素组
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameAsNow}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 4, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
for i := 0; i < 20; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
// 组间顺序不变
require.Equal(t, int64(1), cpy[0].account.ID, "group 1 position fixed")
require.Equal(t, int64(4), cpy[3].account.ID, "group 3 position fixed")
// 组2 内部可以打乱,但仍在位置 1 和 2
mid := map[int64]bool{cpy[1].account.ID: true, cpy[2].account.ID: true}
require.True(t, mid[2] && mid[3], "group 2 elements should stay in positions 1-2")
}
}
// ============ shuffleWithinPriorityAndLastUsed 测试 ============
func TestShuffleWithinPriorityAndLastUsed_Empty(t *testing.T) {
shuffleWithinPriorityAndLastUsed(nil)
shuffleWithinPriorityAndLastUsed([]*Account{})
}
func TestShuffleWithinPriorityAndLastUsed_SingleElement(t *testing.T) {
accounts := []*Account{{ID: 1, Priority: 1}}
shuffleWithinPriorityAndLastUsed(accounts)
require.Equal(t, int64(1), accounts[0].ID)
}
func TestShuffleWithinPriorityAndLastUsed_SameGroup_Shuffled(t *testing.T) {
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 1, LastUsedAt: nil},
{ID: 3, Priority: 1, LastUsedAt: nil},
}
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
shuffleWithinPriorityAndLastUsed(cpy)
seen[cpy[0].ID] = true
}
require.GreaterOrEqual(t, len(seen), 2, "same group should be shuffled")
}
func TestShuffleWithinPriorityAndLastUsed_DifferentPriority_OrderPreserved(t *testing.T) {
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 2, LastUsedAt: nil},
{ID: 3, Priority: 3, LastUsedAt: nil},
}
for i := 0; i < 20; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
shuffleWithinPriorityAndLastUsed(cpy)
require.Equal(t, int64(1), cpy[0].ID)
require.Equal(t, int64(2), cpy[1].ID)
require.Equal(t, int64(3), cpy[2].ID)
}
}
func TestShuffleWithinPriorityAndLastUsed_DifferentLastUsedAt_OrderPreserved(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 1, LastUsedAt: &earlier},
{ID: 3, Priority: 1, LastUsedAt: &now},
}
for i := 0; i < 20; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
shuffleWithinPriorityAndLastUsed(cpy)
require.Equal(t, int64(1), cpy[0].ID)
require.Equal(t, int64(2), cpy[1].ID)
require.Equal(t, int64(3), cpy[2].ID)
}
}
// ============ sameLastUsedAt 测试 ============
func TestSameLastUsedAt(t *testing.T) {
now := time.Now()
sameSecond := time.Unix(now.Unix(), 0)
sameSecondDiffNano := time.Unix(now.Unix(), 999_999_999)
differentSecond := now.Add(1 * time.Second)
t.Run("both nil", func(t *testing.T) {
require.True(t, sameLastUsedAt(nil, nil))
})
t.Run("one nil one not", func(t *testing.T) {
require.False(t, sameLastUsedAt(nil, &now))
require.False(t, sameLastUsedAt(&now, nil))
})
t.Run("same second different nanoseconds", func(t *testing.T) {
require.True(t, sameLastUsedAt(&sameSecond, &sameSecondDiffNano))
})
t.Run("different seconds", func(t *testing.T) {
require.False(t, sameLastUsedAt(&now, &differentSecond))
})
t.Run("exact same time", func(t *testing.T) {
require.True(t, sameLastUsedAt(&now, &now))
})
}
// ============ sameAccountWithLoadGroup 测试 ============
func TestSameAccountWithLoadGroup(t *testing.T) {
now := time.Now()
sameSecond := time.Unix(now.Unix(), 0)
t.Run("same group", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
require.True(t, sameAccountWithLoadGroup(a, b))
})
t.Run("different priority", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 2, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
require.False(t, sameAccountWithLoadGroup(a, b))
})
t.Run("different load rate", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}}
require.False(t, sameAccountWithLoadGroup(a, b))
})
t.Run("different last used at", func(t *testing.T) {
later := now.Add(1 * time.Second)
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &later}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
require.False(t, sameAccountWithLoadGroup(a, b))
})
t.Run("both nil LastUsedAt", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}
require.True(t, sameAccountWithLoadGroup(a, b))
})
}
// ============ sameAccountGroup 测试 ============
func TestSameAccountGroup(t *testing.T) {
now := time.Now()
t.Run("same group", func(t *testing.T) {
a := &Account{Priority: 1, LastUsedAt: nil}
b := &Account{Priority: 1, LastUsedAt: nil}
require.True(t, sameAccountGroup(a, b))
})
t.Run("different priority", func(t *testing.T) {
a := &Account{Priority: 1, LastUsedAt: nil}
b := &Account{Priority: 2, LastUsedAt: nil}
require.False(t, sameAccountGroup(a, b))
})
t.Run("different LastUsedAt", func(t *testing.T) {
later := now.Add(1 * time.Second)
a := &Account{Priority: 1, LastUsedAt: &now}
b := &Account{Priority: 1, LastUsedAt: &later}
require.False(t, sameAccountGroup(a, b))
})
}
// ============ sortAccountsByPriorityAndLastUsed 集成随机化测试 ============
func TestSortAccountsByPriorityAndLastUsed_WithShuffle(t *testing.T) {
t.Run("same priority and nil LastUsedAt are shuffled", func(t *testing.T) {
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 1, LastUsedAt: nil},
{ID: 3, Priority: 1, LastUsedAt: nil},
}
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
sortAccountsByPriorityAndLastUsed(cpy, false)
seen[cpy[0].ID] = true
}
require.GreaterOrEqual(t, len(seen), 2, "identical sort keys should produce different orderings after shuffle")
})
t.Run("different priorities still sorted correctly", func(t *testing.T) {
now := time.Now()
accounts := []*Account{
{ID: 3, Priority: 3, LastUsedAt: &now},
{ID: 1, Priority: 1, LastUsedAt: &now},
{ID: 2, Priority: 2, LastUsedAt: &now},
}
sortAccountsByPriorityAndLastUsed(accounts, false)
require.Equal(t, int64(1), accounts[0].ID)
require.Equal(t, int64(2), accounts[1].ID)
require.Equal(t, int64(3), accounts[2].ID)
})
}

View File

@@ -275,4 +275,5 @@ var ProviderSet = wire.NewSet(
NewUsageCache,
NewTotpService,
NewErrorPassthroughService,
NewDigestSessionStore,
)

View File

@@ -376,7 +376,6 @@ export interface PlatformAvailability {
total_accounts: number
available_count: number
rate_limit_count: number
scope_rate_limit_count?: Record<string, number>
error_count: number
}
@@ -387,7 +386,6 @@ export interface GroupAvailability {
total_accounts: number
available_count: number
rate_limit_count: number
scope_rate_limit_count?: Record<string, number>
error_count: number
}
@@ -402,7 +400,6 @@ export interface AccountAvailability {
is_rate_limited: boolean
rate_limit_reset_at?: string
rate_limit_remaining_sec?: number
scope_rate_limits?: Record<string, number>
is_overloaded: boolean
overload_until?: string
overload_remaining_sec?: number

View File

@@ -76,26 +76,6 @@
</div>
</div>
<!-- Scope Rate Limit Indicators (Antigravity) -->
<template v-if="activeScopeRateLimits.length > 0">
<div v-for="item in activeScopeRateLimits" :key="item.scope" class="group relative">
<span
class="inline-flex items-center gap-1 rounded bg-orange-100 px-1.5 py-0.5 text-xs font-medium text-orange-700 dark:bg-orange-900/30 dark:text-orange-400"
>
<Icon name="exclamationTriangle" size="xs" :stroke-width="2" />
{{ formatScopeName(item.scope) }}
</span>
<!-- Tooltip -->
<div
class="pointer-events-none absolute bottom-full left-1/2 z-50 mb-2 -translate-x-1/2 whitespace-nowrap rounded bg-gray-900 px-2 py-1 text-xs text-white opacity-0 transition-opacity group-hover:opacity-100 dark:bg-gray-700"
>
{{ t('admin.accounts.status.scopeRateLimitedUntil', { scope: formatScopeName(item.scope), time: formatTime(item.reset_at) }) }}
<div
class="absolute left-1/2 top-full -translate-x-1/2 border-4 border-transparent border-t-gray-900 dark:border-t-gray-700" ></div>
</div>
</div>
</template>
<!-- Model Rate Limit Indicators (Antigravity OAuth Smart Retry) -->
<template v-if="activeModelRateLimits.length > 0">
<div v-for="item in activeModelRateLimits" :key="item.model" class="group relative">
@@ -160,15 +140,6 @@ const isRateLimited = computed(() => {
return new Date(props.account.rate_limit_reset_at) > new Date()
})
// Computed: active scope rate limits (Antigravity)
const activeScopeRateLimits = computed(() => {
const scopeLimits = props.account.scope_rate_limits
if (!scopeLimits) return []
const now = new Date()
return Object.entries(scopeLimits)
.filter(([, info]) => new Date(info.reset_at) > now)
.map(([scope, info]) => ({ scope, reset_at: info.reset_at }))
})
// Computed: active model rate limits (Antigravity OAuth Smart Retry)
const activeModelRateLimits = computed(() => {

View File

@@ -1038,10 +1038,7 @@
</div>
<!-- Custom Error Codes Section -->
<div
v-if="form.platform !== 'gemini'"
class="border-t border-gray-200 pt-4 dark:border-dark-600"
>
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.customErrorCodes') }}</label>

View File

@@ -1360,7 +1360,6 @@ export default {
overloaded: 'Overloaded',
tempUnschedulable: 'Temp Unschedulable',
rateLimitedUntil: 'Rate limited until {time}',
scopeRateLimitedUntil: '{scope} rate limited until {time}',
modelRateLimitedUntil: '{model} rate limited until {time}',
overloadedUntil: 'Overloaded until {time}',
viewTempUnschedDetails: 'View temp unschedulable details'
@@ -3063,7 +3062,6 @@ export default {
empty: 'No data',
queued: 'Queue {count}',
rateLimited: 'Rate-limited {count}',
scopeRateLimitedTooltip: '{scope} rate-limited ({count} accounts)',
errorAccounts: 'Errors {count}',
loadFailed: 'Failed to load concurrency data'
},

View File

@@ -1496,7 +1496,6 @@ export default {
overloaded: '过载中',
tempUnschedulable: '临时不可调度',
rateLimitedUntil: '限流中,重置时间:{time}',
scopeRateLimitedUntil: '{scope} 限流中,重置时间:{time}',
modelRateLimitedUntil: '{model} 限流至 {time}',
overloadedUntil: '负载过重,重置时间:{time}',
viewTempUnschedDetails: '查看临时不可调度详情'
@@ -3236,7 +3235,6 @@ export default {
empty: '暂无数据',
queued: '队列 {count}',
rateLimited: '限流 {count}',
scopeRateLimitedTooltip: '{scope} 限流中 ({count} 个账号)',
errorAccounts: '异常 {count}',
loadFailed: '加载并发数据失败'
},

View File

@@ -594,9 +594,6 @@ export interface Account {
temp_unschedulable_until: string | null
temp_unschedulable_reason: string | null
// Antigravity scope 级限流状态
scope_rate_limits?: Record<string, { reset_at: string; remaining_sec: number }>
// Session window fields (5-hour window)
session_window_start: string | null
session_window_end: string | null

View File

@@ -56,7 +56,6 @@ interface SummaryRow {
total_accounts: number
available_accounts: number
rate_limited_accounts: number
scope_rate_limit_count?: Record<string, number>
error_accounts: number
// 并发统计
total_concurrency: number
@@ -122,7 +121,7 @@ const platformRows = computed((): SummaryRow[] => {
total_accounts: totalAccounts,
available_accounts: availableAccounts,
rate_limited_accounts: safeNumber(avail.rate_limit_count),
scope_rate_limit_count: avail.scope_rate_limit_count,
error_accounts: safeNumber(avail.error_count),
total_concurrency: totalConcurrency,
used_concurrency: usedConcurrency,
@@ -162,7 +161,7 @@ const groupRows = computed((): SummaryRow[] => {
total_accounts: totalAccounts,
available_accounts: availableAccounts,
rate_limited_accounts: safeNumber(avail.rate_limit_count),
scope_rate_limit_count: avail.scope_rate_limit_count,
error_accounts: safeNumber(avail.error_count),
total_concurrency: totalConcurrency,
used_concurrency: usedConcurrency,
@@ -329,14 +328,6 @@ function formatDuration(seconds: number): string {
return `${hours}h`
}
function formatScopeName(scope: string): string {
const names: Record<string, string> = {
claude: 'Claude',
gemini_text: 'Gemini',
gemini_image: 'Image'
}
return names[scope] || scope
}
watch(
() => realtimeEnabled.value,
@@ -505,18 +496,6 @@ watch(
{{ t('admin.ops.concurrency.rateLimited', { count: row.rate_limited_accounts }) }}
</span>
<!-- Scope 限流 ( Antigravity) -->
<template v-if="row.scope_rate_limit_count && Object.keys(row.scope_rate_limit_count).length > 0">
<span
v-for="(count, scope) in row.scope_rate_limit_count"
:key="scope"
class="rounded-full bg-orange-100 px-1.5 py-0.5 font-semibold text-orange-700 dark:bg-orange-900/30 dark:text-orange-400"
:title="t('admin.ops.concurrency.scopeRateLimitedTooltip', { scope, count })"
>
{{ formatScopeName(scope as string) }} {{ count }}
</span>
</template>
<!-- 异常账号 -->
<span
v-if="row.error_accounts > 0"