Compare commits

..

20 Commits

Author SHA1 Message Date
shaw
e4d74ae11d feat(ui): 用户列表页显示当前并发数
优化 /admin/users 页面的并发数列,显示「当前/最大」格式,
参考 AccountCapacityCell 的设计风格。

- 后端 UserHandler 注入 ConcurrencyService,批量查询用户当前并发数
- 新增 UserConcurrencyCell 组件,支持颜色状态(空闲灰/使用中黄/满载红)
- 前端 AdminUser 类型添加 current_concurrency 字段
2026-02-08 16:44:51 +08:00
shaw
8a0a8558cf feat(ui): OpenAI OAuth 账号支持批量 RT 输入创建
新增通过手动输入 Refresh Token 创建 OpenAI OAuth 账号功能,
参考 Anthropic sessionKey 批量创建方式:

- useOpenAIOAuth 添加 validateRefreshToken 方法
- accounts.ts 添加 refreshOpenAIToken API
- AuthInputMethod 类型新增 refresh_token 选项
- 支持多行输入 RT(每行一个)批量创建账号
- 账号名称自动累加后缀 #1, #2 等
- UI 显示 RT 数量徽章和批量创建提示
- 添加中英文 i18n 翻译
2026-02-08 16:10:15 +08:00
Wesley Liddick
2185a3b674 Merge pull request #517 from touwaeriol/fix/upstream-baseurl
refactor(upstream): replace upstream account type with apikey + auto-append base_url
2026-02-08 14:03:12 +08:00
Wesley Liddick
9e3c306a5b Merge pull request #513 from touwaeriol/pr/antigravity-full-v2
feat(antigravity): comprehensive enhancements — rate limiting, scheduling & smart retry
2026-02-08 14:01:17 +08:00
shaw
b1c30df8e3 fix(ui): unify admin table toolbar layout with search and buttons in single row
Standardize filter bar layout across admin pages to place search/filters
on left and action buttons on right within the same row, improving
visual consistency and space utilization.
2026-02-08 14:00:02 +08:00
erio
69816f8691 fix: remove unused upstreamHopByHopHeaders variable to pass golangci-lint 2026-02-08 13:30:39 +08:00
shaw
b4ec65785d fix: apikey类型账号test去掉oauth-2025-04-20 2026-02-08 13:26:28 +08:00
erio
3c93644146 chore: bump version to 0.1.74.7 2026-02-08 13:14:58 +08:00
erio
fb58560d15 refactor(upstream): replace upstream account type with apikey, auto-append /antigravity
Upstream accounts now use the standard APIKey type instead of a dedicated
upstream type. GetBaseURL() and new GetGeminiBaseURL() automatically append
/antigravity for Antigravity platform APIKey accounts, eliminating the need
for separate upstream forwarding methods.

- Remove ForwardUpstream, ForwardUpstreamGemini, testUpstreamConnection
- Remove upstream branch guards in Forward/ForwardGemini/TestConnection
- Add migration 052 to convert existing upstream accounts to apikey
- Update frontend CreateAccountModal to create apikey type
- Add unit tests for GetBaseURL and GetGeminiBaseURL
2026-02-08 13:06:25 +08:00
erio
6ab77f5eb5 fix(upstream): passthrough response body directly instead of parsing SSE
ForwardUpstream/ForwardUpstreamGemini should pipe the upstream response
directly to the client (headers + body), not parse it as SSE stream.
2026-02-08 08:49:43 +08:00
erio
4f57d7f761 fix: add nil guard for gin.Context in header passthrough to satisfy staticcheck SA5011 2026-02-08 08:36:35 +08:00
erio
1563bd3dda feat(upstream): passthrough all client headers instead of manual header setting
Replace manual header setting (Content-Type, anthropic-version, anthropic-beta)
with full client header passthrough in ForwardUpstream/ForwardUpstreamGemini.
Only authentication headers (Authorization, x-api-key) are overridden with
upstream account credentials. Hop-by-hop headers are excluded.

Add unit tests covering header passthrough, auth override, and hop-by-hop filtering.
2026-02-08 08:33:09 +08:00
erio
df3346387f fix(frontend): upstream account edit fields and mixed_scheduling on create
- EditAccountModal: add Base URL / API Key fields for upstream type
- EditAccountModal: initialize editBaseUrl from credentials on upstream account open
- EditAccountModal: save upstream credentials (base_url, api_key) on submit
- CreateAccountModal: pass mixed_scheduling extra when creating upstream account
2026-02-08 02:08:51 +08:00
erio
77b66653ed fix(gateway): restore upstream account forwarding with dedicated methods
v0.1.74 merged upstream accounts into the OAuth path, causing requests
to hit the wrong protocol and endpoint. Add three upstream-specific
methods (testUpstreamConnection, ForwardUpstream, ForwardUpstreamGemini)
that use base_url + apiKey auth and passthrough the original body, while
reusing the existing response handling and error/retry logic.
2026-02-08 01:21:02 +08:00
erio
3077fd279d feat: smart retry max 1 attempt + clear sticky session on failure
- Change antigravitySmartRetryMaxAttempts from 3 to 1 to prevent
  repeated rate limiting and long waits
- Clear sticky session binding (DeleteSessionAccountID) after smart
  retry exhaustion, so subsequent requests don't hit the same
  rate-limited account
- Add flow diagrams to Forward/ForwardGemini doc comments
- Add comprehensive unit tests covering:
  - Sticky session cleared on retry failure (429, 503, network error)
  - Sticky session NOT cleared on retry success
  - Sticky session NOT cleared for non-sticky requests (empty hash)
  - Sticky session NOT cleared on long delay path (handled by handler)
  - Nil cache safety (no panic)
  - MaxAttempts constant verification
  - End-to-end retryLoop → switchError propagation with session clear
2026-02-07 19:30:58 +08:00
erio
e3748da860 fix(lint): handle errcheck for strings.Builder.WriteString 2026-02-07 18:18:15 +08:00
erio
36e6fb5fc8 ci: trigger CI for new PR 2026-02-07 18:13:37 +08:00
erio
86b503f87f refactor: remove Anthropic digest chain from Messages handler
The digest chain fallback is only needed for Gemini endpoints, not
for the Anthropic Messages API path. Remove the handler integration
while keeping the reusable service/repository layer for future use.
2026-02-07 18:01:04 +08:00
erio
50a783ff01 feat: add Anthropic sticky session digest chain matching via Trie
The previous fallback (step 3) in GenerateSessionHash hashed system +
all messages together, producing a different hash each round as the
conversation grew ([a] -> [a,b] -> [a,b,c]). This made fallback sticky
sessions ineffective for multi-turn conversations.

Implement per-message Trie digest chain matching (reusing Gemini's Trie
infrastructure) so that the previous round's chain is always a prefix
of the current round's chain, enabling reliable session affinity.
2026-02-07 18:00:56 +08:00
erio
e1a68497d6 refactor: simplify sticky session rate limit handling — switch immediately on any rate limit
Remove threshold-based waiting in both sticky session and antigravity
pre-check paths. When a model is rate-limited, immediately clear the
sticky session and switch accounts instead of waiting for short durations.
2026-02-07 17:06:49 +08:00
37 changed files with 1982 additions and 266 deletions

View File

@@ -1 +1 @@
0.1.70
0.1.74.7

View File

@@ -102,7 +102,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
@@ -126,13 +128,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)

View File

@@ -16,7 +16,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router := gin.New()
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc)
userHandler := NewUserHandler(adminSvc, nil)
groupHandler := NewGroupHandler(adminSvc)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc)

View File

@@ -11,15 +11,23 @@ import (
"github.com/gin-gonic/gin"
)
// UserWithConcurrency wraps AdminUser with current concurrency info
type UserWithConcurrency struct {
dto.AdminUser
CurrentConcurrency int `json:"current_concurrency"`
}
// UserHandler handles admin user management
type UserHandler struct {
adminService service.AdminService
adminService service.AdminService
concurrencyService *service.ConcurrencyService
}
// NewUserHandler creates a new admin user handler
func NewUserHandler(adminService service.AdminService) *UserHandler {
func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler {
return &UserHandler{
adminService: adminService,
adminService: adminService,
concurrencyService: concurrencyService,
}
}
@@ -87,10 +95,30 @@ func (h *UserHandler) List(c *gin.Context) {
return
}
out := make([]dto.AdminUser, 0, len(users))
for i := range users {
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
// Batch get current concurrency (nil map if unavailable)
var loadInfo map[int64]*service.UserLoadInfo
if len(users) > 0 && h.concurrencyService != nil {
usersConcurrency := make([]service.UserWithConcurrency, len(users))
for i := range users {
usersConcurrency[i] = service.UserWithConcurrency{
ID: users[i].ID,
MaxConcurrency: users[i].Concurrency,
}
}
loadInfo, _ = h.concurrencyService.GetUsersLoadBatch(c.Request.Context(), usersConcurrency)
}
// Build response with concurrency info
out := make([]UserWithConcurrency, len(users))
for i := range users {
out[i] = UserWithConcurrency{
AdminUser: *dto.UserFromServiceAdmin(&users[i]),
}
if info := loadInfo[users[i].ID]; info != nil {
out[i].CurrentConcurrency = info.CurrentConcurrency
}
}
response.Paginated(c, out, total, page, pageSize)
}

View File

@@ -482,7 +482,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
} else {
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)

View File

@@ -410,7 +410,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
} else {
result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body)

View File

@@ -238,3 +238,41 @@ func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, pre
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
}
// ============ Anthropic 会话 Fallback 方法 (复用 Trie 实现) ============
// FindAnthropicSession 查找 Anthropic 会话(复用 Gemini Trie Lua 脚本)
func (c *gatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" {
return "", 0, false
}
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
if err != nil || result == nil {
return "", 0, false
}
value, ok := result.(string)
if !ok || value == "" {
return "", 0, false
}
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
return uuid, accountID, ok
}
// SaveAnthropicSession 保存 Anthropic 会话(复用 Gemini Trie Lua 脚本)
func (c *gatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" {
return nil
}
trieKey := service.BuildAnthropicTrieKey(groupID, prefixHash)
value := service.FormatGeminiSessionValue(uuid, accountID)
ttlSeconds := int(service.AnthropicSessionTTL().Seconds())
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
}

View File

@@ -425,6 +425,22 @@ func (a *Account) GetBaseURL() string {
if baseURL == "" {
return "https://api.anthropic.com"
}
if a.Platform == PlatformAntigravity {
return strings.TrimRight(baseURL, "/") + "/antigravity"
}
return baseURL
}
// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。
// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。
func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string {
baseURL := strings.TrimSpace(a.GetCredential("base_url"))
if baseURL == "" {
return defaultBaseURL
}
if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey {
return strings.TrimRight(baseURL, "/") + "/antigravity"
}
return baseURL
}

View File

@@ -0,0 +1,160 @@
//go:build unit
package service
import (
"testing"
)
func TestGetBaseURL(t *testing.T) {
tests := []struct {
name string
account Account
expected string
}{
{
name: "non-apikey type returns empty",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAnthropic,
},
expected: "",
},
{
name: "apikey without base_url returns default anthropic",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAnthropic,
Credentials: map[string]any{},
},
expected: "https://api.anthropic.com",
},
{
name: "apikey with custom base_url",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAnthropic,
Credentials: map[string]any{"base_url": "https://custom.example.com"},
},
expected: "https://custom.example.com",
},
{
name: "antigravity apikey auto-appends /antigravity",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity apikey trims trailing slash before appending",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity non-apikey returns empty",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetBaseURL()
if result != tt.expected {
t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected)
}
})
}
}
func TestGetGeminiBaseURL(t *testing.T) {
const defaultGeminiURL = "https://generativelanguage.googleapis.com"
tests := []struct {
name string
account Account
expected string
}{
{
name: "apikey without base_url returns default",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{},
},
expected: defaultGeminiURL,
},
{
name: "apikey with custom base_url",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"},
},
expected: "https://custom-gemini.example.com",
},
{
name: "antigravity apikey auto-appends /antigravity",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity apikey trims trailing slash",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity oauth does NOT append /antigravity",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "https://upstream.example.com",
},
{
name: "oauth without base_url returns default",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{},
},
expected: defaultGeminiURL,
},
{
name: "nil credentials returns default",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
expected: defaultGeminiURL,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetGeminiBaseURL(defaultGeminiURL)
if result != tt.expected {
t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected)
}
})
}
}

View File

@@ -245,7 +245,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
// Set common headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
// Apply Claude Code client headers
for key, value := range claude.DefaultHeaders {
@@ -254,8 +253,10 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
// Set authentication header
if useBearer {
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
req.Header.Set("Authorization", "Bearer "+authToken)
} else {
req.Header.Set("anthropic-beta", claude.APIKeyBetaHeader)
req.Header.Set("x-api-key", authToken)
}

View File

@@ -0,0 +1,89 @@
package service
import (
"encoding/json"
"strconv"
"strings"
"time"
)
// Anthropic 会话 Fallback 相关常量
const (
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL5 分钟)
anthropicSessionTTLSeconds = 300
// anthropicTrieKeyPrefix Anthropic Trie 会话 key 前缀
anthropicTrieKeyPrefix = "anthropic:trie:"
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
)
// AnthropicSessionTTL 返回 Anthropic 会话缓存 TTL
func AnthropicSessionTTL() time.Duration {
return anthropicSessionTTLSeconds * time.Second
}
// BuildAnthropicDigestChain 根据 Anthropic 请求生成摘要链
// 格式: s:<hash>-u:<hash>-a:<hash>-u:<hash>-...
// s = system, u = user, a = assistant
func BuildAnthropicDigestChain(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
var parts []string
// 1. system prompt
if parsed.System != nil {
systemData, _ := json.Marshal(parsed.System)
if len(systemData) > 0 && string(systemData) != "null" {
parts = append(parts, "s:"+shortHash(systemData))
}
}
// 2. messages
for _, msg := range parsed.Messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
role, _ := msgMap["role"].(string)
prefix := rolePrefix(role)
content := msgMap["content"]
contentData, _ := json.Marshal(content)
parts = append(parts, prefix+":"+shortHash(contentData))
}
return strings.Join(parts, "-")
}
// rolePrefix 将 Anthropic 的 role 映射为单字符前缀
func rolePrefix(role string) string {
switch role {
case "assistant":
return "a"
default:
return "u"
}
}
// BuildAnthropicTrieKey 构建 Anthropic Trie Redis key
// 格式: anthropic:trie:{groupID}:{prefixHash}
func BuildAnthropicTrieKey(groupID int64, prefixHash string) string {
return anthropicTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
prefix := prefixHash
if len(prefixHash) >= 8 {
prefix = prefixHash[:8]
}
uuidPart := uuid
if len(uuid) >= 8 {
uuidPart = uuid[:8]
}
return anthropicDigestSessionKeyPrefix + prefix + ":" + uuidPart
}

View File

@@ -0,0 +1,357 @@
package service
import (
"strings"
"testing"
)
func TestBuildAnthropicDigestChain_NilRequest(t *testing.T) {
result := BuildAnthropicDigestChain(nil)
if result != "" {
t.Errorf("expected empty string for nil request, got: %s", result)
}
}
func TestBuildAnthropicDigestChain_EmptyMessages(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{},
}
result := BuildAnthropicDigestChain(parsed)
if result != "" {
t.Errorf("expected empty string for empty messages, got: %s", result)
}
}
func TestBuildAnthropicDigestChain_SingleUserMessage(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("expected prefix 'u:', got: %s", parts[0])
}
}
func TestBuildAnthropicDigestChain_UserAndAssistant(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("part[0] expected prefix 'u:', got: %s", parts[0])
}
if !strings.HasPrefix(parts[1], "a:") {
t.Errorf("part[1] expected prefix 'a:', got: %s", parts[1])
}
}
func TestBuildAnthropicDigestChain_WithSystemString(t *testing.T) {
parsed := &ParsedRequest{
System: "You are a helpful assistant",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "s:") {
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
}
if !strings.HasPrefix(parts[1], "u:") {
t.Errorf("part[1] expected prefix 'u:', got: %s", parts[1])
}
}
func TestBuildAnthropicDigestChain_WithSystemContentBlocks(t *testing.T) {
parsed := &ParsedRequest{
System: []any{
map[string]any{"type": "text", "text": "You are a helpful assistant"},
},
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "s:") {
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
}
}
func TestBuildAnthropicDigestChain_ConversationPrefixRelationship(t *testing.T) {
// 核心测试:验证对话增长时链的前缀关系
// 上一轮的完整链一定是下一轮链的前缀
system := "You are a helpful assistant"
// 第 1 轮: system + user
round1 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
chain1 := BuildAnthropicDigestChain(round1)
// 第 2 轮: system + user + assistant + user
round2 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
map[string]any{"role": "user", "content": "how are you?"},
},
}
chain2 := BuildAnthropicDigestChain(round2)
// 第 3 轮: system + user + assistant + user + assistant + user
round3 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
map[string]any{"role": "user", "content": "how are you?"},
map[string]any{"role": "assistant", "content": "I'm doing well"},
map[string]any{"role": "user", "content": "great"},
},
}
chain3 := BuildAnthropicDigestChain(round3)
t.Logf("Chain1: %s", chain1)
t.Logf("Chain2: %s", chain2)
t.Logf("Chain3: %s", chain3)
// chain1 是 chain2 的前缀
if !strings.HasPrefix(chain2, chain1) {
t.Errorf("chain1 should be prefix of chain2:\n chain1: %s\n chain2: %s", chain1, chain2)
}
// chain2 是 chain3 的前缀
if !strings.HasPrefix(chain3, chain2) {
t.Errorf("chain2 should be prefix of chain3:\n chain2: %s\n chain3: %s", chain2, chain3)
}
// chain1 也是 chain3 的前缀(传递性)
if !strings.HasPrefix(chain3, chain1) {
t.Errorf("chain1 should be prefix of chain3:\n chain1: %s\n chain3: %s", chain1, chain3)
}
}
func TestBuildAnthropicDigestChain_DifferentSystemProducesDifferentChain(t *testing.T) {
parsed1 := &ParsedRequest{
System: "System A",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
parsed2 := &ParsedRequest{
System: "System B",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
chain1 := BuildAnthropicDigestChain(parsed1)
chain2 := BuildAnthropicDigestChain(parsed2)
if chain1 == chain2 {
t.Error("Different system prompts should produce different chains")
}
// 但 user 部分的 hash 应该相同
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
if parts1[1] != parts2[1] {
t.Error("Same user message should produce same hash regardless of system")
}
}
func TestBuildAnthropicDigestChain_DifferentContentProducesDifferentChain(t *testing.T) {
parsed1 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "ORIGINAL reply"},
map[string]any{"role": "user", "content": "next"},
},
}
parsed2 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "TAMPERED reply"},
map[string]any{"role": "user", "content": "next"},
},
}
chain1 := BuildAnthropicDigestChain(parsed1)
chain2 := BuildAnthropicDigestChain(parsed2)
if chain1 == chain2 {
t.Error("Different content should produce different chains")
}
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
// 第一个 user message hash 应该相同
if parts1[0] != parts2[0] {
t.Error("First user message hash should be the same")
}
// assistant reply hash 应该不同
if parts1[1] == parts2[1] {
t.Error("Assistant reply hash should differ")
}
}
func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
parsed := &ParsedRequest{
System: "test system",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi"},
},
}
chain1 := BuildAnthropicDigestChain(parsed)
chain2 := BuildAnthropicDigestChain(parsed)
if chain1 != chain2 {
t.Errorf("BuildAnthropicDigestChain not deterministic: %s vs %s", chain1, chain2)
}
}
func TestBuildAnthropicTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "anthropic:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "anthropic:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "anthropic:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildAnthropicTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildAnthropicTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
tests := []struct {
name string
prefixHash string
uuid string
want string
}{
{
name: "normal 16 char hash with uuid",
prefixHash: "abcdefgh12345678",
uuid: "550e8400-e29b-41d4-a716-446655440000",
want: "anthropic:digest:abcdefgh:550e8400",
},
{
name: "exactly 8 chars",
prefixHash: "12345678",
uuid: "abcdefgh",
want: "anthropic:digest:12345678:abcdefgh",
},
{
name: "short values",
prefixHash: "abc",
uuid: "xyz",
want: "anthropic:digest:abc:xyz",
},
{
name: "empty values",
prefixHash: "",
uuid: "",
want: "anthropic:digest::",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateAnthropicDigestSessionKey(tt.prefixHash, tt.uuid)
if got != tt.want {
t.Errorf("GenerateAnthropicDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want)
}
})
}
// 验证不同 uuid 产生不同 sessionKey
t.Run("different uuid different key", func(t *testing.T) {
hash := "sameprefix123456"
result1 := GenerateAnthropicDigestSessionKey(hash, "uuid0001-session-a")
result2 := GenerateAnthropicDigestSessionKey(hash, "uuid0002-session-b")
if result1 == result2 {
t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2)
}
})
}
func TestAnthropicSessionTTL(t *testing.T) {
ttl := AnthropicSessionTTL()
if ttl.Seconds() != 300 {
t.Errorf("expected 300 seconds, got: %v", ttl.Seconds())
}
}
func TestBuildAnthropicDigestChain_ContentBlocks(t *testing.T) {
// 测试 content 为 content blocks 数组的情况
parsed := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "describe this image"},
map[string]any{"type": "image", "source": map[string]any{"type": "base64"}},
},
},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("expected prefix 'u:', got: %s", parts[0])
}
}

View File

@@ -35,7 +35,7 @@ const (
// - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号
antigravityRateLimitThreshold = 7 * time.Second
antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间
antigravitySmartRetryMaxAttempts = 3 // 智能重试最大次数
antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待)
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
// Google RPC 状态和类型常量
@@ -247,6 +247,11 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
}
}
// 清除粘性会话绑定,避免下次请求仍命中限流账号
if s.cache != nil && p.sessionHash != "" {
_ = s.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash)
}
// 返回账号切换信号,让上层切换账号重试
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
@@ -264,27 +269,15 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
// 预检查:如果账号已限流,根据剩余时间决定等待或切换
// 预检查:如果账号已限流,直接返回切换信号
if p.requestedModel != "" {
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
if remaining < antigravityRateLimitThreshold {
// 限流剩余时间较短,等待后继续
log.Printf("%s pre_check: rate_limit_wait remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
select {
case <-p.ctx.Done():
return nil, p.ctx.Err()
case <-time.After(remaining):
}
} else {
// 限流剩余时间较长,返回账号切换信号
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Second), p.requestedModel, p.account.ID)
return nil, &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: p.requestedModel,
IsStickySession: p.isStickySession,
}
log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
return nil, &AntigravityAccountSwitchError{
OriginalAccountID: p.account.ID,
RateLimitedModel: p.requestedModel,
IsStickySession: p.isStickySession,
}
}
}
@@ -650,6 +643,7 @@ type TestConnectionResult struct {
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
// 获取 token
if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured")
@@ -964,8 +958,19 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
}
// Forward 转发 Claude 协议请求Claude → Gemini 转换)
//
// 限流处理流程:
//
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
// ├─ 成功 → 正常返回
// └─ 429/503 → handleSmartRetry
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
// └─ retryDelay < 7s → 等待后重试 1 次
// ├─ 成功 → 正常返回
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
@@ -1583,8 +1588,19 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
}
// ForwardGemini 转发 Gemini 协议请求
//
// 限流处理流程:
//
// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游
// ├─ 成功 → 正常返回
// └─ 429/503 → handleSmartRetry
// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号
// └─ retryDelay < 7s → 等待后重试 1 次
// ├─ 成功 → 正常返回
// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)

View File

@@ -803,7 +803,7 @@ func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) {
require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope")
}
func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) {
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 1,
@@ -815,19 +815,15 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
// RFC3339 here is second-precision; keep it safely in the future.
"rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339),
},
},
},
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
@@ -841,12 +837,16 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
},
})
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, result)
require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession)
require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check")
}
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) {
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 2,

View File

@@ -13,6 +13,23 @@ import (
"github.com/stretchr/testify/require"
)
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
// 仅关注 DeleteSessionAccountID 的调用记录
type stubSmartRetryCache struct {
GatewayCache // 嵌入接口,未实现的方法 panic确保只调用预期方法
deleteCalls []deleteSessionCall
}
type deleteSessionCall struct {
groupID int64
sessionHash string
}
func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID int64, sessionHash string) error {
c.deleteCalls = append(c.deleteCalls, deleteSessionCall{groupID: groupID, sessionHash: sessionHash})
return nil
}
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
type mockSmartRetryUpstream struct {
responses []*http.Response
@@ -198,7 +215,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
// TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError
func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *testing.T) {
// 智能重试后仍然返回 429需要提供 3 个响应,因为智能重试最多 3 次)
// 智能重试后仍然返回 429需要提供 1 个响应,因为智能重试最多 1 次)
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
@@ -213,19 +230,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp2 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp3 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp1, failResp2, failResp3},
errors: []error{nil, nil, nil},
responses: []*http.Response{failResp1},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
@@ -236,7 +243,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
Platform: PlatformAntigravity,
}
// 3s < 7s 阈值,应该触发智能重试(最多 3 次)
// 3s < 7s 阈值,应该触发智能重试(最多 1 次)
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
@@ -284,7 +291,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-flash", repo.modelRateLimitCalls[0].modelKey)
require.Len(t, upstream.calls, 3, "should have made three retry calls (max attempts)")
require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)")
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError
@@ -556,19 +563,15 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
require.True(t, switchErr.IsStickySession)
}
// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试
func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
// 第一次网络错误,第二次成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
// TestHandleSmartRetry_NetworkError_ExhaustsRetry 测试网络错误时maxAttempts=1直接耗尽重试并切换账号
func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) {
// 唯一一次重试遇到网络错误nil response
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil, successResp}, // 第一次返回 nil模拟网络错误
errors: []error{nil, nil}, // mock 不返回 error靠 nil response 触发
responses: []*http.Response{nil}, // 返回 nil模拟网络错误
errors: []error{nil}, // mock 不返回 error靠 nil response 触发
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 8,
Name: "acc-8",
@@ -600,6 +603,7 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
@@ -612,10 +616,15 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after network error recovery")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not return switchError on success")
require.Len(t, upstream.calls, 2, "should have made two retry calls")
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.NotNil(t, result.switchError, "should return switchError after network error exhausted retry")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel)
require.Len(t, upstream.calls, 1, "should have made one retry call")
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流
@@ -674,3 +683,617 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// ---------------------------------------------------------------------------
// 以下测试覆盖本次改动:
// 1. antigravitySmartRetryMaxAttempts = 1仅重试 1 次)
// 2. 智能重试失败后清除粘性会话绑定DeleteSessionAccountID
// ---------------------------------------------------------------------------
// TestSmartRetryMaxAttempts_VerifyConstant 验证常量值为 1
func TestSmartRetryMaxAttempts_VerifyConstant(t *testing.T) {
require.Equal(t, 1, antigravitySmartRetryMaxAttempts,
"antigravitySmartRetryMaxAttempts should be 1 to prevent repeated rate limiting")
}
// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession
// 核心场景:粘性会话 + 短延迟重试失败 → 必须清除粘性绑定
func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t *testing.T) {
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 10,
Name: "acc-10",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-abc",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
// 验证返回 switchError
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession, "switchError should carry IsStickySession=true")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
// 核心断言DeleteSessionAccountID 被调用,且参数正确
require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID exactly once")
require.Equal(t, int64(42), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-hash-abc", cache.deleteCalls[0].sessionHash)
// 验证仅重试 1 次
require.Len(t, upstream.calls, 1, "should make exactly 1 retry call (maxAttempts=1)")
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession
// 非粘性会话 + 短延迟重试失败 → 不应调用 DeleteSessionAccountIDsessionHash 为空)
func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession(t *testing.T) {
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 11,
Name: "acc-11",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: false,
groupID: 42,
sessionHash: "", // 非粘性会话sessionHash 为空
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.False(t, result.switchError.IsStickySession)
// 核心断言sessionHash 为空时不应调用 DeleteSessionAccountID
require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID when sessionHash is empty")
}
// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic
// 边界cache 为 nil 时不应 panic
func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic(t *testing.T) {
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 12,
Name: "acc-12",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-nil-cache",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
// cache 为 nil不应 panic
svc := &AntigravityGatewayService{cache: nil}
require.NotPanics(t, func() {
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
})
}
// TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession
// 重试成功时不应清除粘性会话(只有失败才清除)
func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 13,
Name: "acc-13",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-success",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not return switchError on success")
// 核心断言:重试成功时不应清除粘性会话
require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID on successful retry")
}
// TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry
// 长延迟路径情况1在 handleSmartRetry 中不直接调用 DeleteSessionAccountID
// (清除由 handler 层的 shouldClearStickySession 在下次请求时处理)
func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 14,
Name: "acc-14",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 15s >= 7s 阈值 → 走长延迟路径
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-long-delay",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
// 长延迟路径不在 handleSmartRetry 中调用 DeleteSessionAccountID
// (由上游 handler 的 shouldClearStickySession 处理)
require.Len(t, cache.deleteCalls, 0,
"long delay path should NOT call DeleteSessionAccountID in handleSmartRetry (handled by handler layer)")
}
// TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession
// 网络错误耗尽重试 + 粘性会话 → 也应清除粘性绑定
func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t *testing.T) {
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil}, // 网络错误
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 15,
Name: "acc-15",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 99,
sessionHash: "sticky-net-error",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
// 核心断言:网络错误耗尽重试后也应清除粘性绑定
require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID after network error exhausts retry")
require.Equal(t, int64(99), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-net-error", cache.deleteCalls[0].sessionHash)
}
// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
// 503 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) {
failRespBody := `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 16,
Name: "acc-16",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 77,
sessionHash: "sticky-503-short",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
// 验证粘性绑定被清除
require.Len(t, cache.deleteCalls, 1)
require.Equal(t, int64(77), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-503-short", cache.deleteCalls[0].sessionHash)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro", repo.modelRateLimitCalls[0].modelKey)
}
// TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates
// 集成测试antigravityRetryLoop → handleSmartRetry → switchError 传播
// 验证 IsStickySession 正确传递到上层,且粘性绑定被清除
func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates(t *testing.T) {
// 初始 429 响应
initialRespBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
initialResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(initialRespBody)),
}
// 智能重试也返回 429
retryRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
retryResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(retryRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{initialResp, retryResp},
errors: []error{nil, nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 17,
Name: "acc-17",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
svc := &AntigravityGatewayService{cache: cache}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 55,
sessionHash: "sticky-loop-test",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result, "should not return result when switchError")
require.NotNil(t, err, "should return error")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr, "error should be AntigravityAccountSwitchError")
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-opus-4-6", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession, "IsStickySession must propagate through retryLoop")
// 验证粘性绑定被清除
require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry")
require.Equal(t, int64(55), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash)
}

View File

@@ -232,6 +232,14 @@ func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, gro
return nil
}
func (m *mockGatewayCacheForPlatform) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int

View File

@@ -313,6 +313,14 @@ type GatewayCache interface {
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
// FindAnthropicSession 查找 Anthropic 会话Trie 匹配)
// Find Anthropic session using Trie matching
FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveAnthropicSession 保存 Anthropic 会话
// Save Anthropic session binding
SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@@ -323,21 +331,15 @@ func derefGroupID(groupID *int64) int64 {
return *groupID
}
// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。
// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。
// 低于此阈值时保持粘性会话,等待短暂限流结束。
const stickySessionRateLimitThreshold = 10 * time.Second
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。
// 或请求的模型处于限流状态时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// within temporary unschedulable period, or model rate limit remaining time
// exceeds stickySessionRateLimitThreshold.
// within temporary unschedulable period, or the requested model is rate-limited.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil {
@@ -349,8 +351,8 @@ func shouldClearStickySession(account *Account, requestedModel string) bool {
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
// 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold {
// 检查模型限流和 scope 限流,有限流即清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
return true
}
return false
@@ -488,23 +490,25 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return s.hashContent(cacheableContent)
}
// 3. Fallback: 使用 system 内容
// 3. 最后 fallback: 使用 system + 所有消息的完整摘要串
var combined strings.Builder
if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" {
return s.hashContent(systemText)
_, _ = combined.WriteString(systemText)
}
}
// 4. 最后 fallback: 使用第一条消息
if len(parsed.Messages) > 0 {
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
msgText := s.extractTextFromContent(firstMsg["content"])
for _, msg := range parsed.Messages {
if m, ok := msg.(map[string]any); ok {
msgText := s.extractTextFromContent(m["content"])
if msgText != "" {
return s.hashContent(msgText)
_, _ = combined.WriteString(msgText)
}
}
}
if combined.Len() > 0 {
return s.hashContent(combined.String())
}
return ""
}
@@ -547,6 +551,22 @@ func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, p
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
}
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
func (s *GatewayService) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
}
return s.cache.FindAnthropicSession(ctx, groupID, prefixHash, digestChain)
}
// SaveAnthropicSession 保存 Anthropic 会话
func (s *GatewayService) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
return nil
}
return s.cache.SaveAnthropicSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
@@ -1110,7 +1130,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
}
@@ -1264,7 +1283,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
@@ -2169,9 +2187,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -2272,9 +2287,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
@@ -2383,9 +2395,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -2488,9 +2497,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}

View File

@@ -560,10 +560,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
@@ -640,10 +637,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
@@ -1026,10 +1020,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
@@ -1097,10 +1088,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
@@ -2420,10 +2408,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err

View File

@@ -281,6 +281,14 @@ func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, group
return nil
}
func (m *mockGatewayCacheForGemini) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()

View File

@@ -220,6 +220,14 @@ func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64,
return nil
}
func (c *stubGatewayCache) FindAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveAnthropicSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)

View File

@@ -23,8 +23,7 @@ import (
// - 临时不可调度且未过期:清理
// - 临时不可调度已过期:不清理
// - 正常可调度状态:不清理
// - 模型限流超过阈值:清理
// - 模型限流未超过阈值:不清理
// - 模型限流(任意时长):清理
//
// TestShouldClearStickySession tests the sticky session clearing logic.
// Verifies correct behavior for various account states including:
@@ -35,9 +34,9 @@ func TestShouldClearStickySession(t *testing.T) {
future := now.Add(1 * time.Hour)
past := now.Add(-1 * time.Hour)
// 短限流时间(低于阈值,不应清除粘性会话)
// 短限流时间(有限流即清除粘性会话)
shortRateLimitReset := now.Add(5 * time.Second).Format(time.RFC3339)
// 长限流时间(超过阈值,应清除粘性会话)
// 长限流时间(有限流即清除粘性会话)
longRateLimitReset := now.Add(30 * time.Second).Format(time.RFC3339)
tests := []struct {
@@ -53,7 +52,7 @@ func TestShouldClearStickySession(t *testing.T) {
{name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, requestedModel: "", want: true},
{name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, requestedModel: "", want: false},
{name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, requestedModel: "", want: false},
// 模型限流测试
// 模型限流测试:有限流即清除
{
name: "model rate limited short duration",
account: &Account{
@@ -68,7 +67,7 @@ func TestShouldClearStickySession(t *testing.T) {
},
},
requestedModel: "claude-sonnet-4",
want: false, // 低于阈值,不清除
want: true, // 有限流即清除
},
{
name: "model rate limited long duration",
@@ -84,7 +83,7 @@ func TestShouldClearStickySession(t *testing.T) {
},
},
requestedModel: "claude-sonnet-4",
want: true, // 超过阈值,清除
want: true, // 有限流即清除
},
{
name: "model rate limited different model",

View File

@@ -0,0 +1,11 @@
-- Migrate upstream accounts to apikey type
-- Background: upstream type is no longer needed. Antigravity platform APIKey accounts
-- with base_url pointing to an upstream sub2api instance can reuse the standard
-- APIKey forwarding path. GetBaseURL()/GetGeminiBaseURL() automatically appends
-- /antigravity for Antigravity platform APIKey accounts.
UPDATE accounts
SET type = 'apikey'
WHERE type = 'upstream'
AND platform = 'antigravity'
AND deleted_at IS NULL;

View File

@@ -398,6 +398,26 @@ export async function getAntigravityDefaultModelMapping(): Promise<Record<string
return data
}
/**
* Refresh OpenAI token using refresh token
* @param refreshToken - The refresh token
* @param proxyId - Optional proxy ID
* @returns Token information including access_token, email, etc.
*/
export async function refreshOpenAIToken(
refreshToken: string,
proxyId?: number | null
): Promise<Record<string, unknown>> {
const payload: { refresh_token: string; proxy_id?: number } = {
refresh_token: refreshToken
}
if (proxyId) {
payload.proxy_id = proxyId
}
const { data } = await apiClient.post<Record<string, unknown>>('/admin/openai/refresh-token', payload)
return data
}
export const accountsAPI = {
list,
getById,
@@ -418,6 +438,7 @@ export const accountsAPI = {
getAvailableModels,
generateAuthUrl,
exchangeCode,
refreshOpenAIToken,
batchCreate,
batchUpdateCredentials,
bulkUpdate,

View File

@@ -1650,10 +1650,12 @@
:show-proxy-warning="form.platform !== 'openai' && !!form.proxy_id"
:allow-multiple="form.platform === 'anthropic'"
:show-cookie-option="form.platform === 'anthropic'"
:show-refresh-token-option="form.platform === 'openai'"
:platform="form.platform"
:show-project-id="geminiOAuthType === 'code_assist'"
@generate-url="handleGenerateUrl"
@cookie-auth="handleCookieAuth"
@validate-refresh-token="handleOpenAIValidateRT"
/>
</div>
@@ -2010,6 +2012,7 @@ interface OAuthFlowExposed {
oauthState: string
projectId: string
sessionKey: string
refreshToken: string
inputMethod: AuthInputMethod
reset: () => void
}
@@ -2289,9 +2292,9 @@ watch(
watch(
[accountCategory, addMethod, antigravityAccountType],
([category, method, agType]) => {
// Antigravity upstream 类型
// Antigravity upstream 类型(实际创建为 apikey
if (form.platform === 'antigravity' && agType === 'upstream') {
form.type = 'upstream'
form.type = 'apikey'
return
}
if (category === 'oauth-based') {
@@ -2714,7 +2717,8 @@ const handleSubmit = async () => {
submitting.value = true
try {
await createAccountAndFinish(form.platform, 'upstream', credentials)
const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined
await createAccountAndFinish(form.platform, 'apikey', credentials, extra)
} catch (error: any) {
appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate'))
} finally {
@@ -2860,6 +2864,95 @@ const handleOpenAIExchange = async (authCode: string) => {
}
}
// OpenAI 手动 RT 批量验证和创建
const handleOpenAIValidateRT = async (refreshTokenInput: string) => {
if (!refreshTokenInput.trim()) return
// Parse multiple refresh tokens (one per line)
const refreshTokens = refreshTokenInput
.split('\n')
.map((rt) => rt.trim())
.filter((rt) => rt)
if (refreshTokens.length === 0) {
openaiOAuth.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken')
return
}
openaiOAuth.loading.value = true
openaiOAuth.error.value = ''
let successCount = 0
let failedCount = 0
const errors: string[] = []
try {
for (let i = 0; i < refreshTokens.length; i++) {
try {
const tokenInfo = await openaiOAuth.validateRefreshToken(
refreshTokens[i],
form.proxy_id
)
if (!tokenInfo) {
failedCount++
errors.push(`#${i + 1}: ${openaiOAuth.error.value || 'Validation failed'}`)
openaiOAuth.error.value = ''
continue
}
const credentials = openaiOAuth.buildCredentials(tokenInfo)
const extra = openaiOAuth.buildExtraInfo(tokenInfo)
// Generate account name with index for batch
const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name
await adminAPI.accounts.create({
name: accountName,
notes: form.notes,
platform: 'openai',
type: 'oauth',
credentials,
extra,
proxy_id: form.proxy_id,
concurrency: form.concurrency,
priority: form.priority,
rate_multiplier: form.rate_multiplier,
group_ids: form.group_ids,
expires_at: form.expires_at,
auto_pause_on_expired: autoPauseOnExpired.value
})
successCount++
} catch (error: any) {
failedCount++
const errMsg = error.response?.data?.detail || error.message || 'Unknown error'
errors.push(`#${i + 1}: ${errMsg}`)
}
}
// Show results
if (successCount > 0 && failedCount === 0) {
appStore.showSuccess(
refreshTokens.length > 1
? t('admin.accounts.oauth.batchSuccess', { count: successCount })
: t('admin.accounts.accountCreated')
)
emit('created')
handleClose()
} else if (successCount > 0 && failedCount > 0) {
appStore.showWarning(
t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount })
)
openaiOAuth.error.value = errors.join('\n')
emit('created')
} else {
openaiOAuth.error.value = errors.join('\n')
appStore.showError(t('admin.accounts.oauth.batchFailed'))
}
} finally {
openaiOAuth.loading.value = false
}
}
// Gemini OAuth 授权码兑换
const handleGeminiExchange = async (authCode: string) => {
if (!authCode.trim() || !geminiOAuth.sessionId.value) return

View File

@@ -364,6 +364,30 @@
</div>
</div>
<!-- Upstream fields (only for upstream type) -->
<div v-if="account.type === 'upstream'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.upstream.baseUrl') }}</label>
<input
v-model="editBaseUrl"
type="text"
class="input"
placeholder="https://s.konstants.xyz"
/>
<p class="input-hint">{{ t('admin.accounts.upstream.baseUrlHint') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.upstream.apiKey') }}</label>
<input
v-model="editApiKey"
type="password"
class="input font-mono"
placeholder="sk-..."
/>
<p class="input-hint">{{ t('admin.accounts.leaveEmptyToKeep') }}</p>
</div>
</div>
<!-- Antigravity model restriction (applies to all antigravity types) -->
<!-- Antigravity 只支持模型映射模式不支持白名单模式 -->
<div v-if="account.platform === 'antigravity'" class="border-t border-gray-200 pt-4 dark:border-dark-600">
@@ -1244,6 +1268,9 @@ watch(
} else {
selectedErrorCodes.value = []
}
} else if (newAccount.type === 'upstream' && newAccount.credentials) {
const credentials = newAccount.credentials as Record<string, unknown>
editBaseUrl.value = (credentials.base_url as string) || ''
} else {
const platformDefaultUrl =
newAccount.platform === 'openai'
@@ -1584,6 +1611,22 @@ const handleSubmit = async () => {
return
}
updatePayload.credentials = newCredentials
} else if (props.account.type === 'upstream') {
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
const newCredentials: Record<string, unknown> = { ...currentCredentials }
newCredentials.base_url = editBaseUrl.value.trim()
if (editApiKey.value.trim()) {
newCredentials.api_key = editApiKey.value.trim()
}
if (!applyTempUnschedConfig(newCredentials)) {
submitting.value = false
return
}
updatePayload.credentials = newCredentials
} else {
// For oauth/setup-token types, only update intercept_warmup_requests if changed

View File

@@ -10,11 +10,11 @@
<h4 class="mb-3 font-semibold text-blue-900 dark:text-blue-200">{{ oauthTitle }}</h4>
<!-- Auth Method Selection -->
<div v-if="showCookieOption" class="mb-4">
<div v-if="showMethodSelection" class="mb-4">
<label class="mb-2 block text-sm font-medium text-blue-800 dark:text-blue-300">
{{ methodLabel }}
</label>
<div class="flex gap-4">
<div class="flex flex-wrap gap-4">
<label class="flex cursor-pointer items-center gap-2">
<input
v-model="inputMethod"
@@ -26,7 +26,7 @@
t('admin.accounts.oauth.manualAuth')
}}</span>
</label>
<label class="flex cursor-pointer items-center gap-2">
<label v-if="showCookieOption" class="flex cursor-pointer items-center gap-2">
<input
v-model="inputMethod"
type="radio"
@@ -37,6 +37,101 @@
t('admin.accounts.oauth.cookieAutoAuth')
}}</span>
</label>
<label v-if="showRefreshTokenOption" class="flex cursor-pointer items-center gap-2">
<input
v-model="inputMethod"
type="radio"
value="refresh_token"
class="text-blue-600 focus:ring-blue-500"
/>
<span class="text-sm text-blue-900 dark:text-blue-200">{{
t('admin.accounts.oauth.openai.refreshTokenAuth')
}}</span>
</label>
</div>
</div>
<!-- Refresh Token Input (OpenAI only) -->
<div v-if="inputMethod === 'refresh_token'" class="space-y-4">
<div
class="rounded-lg border border-blue-300 bg-white/80 p-4 dark:border-blue-600 dark:bg-gray-800/80"
>
<p class="mb-3 text-sm text-blue-700 dark:text-blue-300">
{{ t('admin.accounts.oauth.openai.refreshTokenDesc') }}
</p>
<!-- Refresh Token Input -->
<div class="mb-4">
<label
class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-700 dark:text-gray-300"
>
<Icon name="key" size="sm" class="text-blue-500" />
Refresh Token
<span
v-if="parsedRefreshTokenCount > 1"
class="rounded-full bg-blue-500 px-2 py-0.5 text-xs text-white"
>
{{ t('admin.accounts.oauth.keysCount', { count: parsedRefreshTokenCount }) }}
</span>
</label>
<textarea
v-model="refreshTokenInput"
rows="3"
class="input w-full resize-y font-mono text-sm"
:placeholder="t('admin.accounts.oauth.openai.refreshTokenPlaceholder')"
></textarea>
<p
v-if="parsedRefreshTokenCount > 1"
class="mt-1 text-xs text-blue-600 dark:text-blue-400"
>
{{ t('admin.accounts.oauth.batchCreateAccounts', { count: parsedRefreshTokenCount }) }}
</p>
</div>
<!-- Error Message -->
<div
v-if="error"
class="mb-4 rounded-lg border border-red-200 bg-red-50 p-3 dark:border-red-700 dark:bg-red-900/30"
>
<p class="whitespace-pre-line text-sm text-red-600 dark:text-red-400">
{{ error }}
</p>
</div>
<!-- Validate Button -->
<button
type="button"
class="btn btn-primary w-full"
:disabled="loading || !refreshTokenInput.trim()"
@click="handleValidateRefreshToken"
>
<svg
v-if="loading"
class="-ml-1 mr-2 h-4 w-4 animate-spin"
fill="none"
viewBox="0 0 24 24"
>
<circle
class="opacity-25"
cx="12"
cy="12"
r="10"
stroke="currentColor"
stroke-width="4"
></circle>
<path
class="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
></path>
</svg>
<Icon v-else name="sparkles" size="sm" class="mr-2" />
{{
loading
? t('admin.accounts.oauth.openai.validating')
: t('admin.accounts.oauth.openai.validateAndCreate')
}}
</button>
</div>
</div>
@@ -173,7 +268,7 @@
</div>
<!-- Manual Authorization Flow -->
<div v-else class="space-y-4">
<div v-if="inputMethod === 'manual'" class="space-y-4">
<p class="mb-4 text-sm text-blue-800 dark:text-blue-300">
{{ oauthFollowSteps }}
</p>
@@ -428,6 +523,7 @@ interface Props {
allowMultiple?: boolean
methodLabel?: string
showCookieOption?: boolean // Whether to show cookie auto-auth option
showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only)
platform?: 'anthropic' | 'openai' | 'gemini' | 'antigravity' // Platform type for different UI/text
showProjectId?: boolean // New prop to control project ID visibility
}
@@ -442,6 +538,7 @@ const props = withDefaults(defineProps<Props>(), {
allowMultiple: false,
methodLabel: 'Authorization Method',
showCookieOption: true,
showRefreshTokenOption: false,
platform: 'anthropic',
showProjectId: true
})
@@ -450,6 +547,7 @@ const emit = defineEmits<{
'generate-url': []
'exchange-code': [code: string]
'cookie-auth': [sessionKey: string]
'validate-refresh-token': [refreshToken: string]
'update:inputMethod': [method: AuthInputMethod]
}>()
@@ -487,10 +585,14 @@ const oauthImportantNotice = computed(() => {
const inputMethod = ref<AuthInputMethod>(props.showCookieOption ? 'manual' : 'manual')
const authCodeInput = ref('')
const sessionKeyInput = ref('')
const refreshTokenInput = ref('')
const showHelpDialog = ref(false)
const oauthState = ref('')
const projectId = ref('')
// Computed: show method selection when either cookie or refresh token option is enabled
const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption)
// Clipboard
const { copied, copyToClipboard } = useClipboard()
@@ -502,6 +604,14 @@ const parsedKeyCount = computed(() => {
.filter((k) => k).length
})
// Computed: count of refresh tokens entered
const parsedRefreshTokenCount = computed(() => {
return refreshTokenInput.value
.split('\n')
.map((rt) => rt.trim())
.filter((rt) => rt).length
})
// Watchers
watch(inputMethod, (newVal) => {
emit('update:inputMethod', newVal)
@@ -563,18 +673,26 @@ const handleCookieAuth = () => {
}
}
const handleValidateRefreshToken = () => {
if (refreshTokenInput.value.trim()) {
emit('validate-refresh-token', refreshTokenInput.value.trim())
}
}
// Expose methods and state
defineExpose({
authCode: authCodeInput,
oauthState,
projectId,
sessionKey: sessionKeyInput,
refreshToken: refreshTokenInput,
inputMethod,
reset: () => {
authCodeInput.value = ''
oauthState.value = ''
projectId.value = ''
sessionKeyInput.value = ''
refreshTokenInput.value = ''
inputMethod.value = 'manual'
showHelpDialog.value = false
}

View File

@@ -0,0 +1,43 @@
<template>
<div class="flex items-center">
<span
:class="[
'inline-flex items-center gap-1 rounded-md px-2 py-0.5 text-xs font-medium',
statusClass
]"
>
<!-- Four-square grid icon -->
<svg class="h-3 w-3" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M3.75 6A2.25 2.25 0 016 3.75h2.25A2.25 2.25 0 0110.5 6v2.25a2.25 2.25 0 01-2.25 2.25H6a2.25 2.25 0 01-2.25-2.25V6zM3.75 15.75A2.25 2.25 0 016 13.5h2.25a2.25 2.25 0 012.25 2.25V18a2.25 2.25 0 01-2.25 2.25H6A2.25 2.25 0 013.75 18v-2.25zM13.5 6a2.25 2.25 0 012.25-2.25H18A2.25 2.25 0 0120.25 6v2.25A2.25 2.25 0 0118 10.5h-2.25a2.25 2.25 0 01-2.25-2.25V6zM13.5 15.75a2.25 2.25 0 012.25-2.25H18a2.25 2.25 0 012.25 2.25V18A2.25 2.25 0 0118 20.25h-2.25A2.25 2.25 0 0113.5 18v-2.25z" />
</svg>
<span class="font-mono">{{ current }}</span>
<span class="text-gray-400 dark:text-gray-500">/</span>
<span class="font-mono">{{ max }}</span>
</span>
</div>
</template>
<script setup lang="ts">
import { computed } from 'vue'
const props = defineProps<{
current: number
max: number
}>()
// Status color based on usage
const statusClass = computed(() => {
const { current, max } = props
// Full: red
if (current >= max && max > 0) {
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
}
// In use: yellow
if (current > 0) {
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
}
// Idle: gray
return 'bg-gray-100 text-gray-600 dark:bg-gray-800 dark:text-gray-400'
})
</script>

View File

@@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin'
export type AddMethod = 'oauth' | 'setup-token'
export type AuthInputMethod = 'manual' | 'cookie'
export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token'
export interface OAuthState {
authUrl: string

View File

@@ -105,6 +105,32 @@ export function useOpenAIOAuth() {
}
}
// Validate refresh token and get full token info
const validateRefreshToken = async (
refreshToken: string,
proxyId?: number | null
): Promise<OpenAITokenInfo | null> => {
if (!refreshToken.trim()) {
error.value = 'Missing refresh token'
return null
}
loading.value = true
error.value = ''
try {
// Use dedicated refresh-token endpoint
const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(refreshToken.trim(), proxyId)
return tokenInfo as OpenAITokenInfo
} catch (err: any) {
error.value = err.response?.data?.detail || 'Failed to validate refresh token'
appStore.showError(error.value)
return null
} finally {
loading.value = false
}
}
// Build credentials for OpenAI OAuth account
const buildCredentials = (tokenInfo: OpenAITokenInfo): Record<string, unknown> => {
const creds: Record<string, unknown> = {
@@ -152,6 +178,7 @@ export function useOpenAIOAuth() {
resetState,
generateAuthUrl,
exchangeAuthCode,
validateRefreshToken,
buildCredentials,
buildExtraInfo
}

View File

@@ -1662,6 +1662,9 @@ export default {
cookieAuthFailed: 'Cookie authorization failed',
keyAuthFailed: 'Key {index}: {error}',
successCreated: 'Successfully created {count} account(s)',
batchSuccess: 'Successfully created {count} account(s)',
batchPartialSuccess: 'Partial success: {success} succeeded, {failed} failed',
batchFailed: 'Batch creation failed',
// OpenAI specific
openai: {
title: 'OpenAI Account Authorization',
@@ -1680,7 +1683,14 @@ export default {
authCodePlaceholder:
'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value',
authCodeHint:
'You can copy the entire URL or just the code parameter value, the system will auto-detect'
'You can copy the entire URL or just the code parameter value, the system will auto-detect',
// Refresh Token auth
refreshTokenAuth: 'Manual RT Input',
refreshTokenDesc: 'Enter your existing OpenAI Refresh Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.',
refreshTokenPlaceholder: 'Paste your OpenAI Refresh Token...\nSupports multiple, one per line',
validating: 'Validating...',
validateAndCreate: 'Validate & Create Account',
pleaseEnterRefreshToken: 'Please enter Refresh Token'
},
// Gemini specific
gemini: {

View File

@@ -1804,6 +1804,9 @@ export default {
cookieAuthFailed: 'Cookie 授权失败',
keyAuthFailed: '密钥 {index}: {error}',
successCreated: '成功创建 {count} 个账号',
batchSuccess: '成功创建 {count} 个账号',
batchPartialSuccess: '部分成功:{success} 个成功,{failed} 个失败',
batchFailed: '批量创建失败',
// OpenAI specific
openai: {
title: 'OpenAI 账户授权',
@@ -1820,7 +1823,14 @@ export default {
authCode: '授权链接或 Code',
authCodePlaceholder:
'方式1复制完整的链接\n(http://localhost:xxx/auth/callback?code=...)\n方式2仅复制 code 参数的值',
authCodeHint: '您可以直接复制整个链接或仅复制 code 参数值,系统会自动识别'
authCodeHint: '您可以直接复制整个链接或仅复制 code 参数值,系统会自动识别',
// Refresh Token auth
refreshTokenAuth: '手动输入 RT',
refreshTokenDesc: '输入您已有的 OpenAI Refresh Token支持批量输入每行一个系统将自动验证并创建账号。',
refreshTokenPlaceholder: '粘贴您的 OpenAI Refresh Token...\n支持多个每行一个',
validating: '验证中...',
validateAndCreate: '验证并创建账号',
pleaseEnterRefreshToken: '请输入 Refresh Token'
},
// Gemini specific
gemini: {

View File

@@ -43,6 +43,8 @@ export interface AdminUser extends User {
notes: string
// 用户专属分组倍率配置 (group_id -> rate_multiplier)
group_rates?: Record<number, number>
// 当前并发数(仅管理员列表接口返回)
current_concurrency?: number
}
export interface LoginRequest {

View File

@@ -1,26 +1,10 @@
<template>
<AppLayout>
<TablePageLayout>
<template #actions>
<div class="flex justify-end gap-3">
<button
@click="loadAnnouncements"
:disabled="loading"
class="btn btn-secondary"
:title="t('common.refresh')"
>
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
</button>
<button @click="openCreateDialog" class="btn btn-primary">
<Icon name="plus" size="md" class="mr-1" />
{{ t('admin.announcements.createAnnouncement') }}
</button>
</div>
</template>
<template #filters>
<div class="flex flex-col gap-4 sm:flex-row sm:items-center sm:justify-between">
<div class="max-w-md flex-1">
<div class="flex flex-wrap items-center gap-3">
<!-- Left: Search + Filters -->
<div class="flex-1 sm:max-w-64">
<input
v-model="searchQuery"
type="text"
@@ -29,13 +13,27 @@
@input="handleSearch"
/>
</div>
<div class="flex gap-2">
<Select
v-model="filters.status"
:options="statusFilterOptions"
class="w-40"
@change="handleStatusChange"
/>
<Select
v-model="filters.status"
:options="statusFilterOptions"
class="w-40"
@change="handleStatusChange"
/>
<!-- Right: Action buttons -->
<div class="flex flex-1 flex-wrap items-center justify-end gap-2">
<button
@click="loadAnnouncements"
:disabled="loading"
class="btn btn-secondary"
:title="t('common.refresh')"
>
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
</button>
<button @click="openCreateDialog" class="btn btn-primary">
<Icon name="plus" size="md" class="mr-1" />
{{ t('admin.announcements.createAnnouncement') }}
</button>
</div>
</div>
</template>

View File

@@ -1,26 +1,10 @@
<template>
<AppLayout>
<TablePageLayout>
<template #actions>
<div class="flex justify-end gap-3">
<button
@click="loadCodes"
:disabled="loading"
class="btn btn-secondary"
:title="t('common.refresh')"
>
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
</button>
<button @click="showCreateDialog = true" class="btn btn-primary">
<Icon name="plus" size="md" class="mr-1" />
{{ t('admin.promo.createCode') }}
</button>
</div>
</template>
<template #filters>
<div class="flex flex-col gap-4 sm:flex-row sm:items-center sm:justify-between">
<div class="max-w-md flex-1">
<div class="flex flex-wrap items-center gap-3">
<!-- Left: Search + Filters -->
<div class="flex-1 sm:max-w-64">
<input
v-model="searchQuery"
type="text"
@@ -29,13 +13,27 @@
@input="handleSearch"
/>
</div>
<div class="flex gap-2">
<Select
v-model="filters.status"
:options="filterStatusOptions"
class="w-36"
@change="loadCodes"
/>
<Select
v-model="filters.status"
:options="filterStatusOptions"
class="w-36"
@change="loadCodes"
/>
<!-- Right: Action buttons -->
<div class="flex flex-1 flex-wrap items-center justify-end gap-2">
<button
@click="loadCodes"
:disabled="loading"
class="btn btn-secondary"
:title="t('common.refresh')"
>
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
</button>
<button @click="showCreateDialog = true" class="btn btn-primary">
<Icon name="plus" size="md" class="mr-1" />
{{ t('admin.promo.createCode') }}
</button>
</div>
</div>
</template>

View File

@@ -2,9 +2,42 @@
<AppLayout>
<TablePageLayout>
<template #filters>
<div class="space-y-3">
<!-- Row 1: Actions -->
<div class="flex flex-wrap items-center gap-3">
<div class="flex flex-wrap items-center gap-3">
<!-- Left: Search + Filters -->
<div class="relative w-full sm:w-64">
<Icon
name="search"
size="md"
class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-400 dark:text-gray-500"
/>
<input
v-model="searchQuery"
type="text"
:placeholder="t('admin.proxies.searchProxies')"
class="input pl-10"
@input="handleSearch"
/>
</div>
<div class="w-full sm:w-40">
<Select
v-model="filters.protocol"
:options="protocolOptions"
:placeholder="t('admin.proxies.allProtocols')"
@change="loadProxies"
/>
</div>
<div class="w-full sm:w-36">
<Select
v-model="filters.status"
:options="statusOptions"
:placeholder="t('admin.proxies.allStatus')"
@change="loadProxies"
/>
</div>
<!-- Right: All action buttons -->
<div class="flex flex-1 flex-wrap items-center justify-end gap-2">
<button
@click="loadProxies"
:disabled="loading"
@@ -42,41 +75,6 @@
{{ t('admin.proxies.createProxy') }}
</button>
</div>
<!-- Row 2: Search + Filters -->
<div class="flex flex-wrap items-center gap-3">
<div class="relative w-full sm:w-64">
<Icon
name="search"
size="md"
class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-400 dark:text-gray-500"
/>
<input
v-model="searchQuery"
type="text"
:placeholder="t('admin.proxies.searchProxies')"
class="input pl-10"
@input="handleSearch"
/>
</div>
<div class="w-full sm:w-40">
<Select
v-model="filters.protocol"
:options="protocolOptions"
:placeholder="t('admin.proxies.allProtocols')"
@change="loadProxies"
/>
</div>
<div class="w-full sm:w-36">
<Select
v-model="filters.status"
:options="statusOptions"
:placeholder="t('admin.proxies.allStatus')"
@change="loadProxies"
/>
</div>
</div>
</div>
</template>

View File

@@ -1,34 +1,18 @@
<template>
<AppLayout>
<TablePageLayout>
<template #actions>
<div class="flex justify-end gap-3">
<button
@click="loadCodes"
:disabled="loading"
class="btn btn-secondary"
:title="t('common.refresh')"
>
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
</button>
<button @click="showGenerateDialog = true" class="btn btn-primary">
{{ t('admin.redeem.generateCodes') }}
</button>
</div>
</template>
<template #filters>
<div class="flex flex-col gap-4 sm:flex-row sm:items-center sm:justify-between">
<div class="max-w-md flex-1">
<input
v-model="searchQuery"
type="text"
:placeholder="t('admin.redeem.searchCodes')"
class="input"
@input="handleSearch"
/>
<div class="flex flex-wrap items-center gap-3">
<!-- Left: Search + Filters -->
<div class="flex-1 sm:max-w-64">
<input
v-model="searchQuery"
type="text"
:placeholder="t('admin.redeem.searchCodes')"
class="input"
@input="handleSearch"
/>
</div>
<div class="flex gap-2">
<Select
v-model="filters.type"
:options="filterTypeOptions"
@@ -41,9 +25,23 @@
class="w-36"
@change="loadCodes"
/>
<button @click="handleExportCodes" class="btn btn-secondary">
{{ t('admin.redeem.exportCsv') }}
</button>
<!-- Right: Action buttons -->
<div class="flex flex-1 flex-wrap items-center justify-end gap-2">
<button
@click="loadCodes"
:disabled="loading"
class="btn btn-secondary"
:title="t('common.refresh')"
>
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
</button>
<button @click="handleExportCodes" class="btn btn-secondary">
{{ t('admin.redeem.exportCsv') }}
</button>
<button @click="showGenerateDialog = true" class="btn btn-primary">
{{ t('admin.redeem.generateCodes') }}
</button>
</div>
</div>
</template>

View File

@@ -3,9 +3,9 @@
<TablePageLayout>
<!-- Single Row: Search, Filters, and Actions -->
<template #filters>
<div class="flex w-full flex-col gap-3 md:flex-row md:flex-wrap-reverse md:items-center md:justify-between md:gap-4">
<div class="flex flex-wrap items-center gap-3">
<!-- Left: Search + Active Filters -->
<div class="flex min-w-[280px] flex-1 flex-wrap content-start items-center gap-3 md:order-1">
<div class="flex flex-1 flex-wrap items-center gap-3">
<!-- Search Box -->
<div class="relative w-full md:w-64">
<Icon
@@ -100,7 +100,7 @@
</div>
<!-- Right: Actions and Settings -->
<div class="flex w-full items-center justify-between gap-2 md:order-2 md:ml-auto md:max-w-full md:flex-wrap md:justify-end md:gap-3">
<div class="flex flex-wrap items-center justify-end gap-2">
<!-- Mobile: Secondary buttons (icon only) -->
<div class="flex items-center gap-2 md:contents">
<!-- Refresh Button -->
@@ -342,8 +342,11 @@
</div>
</template>
<template #cell-concurrency="{ value }">
<span class="text-sm text-gray-700 dark:text-gray-300">{{ value }}</span>
<template #cell-concurrency="{ row }">
<UserConcurrencyCell
:current="row.current_concurrency ?? 0"
:max="row.concurrency"
/>
</template>
<template #cell-status="{ value }">
@@ -535,6 +538,7 @@ import EmptyState from '@/components/common/EmptyState.vue'
import GroupBadge from '@/components/common/GroupBadge.vue'
import Select from '@/components/common/Select.vue'
import UserAttributesConfigModal from '@/components/user/UserAttributesConfigModal.vue'
import UserConcurrencyCell from '@/components/user/UserConcurrencyCell.vue'
import UserCreateModal from '@/components/admin/user/UserCreateModal.vue'
import UserEditModal from '@/components/admin/user/UserEditModal.vue'
import UserApiKeysModal from '@/components/admin/user/UserApiKeysModal.vue'