Compare commits

...

74 Commits

Author SHA1 Message Date
Wesley Liddick
723102766b Merge pull request #553 from Edric-Li/feat/antigravity-onboard-projectid
feat(antigravity): 添加 onboardUser 支持,修复 project_id 缺失问题
2026-02-11 13:52:44 +08:00
Edric Li
a4a46a8618 feat(antigravity): 添加 onboardUser 支持并修复 project_id 补齐逻辑
- 新增 OnboardUser API 客户端方法,支持账号 onboarding 获取 project_id
- loadProjectIDWithRetry 增加 onboard 回退:LoadCodeAssist 未返回 project_id 时自动触发 onboarding
- GetAccessToken 中 project_id 补齐改用轻量 FillProjectID 替代全量 RefreshAccountToken
- 补齐逻辑增加 5 分钟冷却机制,防止频繁重试
- OnboardUser 轮询等待改为 context 感知,支持提前取消
- 提取 mergeCredentials 辅助方法消除重复代码
- 新增 extractProjectIDFromOnboardResponse 和 resolveDefaultTierID 单元测试
2026-02-11 13:41:55 +08:00
Wesley Liddick
ae6fed15cc Merge pull request #548 from Edric-Li/main
feat: 错误处理增强、重试优化与性能改进
2026-02-10 22:46:58 +08:00
Edric Li
378e476e48 fix: 修复 CI 检查失败
- gofmt: 修复 error_passthrough_service.go 格式问题
- errcheck: 修复 error_passthrough_runtime_test.go 类型断言未检查
- staticcheck: if-else 改为 switch (gateway_service.go)
- test: 修复两个测试用例错误使用 MODEL_CAPACITY_EXHAUSTED 导致走错路径
2026-02-10 22:08:49 +08:00
Edric Li
2a1067c82b Merge remote-tracking branch 'upstream/main' 2026-02-10 21:52:33 +08:00
Edric Li
a54b81cf74 perf: 错误处理性能优化
- MatchRule 延迟/限制 body ToLower,先用 statusCode 短路,只在需要关键词匹配时转换且限制 8KB
- 预计算规则的小写关键词/平台和 error code set,消除运行时重复 ToLower 和线性扫描
- MODEL_CAPACITY_EXHAUSTED 全局去重,避免并发请求重复重试同一模型
- 503 重试 body 读取限制从 2MB 降至 8KB
- time.After 替换为 time.NewTimer,防止 context 取消时 timer 泄漏
2026-02-10 21:40:31 +08:00
Edric Li
2d4236f76e fix: 修复错误透传规则 skip_monitoring 未生效的问题
- ops_error_logger: status < 400 分支增加 OpsSkipPassthroughKey 检查
- ops_upstream_context: 新增 checkSkipMonitoringForUpstreamEvent,中间重试/故障转移事件也能触发跳过标记
- gateway_handler/openai_gateway_handler/gemini_v1beta_handler: handleFailoverExhausted 匹配规则后设置 OpsSkipPassthroughKey
- antigravity_gateway_service: writeMappedClaudeError 增加 applyErrorPassthroughRule 调用
2026-02-10 20:56:01 +08:00
Wesley Liddick
84ced1c497 Merge pull request #543 from slovx2/upstream_main
feat(antigravity): 转发与测试支持 daily/prod 单 URL 切换
2026-02-10 14:57:46 +08:00
song
b161312183 test(antigravity): 更新单URL策略下的重试断言 2026-02-10 14:36:09 +08:00
song
1f647b120a feat(antigravity): 转发与测试支持daily/prod单URL切换 2026-02-10 13:51:29 +08:00
Edric Li
7d0a30fa8f merge: sync upstream main (antigravity single-account 503 retry)
合并上游新增的 Antigravity 单账号 503 退避重试机制,
解决与本地 MODEL_CAPACITY_EXHAUSTED 逻辑的冲突,两者共存。
2026-02-10 12:00:21 +08:00
Edric Li
d95e04fd1f feat: 错误透传规则支持 skip_monitoring 跳过运维监控记录
在每条错误透传规则上新增 skip_monitoring 选项,开启后匹配该规则的错误
不会被记录到 ops_error_logs,减少监控噪音。默认关闭,不影响现有规则。
2026-02-10 11:42:39 +08:00
shaw
5dd83d3cf2 fix: 移除特定system以适配新版cc客户端缓存失效的bug 2026-02-10 10:28:34 +08:00
Wesley Liddick
14e1aac9b5 Merge pull request #533 from GuangYiDing/feat/antigravity-single-account-503-retry
feat: Antigravity 单账号分组 503 退避重试机制
2026-02-10 09:59:48 +08:00
Edric Li
6114f69cca feat: MODEL_CAPACITY_EXHAUSTED 使用固定1s间隔重试60次,不切换账号
MODEL_CAPACITY_EXHAUSTED (503) 表示模型容量不足,所有账号共享同一容量池,
切换账号无意义。改为固定1s间隔重试最多60次,重试耗尽后直接返回上游错误。

- 新增 antigravityModelCapacityRetryMaxAttempts=60 和 antigravityModelCapacityRetryWait=1s
- shouldTriggerAntigravitySmartRetry 新增 isModelCapacityExhausted 返回值
- handleSmartRetry 对 MODEL_CAPACITY_EXHAUSTED 使用独立重试策略
- handleModelRateLimit 对 MODEL_CAPACITY_EXHAUSTED 仅标记 Handled,不设限流
- 重试耗尽后不设置模型限流、不清除粘性会话、不切换账号
2026-02-10 02:03:06 +08:00
Edric Li
d6c2921f2b feat: same-account retry before failover for transient errors
For retryable transient errors (Google 400 "invalid project resource name"
and empty stream responses), retry on the same account up to 2 times
(with 500ms delay) before switching to another account.

- Add RetryableOnSameAccount field to UpstreamFailoverError
- Add same-account retry loop in both Gemini and Claude/OpenAI handler paths
- Move temp-unschedule from service layer to handler layer (only after
  all same-account retries exhausted)
- Reduce temp-unschedule cooldown from 30 minutes to 1 minute
2026-02-10 00:53:54 +08:00
Edric Li
61c73287dc feat: failover and temp-unschedule on empty stream response
- Empty stream responses now return UpstreamFailoverError instead of
  plain 502, triggering automatic account switching (up to 10 retries)
- Add tempUnscheduleEmptyResponse: accounts returning empty responses
  are temp-unscheduled for 30 minutes
- Apply to both Claude and Gemini non-streaming paths
- Align googleConfigErrorCooldown from 60m to 30m for consistency
2026-02-09 23:25:30 +08:00
Edric Li
89905ec43d feat: failover and temp-unschedule on Google "Invalid project resource name" 400
Google 后端间歇性返回 400 "Invalid project resource name" 错误,
此前该错误直接透传给客户端且不触发账号切换,导致请求失败。

- 在 Antigravity 和 Gemini 两个平台的所有转发路径中,
  精确匹配该错误消息后触发 failover 自动换号重试
- 命中后将账号临时封禁 1 小时,避免反复调度到同一故障账号
- 提取共享函数 isGoogleProjectConfigError / tempUnscheduleGoogleConfigError
  消除跨 Service 的代码重复
2026-02-09 22:48:32 +08:00
shaw
aa4b102108 fix: 移除Antigravity的apikey账户额外的表单 2026-02-09 22:15:14 +08:00
Rose Ding
e4bc35151f test: 添加单账号 503 退避重试机制的单元测试
覆盖 Service 层和 Handler 层的所有新增逻辑:
- isSingleAccountRetry context 标记检查
- handleSmartRetry 中 503 + SingleAccountRetry 分支
- handleSingleAccountRetryInPlace 原地重试逻辑
- antigravityRetryLoop 预检查跳过限流
- sleepAntigravitySingleAccountBackoff 固定延迟退避
- 端到端集成场景验证

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 22:06:06 +08:00
Wesley Liddick
56da498b7e Merge pull request #532 from touwaeriol/fix/clear-model-rate-limits
fix: support clearing model-level rate limits from action menu and temp-unsched reset
2026-02-09 20:52:44 +08:00
Wesley Liddick
1bba1a62b1 Merge pull request #531 from touwaeriol/fix/gemini-error-policy-before-retry
fix: Gemini error policy check should precede retry logic
2026-02-09 20:52:32 +08:00
erio
4a84ca9a02 fix: support clearing model-level rate limits from action menu and temp-unsched reset 2026-02-09 20:37:30 +08:00
erio
a70d37a676 fix: Gemini error policy check should precede retry logic 2026-02-09 19:55:17 +08:00
erio
6892e84ad2 fix: skip rate limiting when custom error codes don't match upstream status
Add ShouldHandleErrorCode guard at the entry of handleGeminiUpstreamError
and AntigravityGatewayService.handleUpstreamError so that accounts with
custom error codes (e.g. [599]) are not rate-limited when the upstream
returns a non-matching status (e.g. 429).
2026-02-09 19:55:05 +08:00
erio
73f455745c feat: ErrorPolicySkipped returns 500 instead of upstream status code
When custom error codes are enabled and the upstream error code is NOT
in the configured list, return HTTP 500 to the client instead of
transparently forwarding the original status code.

Also adds integration test TestCustomErrorCode599 verifying that 429,
500, 503, 401, 403 all return 500 without triggering SetRateLimited
or SetError.
2026-02-09 19:54:54 +08:00
Rose Ding
021abfca18 fix: 单账号分组首次 503 不设模型限流标记,避免后续请求雪崩
单账号 antigravity 分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时,
原逻辑会设置 ~29s 模型限流标记。由于只有一个账号无法切换,
后续所有新请求在预检查时命中限流 → 几毫秒内直接返回 503,
导致约 30 秒的雪崩窗口。

修复:在 Handler 入口处检查分组是否只有单个 antigravity 账号,
如果是则提前设置 SingleAccountRetry context 标记,让 Service 层
首次 503 就走原地重试逻辑(不设限流标记),避免污染后续请求。
2026-02-09 17:25:36 +08:00
Wesley Liddick
7d66f7ff0d Merge pull request #527 from touwaeriol/fix/group-badge-platform-color
fix: pass platform prop to GroupBadge in GroupSelector
2026-02-09 14:39:51 +08:00
erio
470b37be7e fix: pass platform prop to GroupBadge in GroupSelector
GroupBadge in GroupSelector was missing the platform prop, causing all
group badges in account edit/detail pages to use fallback colors instead
of platform-specific colors (e.g. Claude=orange, Gemini=blue).
2026-02-09 14:33:05 +08:00
Rose Ding
f6cfab9901 feat: 添加 Antigravity 单账号 503 退避重试机制
当分组内只有一个可用账号且上游返回 503 (MODEL_CAPACITY_EXHAUSTED) 时,
不再设置模型限流+切换账号(因为切换回来还是同一个账号),而是在 Service 层
原地等待+重试,避免双重等待问题。

主要变更:
- Handler 层:检测单账号 503 场景,清除排除列表并设置 SingleAccountRetry 标记
- Service 层:新增 handleSingleAccountRetryInPlace 原地重试逻辑
- Service 层:预检查跳过单账号模式下的限流检查
- 新增 ctxkey.SingleAccountRetry 上下文标记
2026-02-09 14:26:01 +08:00
shaw
51572b5da0 chore: update version 2026-02-09 12:00:03 +08:00
Wesley Liddick
91ca28b7e3 Merge pull request #525 from DaydreamCoding/feat/crs_sync_preview_with_select
feat(admin): 新增 CRS 同步预览和账号选择功能
2026-02-09 11:58:51 +08:00
QTom
04cedce9a1 test: 为 stubAccountRepo 添加 ListCRSAccountIDs 方法实现 2026-02-09 11:40:37 +08:00
QTom
5e0d789440 feat(admin): 新增 CRS 同步预览和账号选择功能
- 后端新增 PreviewFromCRS 接口,允许用户先预览 CRS 中的账号
- 后端支持在同步时选择特定账号,不选中的账号将被跳过
- 前端重构 SyncFromCrsModal 为三步向导:输入凭据 → 预览账号 → 执行同步
- 改进表单无障碍性:添加 for/id 关联和 required 属性
- 修复 Back 按钮返回时的状态清理
- 新增 buildSelectedSet 和 shouldCreateAccount 的单元测试
- 完整的向后兼容性:旧客户端不发送 selected_account_ids 时行为不变
2026-02-09 10:39:09 +08:00
Wesley Liddick
149e4267cd Merge pull request #523 from touwaeriol/feat/antigravity-improvements
feat: Antigravity improvements and scope-to-model rate limiting refactor
2026-02-09 09:38:55 +08:00
erio
9a479d1b55 fix: resolve CI failures from scope removal refactor
- Fix gofmt alignment in ops_realtime_models.go
- Remove SetAntigravityQuotaScopeLimit mock from api_contract_test.go
- Add UpdateSortOrders mock to mockGroupRepoForGateway
2026-02-09 08:27:14 +08:00
erio
fc095bf054 refactor: replace scope-level rate limiting with model-level rate limiting
Merge functional changes from develop branch:
- Remove AntigravityQuotaScope system (claude/gemini_text/gemini_image)
- Replace with per-model rate limiting using resolveAntigravityModelKey
- Remove model load statistics (IncrModelCallCount/GetModelLoadBatch)
- Simplify account selection to unified priority→load→LRU algorithm
- Remove SetAntigravityQuotaScopeLimit from AccountRepository
- Clean up scope-related UI indicators and API fields
2026-02-09 08:19:01 +08:00
erio
1af06aed96 feat: shuffle accounts within same sort group to prevent thundering herd
Add post-sort shuffle for accounts with identical (priority, loadRate,
lastUsedAt) to break deterministic ordering when concurrent requests
read the same scheduler snapshot. Applies to both Antigravity and
OpenAI scheduling paths, plus the sortAccountsByPriorityAndLastUsed
helper.

Keeps upstream CallCount/ModelLoadInfo scheduling intact; shuffle is
additive and only randomises within equivalent-rank groups.
2026-02-09 07:33:17 +08:00
erio
9236936a55 feat: route AccountTypeUpstream to ForwardUpstream in Forward() entry
Without this routing guard, ForwardUpstream is never called because
Forward() always proceeds with the standard OAuth/cookie flow.
2026-02-09 07:27:10 +08:00
erio
125152460f fix: use upstream retryDelay for rate limit duration instead of fixed default
- In handleSmartRetry, use the actual upstream retryDelay to set model
  rate limit duration instead of always using the 30s default
- Return info.RetryDelay from shouldTriggerAntigravitySmartRetry when
  shouldRateLimitModel=true, so callers know the actual delay
- Extract getDefaultRateLimitDuration() and resolveResetTime() helpers
  to reduce duplication in handleUpstreamError 429 handling
- Improve debug logging with upstream_retry_delay and response body
2026-02-09 07:11:29 +08:00
erio
6d90fb0bc3 feat: detect client disconnect during streaming and continue draining upstream for billing 2026-02-09 07:06:26 +08:00
erio
b889d5017b refactor: replace Trie-based digest session store with flat cache 2026-02-09 07:02:12 +08:00
erio
72b08f9cc5 fix: ensure sticky session failover triggers cache billing exemption 2026-02-09 06:57:07 +08:00
erio
681950dadd feat: add linear delay between Antigravity account failover switches 2026-02-09 06:56:29 +08:00
erio
a67d9337b8 feat: integrate CheckErrorPolicy into Gemini error handling paths 2026-02-09 06:55:45 +08:00
erio
2f1182e8a9 feat: unified error policy for Antigravity + enable custom error codes for Gemini accounts 2026-02-09 06:54:42 +08:00
erio
cbb4d854ab fix: check type assertion in test to satisfy errcheck linter 2026-02-09 06:47:50 +08:00
erio
35598d5648 fix: parse Gemini native request format in ParseGatewayRequest for correct session hash generation
ParseGatewayRequest only parsed Anthropic format (system/messages),
ignoring Gemini native format (systemInstruction/contents). This caused
GenerateSessionHash to produce identical hashes for all Gemini sessions.

Add protocol parameter to ParseGatewayRequest to branch between
Anthropic and Gemini parsing. Update GenerateSessionHash message
traversal to extract text from both formats.
2026-02-09 06:47:22 +08:00
erio
5c76b9e45a fix: prevent sessionHash collision for different users with same messages
Mix SessionContext (ClientIP, UserAgent, APIKeyID) into
GenerateSessionHash 3rd-level fallback to differentiate requests
from different users sending identical content.

Also switch hashContent from SHA256-truncated to XXHash64 for
better performance, and optimize Trie Lua script to match from
longest prefix first.
2026-02-09 06:46:32 +08:00
erio
0b8fea4cb4 fix: clean thoughtSignature for all clients, not just CLI
Previously, thoughtSignature cleanup only applied to Gemini CLI
requests (detected via x-gemini-api-privileged-user-id header or
tmp dir pattern). This caused 400 errors for non-CLI clients when
session cache expired and they sent stale signatures.

Remove the isGeminiCLIRequest guard so all clients benefit from
proactive thoughtSignature cleanup on session binding miss.
2026-02-09 06:45:01 +08:00
Wesley Liddick
5fa93ebdc7 Merge pull request #519 from bayma888/feature/group-sort-order
feat(admin): 新增-分组管理自由拖拽排序功能
2026-02-08 18:00:22 +08:00
bayma888
8aa0aed566 docs: add development guide for team reference
记录项目环境配置、CI 流程、常见坑点和解决方案。
2026-02-08 17:54:03 +08:00
bayma888
2eb32a0ed7 chore: update pnpm-lock.yaml for vue-draggable-plus
CI 的 pnpm install --frozen-lockfile 需要 lock 文件同步更新
2026-02-08 17:10:25 +08:00
bayma888
bac9e2bfd5 feat(admin): add drag-and-drop group sort order
- Add `sort_order` field to groups table with migration
- Add `PUT /api/v1/admin/groups/sort-order` API for batch update
- Implement drag-and-drop UI using vue-draggable-plus
- All queries now order groups by sort_order
- Add i18n support (en/zh) for sort-related UI text
- Update test stubs to satisfy new interface methods
2026-02-08 16:53:45 +08:00
shaw
e4d74ae11d feat(ui): 用户列表页显示当前并发数
优化 /admin/users 页面的并发数列,显示「当前/最大」格式,
参考 AccountCapacityCell 的设计风格。

- 后端 UserHandler 注入 ConcurrencyService,批量查询用户当前并发数
- 新增 UserConcurrencyCell 组件,支持颜色状态(空闲灰/使用中黄/满载红)
- 前端 AdminUser 类型添加 current_concurrency 字段
2026-02-08 16:44:51 +08:00
shaw
8a0a8558cf feat(ui): OpenAI OAuth 账号支持批量 RT 输入创建
新增通过手动输入 Refresh Token 创建 OpenAI OAuth 账号功能,
参考 Anthropic sessionKey 批量创建方式:

- useOpenAIOAuth 添加 validateRefreshToken 方法
- accounts.ts 添加 refreshOpenAIToken API
- AuthInputMethod 类型新增 refresh_token 选项
- 支持多行输入 RT(每行一个)批量创建账号
- 账号名称自动累加后缀 #1, #2 等
- UI 显示 RT 数量徽章和批量创建提示
- 添加中英文 i18n 翻译
2026-02-08 16:10:15 +08:00
Wesley Liddick
2185a3b674 Merge pull request #517 from touwaeriol/fix/upstream-baseurl
refactor(upstream): replace upstream account type with apikey + auto-append base_url
2026-02-08 14:03:12 +08:00
Wesley Liddick
9e3c306a5b Merge pull request #513 from touwaeriol/pr/antigravity-full-v2
feat(antigravity): comprehensive enhancements — rate limiting, scheduling & smart retry
2026-02-08 14:01:17 +08:00
shaw
b1c30df8e3 fix(ui): unify admin table toolbar layout with search and buttons in single row
Standardize filter bar layout across admin pages to place search/filters
on left and action buttons on right within the same row, improving
visual consistency and space utilization.
2026-02-08 14:00:02 +08:00
erio
69816f8691 fix: remove unused upstreamHopByHopHeaders variable to pass golangci-lint 2026-02-08 13:30:39 +08:00
shaw
b4ec65785d fix: apikey类型账号test去掉oauth-2025-04-20 2026-02-08 13:26:28 +08:00
erio
3c93644146 chore: bump version to 0.1.74.7 2026-02-08 13:14:58 +08:00
erio
fb58560d15 refactor(upstream): replace upstream account type with apikey, auto-append /antigravity
Upstream accounts now use the standard APIKey type instead of a dedicated
upstream type. GetBaseURL() and new GetGeminiBaseURL() automatically append
/antigravity for Antigravity platform APIKey accounts, eliminating the need
for separate upstream forwarding methods.

- Remove ForwardUpstream, ForwardUpstreamGemini, testUpstreamConnection
- Remove upstream branch guards in Forward/ForwardGemini/TestConnection
- Add migration 052 to convert existing upstream accounts to apikey
- Update frontend CreateAccountModal to create apikey type
- Add unit tests for GetBaseURL and GetGeminiBaseURL
2026-02-08 13:06:25 +08:00
erio
6ab77f5eb5 fix(upstream): passthrough response body directly instead of parsing SSE
ForwardUpstream/ForwardUpstreamGemini should pipe the upstream response
directly to the client (headers + body), not parse it as SSE stream.
2026-02-08 08:49:43 +08:00
erio
4f57d7f761 fix: add nil guard for gin.Context in header passthrough to satisfy staticcheck SA5011 2026-02-08 08:36:35 +08:00
erio
1563bd3dda feat(upstream): passthrough all client headers instead of manual header setting
Replace manual header setting (Content-Type, anthropic-version, anthropic-beta)
with full client header passthrough in ForwardUpstream/ForwardUpstreamGemini.
Only authentication headers (Authorization, x-api-key) are overridden with
upstream account credentials. Hop-by-hop headers are excluded.

Add unit tests covering header passthrough, auth override, and hop-by-hop filtering.
2026-02-08 08:33:09 +08:00
erio
df3346387f fix(frontend): upstream account edit fields and mixed_scheduling on create
- EditAccountModal: add Base URL / API Key fields for upstream type
- EditAccountModal: initialize editBaseUrl from credentials on upstream account open
- EditAccountModal: save upstream credentials (base_url, api_key) on submit
- CreateAccountModal: pass mixed_scheduling extra when creating upstream account
2026-02-08 02:08:51 +08:00
erio
77b66653ed fix(gateway): restore upstream account forwarding with dedicated methods
v0.1.74 merged upstream accounts into the OAuth path, causing requests
to hit the wrong protocol and endpoint. Add three upstream-specific
methods (testUpstreamConnection, ForwardUpstream, ForwardUpstreamGemini)
that use base_url + apiKey auth and passthrough the original body, while
reusing the existing response handling and error/retry logic.
2026-02-08 01:21:02 +08:00
erio
3077fd279d feat: smart retry max 1 attempt + clear sticky session on failure
- Change antigravitySmartRetryMaxAttempts from 3 to 1 to prevent
  repeated rate limiting and long waits
- Clear sticky session binding (DeleteSessionAccountID) after smart
  retry exhaustion, so subsequent requests don't hit the same
  rate-limited account
- Add flow diagrams to Forward/ForwardGemini doc comments
- Add comprehensive unit tests covering:
  - Sticky session cleared on retry failure (429, 503, network error)
  - Sticky session NOT cleared on retry success
  - Sticky session NOT cleared for non-sticky requests (empty hash)
  - Sticky session NOT cleared on long delay path (handled by handler)
  - Nil cache safety (no panic)
  - MaxAttempts constant verification
  - End-to-end retryLoop → switchError propagation with session clear
2026-02-07 19:30:58 +08:00
erio
e3748da860 fix(lint): handle errcheck for strings.Builder.WriteString 2026-02-07 18:18:15 +08:00
erio
36e6fb5fc8 ci: trigger CI for new PR 2026-02-07 18:13:37 +08:00
erio
86b503f87f refactor: remove Anthropic digest chain from Messages handler
The digest chain fallback is only needed for Gemini endpoints, not
for the Anthropic Messages API path. Remove the handler integration
while keeping the reusable service/repository layer for future use.
2026-02-07 18:01:04 +08:00
erio
50a783ff01 feat: add Anthropic sticky session digest chain matching via Trie
The previous fallback (step 3) in GenerateSessionHash hashed system +
all messages together, producing a different hash each round as the
conversation grew ([a] -> [a,b] -> [a,b,c]). This made fallback sticky
sessions ineffective for multi-turn conversations.

Implement per-message Trie digest chain matching (reusing Gemini's Trie
infrastructure) so that the previous round's chain is always a prefix
of the current round's chain, enabling reliable session affinity.
2026-02-07 18:00:56 +08:00
erio
e1a68497d6 refactor: simplify sticky session rate limit handling — switch immediately on any rate limit
Remove threshold-based waiting in both sticky session and antigravity
pre-check paths. When a model is rate-limited, immediately clear the
sticky session and switch accounts instead of waiting for short durations.
2026-02-07 17:06:49 +08:00
132 changed files with 10986 additions and 2434 deletions

323
DEV_GUIDE.md Normal file
View File

@@ -0,0 +1,323 @@
# sub2api 项目开发指南
> 本文档记录项目环境配置、常见坑点和注意事项,供 Claude Code 和团队成员参考。
## 一、项目基本信息
| 项目 | 说明 |
|------|------|
| **上游仓库** | Wei-Shaw/sub2api |
| **Fork 仓库** | bayma888/sub2api-bmai |
| **技术栈** | Go 后端 (Ent ORM + Gin) + Vue3 前端 (pnpm) |
| **数据库** | PostgreSQL 16 + Redis |
| **包管理** | 后端: go modules, 前端: **pnpm**(不是 npm |
## 二、本地环境配置
### PostgreSQL 16 (Windows 服务)
| 配置项 | 值 |
|--------|-----|
| 端口 | 5432 |
| psql 路径 | `C:\Program Files\PostgreSQL\16\bin\psql.exe` |
| pg_hba.conf | `C:\Program Files\PostgreSQL\16\data\pg_hba.conf` |
| 数据库凭据 | user=`sub2api`, password=`sub2api`, dbname=`sub2api` |
| 超级用户 | user=`postgres`, password=`postgres` |
### Redis
| 配置项 | 值 |
|--------|-----|
| 端口 | 6379 |
| 密码 | 无 |
### 开发工具
```bash
# golangci-lint v2.7
go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.7
# pnpm (前端包管理)
npm install -g pnpm
```
## 三、CI/CD 流水线
### GitHub Actions Workflows
| Workflow | 触发条件 | 检查内容 |
|----------|----------|----------|
| **backend-ci.yml** | push, pull_request | 单元测试 + 集成测试 + golangci-lint v2.7 |
| **security-scan.yml** | push, pull_request, 每周一 | govulncheck + gosec + pnpm audit |
| **release.yml** | tag `v*` | 构建发布PR 不触发) |
### CI 要求
- Go 版本必须是 **1.25.7**
- 前端使用 `pnpm install --frozen-lockfile`,必须提交 `pnpm-lock.yaml`
### 本地测试命令
```bash
# 后端单元测试
cd backend && go test -tags=unit ./...
# 后端集成测试
cd backend && go test -tags=integration ./...
# 代码质量检查
cd backend && golangci-lint run ./...
# 前端依赖安装(必须用 pnpm
cd frontend && pnpm install
```
## 四、常见坑点 & 解决方案
### 坑 1pnpm-lock.yaml 必须同步提交
**问题**`package.json` 新增依赖后CI 的 `pnpm install --frozen-lockfile` 失败。
**原因**:上游 CI 使用 pnpmlock 文件不同步会报错。
**解决**
```bash
cd frontend
pnpm install # 更新 pnpm-lock.yaml
git add pnpm-lock.yaml
git commit -m "chore: update pnpm-lock.yaml"
```
---
### 坑 2npm 和 pnpm 的 node_modules 冲突
**问题**:之前用 npm 装过 `node_modules`pnpm install 报 `EPERM` 错误。
**解决**
```bash
cd frontend
rm -rf node_modules # 或 PowerShell: Remove-Item -Recurse -Force node_modules
pnpm install
```
---
### 坑 3PowerShell 中 bcrypt hash 的 `$` 被转义
**问题**bcrypt hash 格式如 `$2a$10$xxx...`PowerShell 把 `$2a` 当变量解析,导致数据丢失。
**解决**:将 SQL 写入文件,用 `psql -f` 执行:
```bash
# 错误示范PowerShell 会吃掉 $
psql -c "INSERT INTO users ... VALUES ('$2a$10$...')"
# 正确做法
echo "INSERT INTO users ... VALUES ('\$2a\$10\$...')" > temp.sql
psql -U sub2api -h 127.0.0.1 -d sub2api -f temp.sql
```
---
### 坑 4psql 不支持中文路径
**问题**`psql -f "D:\中文路径\file.sql"` 报错找不到文件。
**解决**:复制到纯英文路径再执行:
```bash
cp "D:\中文路径\file.sql" "C:\temp.sql"
psql -f "C:\temp.sql"
```
---
### 坑 5PostgreSQL 密码重置流程
**场景**:忘记 PostgreSQL 密码。
**步骤**
1. 修改 `C:\Program Files\PostgreSQL\16\data\pg_hba.conf`
```
# 将 scram-sha-256 改为 trust
host all all 127.0.0.1/32 trust
```
2. 重启 PostgreSQL 服务
```powershell
Restart-Service postgresql-x64-16
```
3. 无密码登录并重置
```bash
psql -U postgres -h 127.0.0.1
ALTER USER sub2api WITH PASSWORD 'sub2api';
ALTER USER postgres WITH PASSWORD 'postgres';
```
4. 改回 `scram-sha-256` 并重启
---
### 坑 6Go interface 新增方法后 test stub 必须补全
**问题**:给 interface 新增方法后,编译报错 `does not implement interface (missing method XXX)`。
**原因**:所有测试文件中实现该 interface 的 stub/mock 都必须补上新方法。
**解决**
```bash
# 搜索所有实现该 interface 的 struct
cd backend
grep -r "type.*Stub.*struct" internal/
grep -r "type.*Mock.*struct" internal/
# 逐一补全新方法
```
---
### 坑 7Windows 上 psql 连 localhost 的 IPv6 问题
**问题**psql 连 `localhost` 先尝试 IPv6 (::1),可能报错后再回退 IPv4。
**建议**:直接用 `127.0.0.1` 代替 `localhost`。
---
### 坑 8Windows 没有 make 命令
**问题**CI 里用 `make test-unit`,本地 Windows 没有 make。
**解决**:直接用 Makefile 里的原始命令:
```bash
# 代替 make test-unit
go test -tags=unit ./...
# 代替 make test-integration
go test -tags=integration ./...
```
---
### 坑 9Ent Schema 修改后必须重新生成
**问题**:修改 `ent/schema/*.go` 后,代码不生效。
**解决**
```bash
cd backend
go generate ./ent # 重新生成 ent 代码
git add ent/ # 生成的文件也要提交
```
---
### 坑 10PR 提交前检查清单
提交 PR 前务必本地验证:
- [ ] `go test -tags=unit ./...` 通过
- [ ] `go test -tags=integration ./...` 通过
- [ ] `golangci-lint run ./...` 无新增问题
- [ ] `pnpm-lock.yaml` 已同步(如果改了 package.json
- [ ] 所有 test stub 补全新接口方法(如果改了 interface
- [ ] Ent 生成的代码已提交(如果改了 schema
## 五、常用命令速查
### 数据库操作
```bash
# 连接数据库
psql -U sub2api -h 127.0.0.1 -d sub2api
# 查看所有用户
psql -U postgres -h 127.0.0.1 -c "\du"
# 查看所有数据库
psql -U postgres -h 127.0.0.1 -c "\l"
# 执行 SQL 文件
psql -U sub2api -h 127.0.0.1 -d sub2api -f migration.sql
```
### Git 操作
```bash
# 同步上游
git fetch upstream
git checkout main
git merge upstream/main
git push origin main
# 创建功能分支
git checkout -b feature/xxx
# Rebase 到最新 main
git fetch upstream
git rebase upstream/main
```
### 前端操作
```bash
# 安装依赖(必须用 pnpm
cd frontend
pnpm install
# 开发服务器
pnpm dev
# 构建
pnpm build
```
### 后端操作
```bash
# 运行服务器
cd backend
go run ./cmd/server/
# 生成 Ent 代码
go generate ./ent
# 运行测试
go test -tags=unit ./...
go test -tags=integration ./...
# Lint 检查
golangci-lint run ./...
```
## 六、项目结构速览
```
sub2api-bmai/
├── backend/
│ ├── cmd/server/ # 主程序入口
│ ├── ent/ # Ent ORM 生成代码
│ │ └── schema/ # 数据库 Schema 定义
│ ├── internal/
│ │ ├── handler/ # HTTP 处理器
│ │ ├── service/ # 业务逻辑
│ │ ├── repository/ # 数据访问层
│ │ └── server/ # 服务器配置
│ ├── migrations/ # 数据库迁移脚本
│ └── config.yaml # 配置文件
├── frontend/
│ ├── src/
│ │ ├── api/ # API 调用
│ │ ├── components/ # Vue 组件
│ │ ├── views/ # 页面视图
│ │ ├── types/ # TypeScript 类型
│ │ └── i18n/ # 国际化
│ ├── package.json # 依赖配置
│ └── pnpm-lock.yaml # pnpm 锁文件(必须提交)
└── .claude/
└── CLAUDE.md # 本文档
```
## 七、参考资源
- [上游仓库](https://github.com/Wei-Shaw/sub2api)
- [Ent 文档](https://entgo.io/docs/getting-started)
- [Vue3 文档](https://vuejs.org/)
- [pnpm 文档](https://pnpm.io/)

View File

@@ -1 +1 @@
0.1.70
0.1.76

View File

@@ -102,7 +102,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
@@ -126,13 +128,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
@@ -154,7 +154,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)

View File

@@ -44,6 +44,8 @@ type ErrorPassthroughRule struct {
PassthroughBody bool `json:"passthrough_body,omitempty"`
// CustomMessage holds the value of the "custom_message" field.
CustomMessage *string `json:"custom_message,omitempty"`
// SkipMonitoring holds the value of the "skip_monitoring" field.
SkipMonitoring bool `json:"skip_monitoring,omitempty"`
// Description holds the value of the "description" field.
Description *string `json:"description,omitempty"`
selectValues sql.SelectValues
@@ -56,7 +58,7 @@ func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms:
values[i] = new([]byte)
case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody:
case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody, errorpassthroughrule.FieldSkipMonitoring:
values[i] = new(sql.NullBool)
case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode:
values[i] = new(sql.NullInt64)
@@ -171,6 +173,12 @@ func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) err
_m.CustomMessage = new(string)
*_m.CustomMessage = value.String
}
case errorpassthroughrule.FieldSkipMonitoring:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field skip_monitoring", values[i])
} else if value.Valid {
_m.SkipMonitoring = value.Bool
}
case errorpassthroughrule.FieldDescription:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field description", values[i])
@@ -257,6 +265,9 @@ func (_m *ErrorPassthroughRule) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("skip_monitoring=")
builder.WriteString(fmt.Sprintf("%v", _m.SkipMonitoring))
builder.WriteString(", ")
if v := _m.Description; v != nil {
builder.WriteString("description=")
builder.WriteString(*v)

View File

@@ -39,6 +39,8 @@ const (
FieldPassthroughBody = "passthrough_body"
// FieldCustomMessage holds the string denoting the custom_message field in the database.
FieldCustomMessage = "custom_message"
// FieldSkipMonitoring holds the string denoting the skip_monitoring field in the database.
FieldSkipMonitoring = "skip_monitoring"
// FieldDescription holds the string denoting the description field in the database.
FieldDescription = "description"
// Table holds the table name of the errorpassthroughrule in the database.
@@ -61,6 +63,7 @@ var Columns = []string{
FieldResponseCode,
FieldPassthroughBody,
FieldCustomMessage,
FieldSkipMonitoring,
FieldDescription,
}
@@ -95,6 +98,8 @@ var (
DefaultPassthroughCode bool
// DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field.
DefaultPassthroughBody bool
// DefaultSkipMonitoring holds the default value on creation for the "skip_monitoring" field.
DefaultSkipMonitoring bool
)
// OrderOption defines the ordering options for the ErrorPassthroughRule queries.
@@ -155,6 +160,11 @@ func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCustomMessage, opts...).ToFunc()
}
// BySkipMonitoring orders the results by the skip_monitoring field.
func BySkipMonitoring(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSkipMonitoring, opts...).ToFunc()
}
// ByDescription orders the results by the description field.
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDescription, opts...).ToFunc()

View File

@@ -104,6 +104,11 @@ func CustomMessage(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v))
}
// SkipMonitoring applies equality check predicate on the "skip_monitoring" field. It's identical to SkipMonitoringEQ.
func SkipMonitoring(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v))
}
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
func Description(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))
@@ -544,6 +549,16 @@ func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v))
}
// SkipMonitoringEQ applies the EQ predicate on the "skip_monitoring" field.
func SkipMonitoringEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v))
}
// SkipMonitoringNEQ applies the NEQ predicate on the "skip_monitoring" field.
func SkipMonitoringNEQ(v bool) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldSkipMonitoring, v))
}
// DescriptionEQ applies the EQ predicate on the "description" field.
func DescriptionEQ(v string) predicate.ErrorPassthroughRule {
return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v))

View File

@@ -172,6 +172,20 @@ func (_c *ErrorPassthroughRuleCreate) SetNillableCustomMessage(v *string) *Error
return _c
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (_c *ErrorPassthroughRuleCreate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleCreate {
_c.mutation.SetSkipMonitoring(v)
return _c
}
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
func (_c *ErrorPassthroughRuleCreate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleCreate {
if v != nil {
_c.SetSkipMonitoring(*v)
}
return _c
}
// SetDescription sets the "description" field.
func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate {
_c.mutation.SetDescription(v)
@@ -249,6 +263,10 @@ func (_c *ErrorPassthroughRuleCreate) defaults() {
v := errorpassthroughrule.DefaultPassthroughBody
_c.mutation.SetPassthroughBody(v)
}
if _, ok := _c.mutation.SkipMonitoring(); !ok {
v := errorpassthroughrule.DefaultSkipMonitoring
_c.mutation.SetSkipMonitoring(v)
}
}
// check runs all checks and user-defined validators on the builder.
@@ -287,6 +305,9 @@ func (_c *ErrorPassthroughRuleCreate) check() error {
if _, ok := _c.mutation.PassthroughBody(); !ok {
return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)}
}
if _, ok := _c.mutation.SkipMonitoring(); !ok {
return &ValidationError{Name: "skip_monitoring", err: errors.New(`ent: missing required field "ErrorPassthroughRule.skip_monitoring"`)}
}
return nil
}
@@ -366,6 +387,10 @@ func (_c *ErrorPassthroughRuleCreate) createSpec() (*ErrorPassthroughRule, *sqlg
_spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value)
_node.CustomMessage = &value
}
if value, ok := _c.mutation.SkipMonitoring(); ok {
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
_node.SkipMonitoring = value
}
if value, ok := _c.mutation.Description(); ok {
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
_node.Description = &value
@@ -608,6 +633,18 @@ func (u *ErrorPassthroughRuleUpsert) ClearCustomMessage() *ErrorPassthroughRuleU
return u
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (u *ErrorPassthroughRuleUpsert) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsert {
u.Set(errorpassthroughrule.FieldSkipMonitoring, v)
return u
}
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
func (u *ErrorPassthroughRuleUpsert) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsert {
u.SetExcluded(errorpassthroughrule.FieldSkipMonitoring)
return u
}
// SetDescription sets the "description" field.
func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert {
u.Set(errorpassthroughrule.FieldDescription, v)
@@ -888,6 +925,20 @@ func (u *ErrorPassthroughRuleUpsertOne) ClearCustomMessage() *ErrorPassthroughRu
})
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (u *ErrorPassthroughRuleUpsertOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertOne {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.SetSkipMonitoring(v)
})
}
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
func (u *ErrorPassthroughRuleUpsertOne) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertOne {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.UpdateSkipMonitoring()
})
}
// SetDescription sets the "description" field.
func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
@@ -1337,6 +1388,20 @@ func (u *ErrorPassthroughRuleUpsertBulk) ClearCustomMessage() *ErrorPassthroughR
})
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (u *ErrorPassthroughRuleUpsertBulk) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertBulk {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.SetSkipMonitoring(v)
})
}
// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create.
func (u *ErrorPassthroughRuleUpsertBulk) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertBulk {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {
s.UpdateSkipMonitoring()
})
}
// SetDescription sets the "description" field.
func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk {
return u.Update(func(s *ErrorPassthroughRuleUpsert) {

View File

@@ -227,6 +227,20 @@ func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRule
return _u
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (_u *ErrorPassthroughRuleUpdate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdate {
_u.mutation.SetSkipMonitoring(v)
return _u
}
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdate {
if v != nil {
_u.SetSkipMonitoring(*v)
}
return _u
}
// SetDescription sets the "description" field.
func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate {
_u.mutation.SetDescription(v)
@@ -387,6 +401,9 @@ func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, e
if _u.mutation.CustomMessageCleared() {
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
}
if value, ok := _u.mutation.SkipMonitoring(); ok {
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
}
@@ -611,6 +628,20 @@ func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughR
return _u
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetSkipMonitoring(v)
return _u
}
// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil.
func (_u *ErrorPassthroughRuleUpdateOne) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdateOne {
if v != nil {
_u.SetSkipMonitoring(*v)
}
return _u
}
// SetDescription sets the "description" field.
func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne {
_u.mutation.SetDescription(v)
@@ -801,6 +832,9 @@ func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *Er
if _u.mutation.CustomMessageCleared() {
_spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString)
}
if value, ok := _u.mutation.SkipMonitoring(); ok {
_spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value)
}

View File

@@ -66,6 +66,8 @@ type Group struct {
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
// 支持的模型系列claude, gemini_text, gemini_image
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// 分组显示排序,数值越小越靠前
SortOrder int `json:"sort_order,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -178,7 +180,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest:
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
values[i] = new(sql.NullString)
@@ -363,6 +365,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
return fmt.Errorf("unmarshal field supported_model_scopes: %w", err)
}
}
case group.FieldSortOrder:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field sort_order", values[i])
} else if value.Valid {
_m.SortOrder = int(value.Int64)
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -530,6 +538,9 @@ func (_m *Group) String() string {
builder.WriteString(", ")
builder.WriteString("supported_model_scopes=")
builder.WriteString(fmt.Sprintf("%v", _m.SupportedModelScopes))
builder.WriteString(", ")
builder.WriteString("sort_order=")
builder.WriteString(fmt.Sprintf("%v", _m.SortOrder))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -63,6 +63,8 @@ const (
FieldMcpXMLInject = "mcp_xml_inject"
// FieldSupportedModelScopes holds the string denoting the supported_model_scopes field in the database.
FieldSupportedModelScopes = "supported_model_scopes"
// FieldSortOrder holds the string denoting the sort_order field in the database.
FieldSortOrder = "sort_order"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -162,6 +164,7 @@ var Columns = []string{
FieldModelRoutingEnabled,
FieldMcpXMLInject,
FieldSupportedModelScopes,
FieldSortOrder,
}
var (
@@ -225,6 +228,8 @@ var (
DefaultMcpXMLInject bool
// DefaultSupportedModelScopes holds the default value on creation for the "supported_model_scopes" field.
DefaultSupportedModelScopes []string
// DefaultSortOrder holds the default value on creation for the "sort_order" field.
DefaultSortOrder int
)
// OrderOption defines the ordering options for the Group queries.
@@ -345,6 +350,11 @@ func ByMcpXMLInject(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMcpXMLInject, opts...).ToFunc()
}
// BySortOrder orders the results by the sort_order field.
func BySortOrder(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSortOrder, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -165,6 +165,11 @@ func McpXMLInject(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v))
}
// SortOrder applies equality check predicate on the "sort_order" field. It's identical to SortOrderEQ.
func SortOrder(v int) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSortOrder, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1160,6 +1165,46 @@ func McpXMLInjectNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldMcpXMLInject, v))
}
// SortOrderEQ applies the EQ predicate on the "sort_order" field.
func SortOrderEQ(v int) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSortOrder, v))
}
// SortOrderNEQ applies the NEQ predicate on the "sort_order" field.
func SortOrderNEQ(v int) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSortOrder, v))
}
// SortOrderIn applies the In predicate on the "sort_order" field.
func SortOrderIn(vs ...int) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSortOrder, vs...))
}
// SortOrderNotIn applies the NotIn predicate on the "sort_order" field.
func SortOrderNotIn(vs ...int) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSortOrder, vs...))
}
// SortOrderGT applies the GT predicate on the "sort_order" field.
func SortOrderGT(v int) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSortOrder, v))
}
// SortOrderGTE applies the GTE predicate on the "sort_order" field.
func SortOrderGTE(v int) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSortOrder, v))
}
// SortOrderLT applies the LT predicate on the "sort_order" field.
func SortOrderLT(v int) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSortOrder, v))
}
// SortOrderLTE applies the LTE predicate on the "sort_order" field.
func SortOrderLTE(v int) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSortOrder, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {

View File

@@ -340,6 +340,20 @@ func (_c *GroupCreate) SetSupportedModelScopes(v []string) *GroupCreate {
return _c
}
// SetSortOrder sets the "sort_order" field.
func (_c *GroupCreate) SetSortOrder(v int) *GroupCreate {
_c.mutation.SetSortOrder(v)
return _c
}
// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate {
if v != nil {
_c.SetSortOrder(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -521,6 +535,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultSupportedModelScopes
_c.mutation.SetSupportedModelScopes(v)
}
if _, ok := _c.mutation.SortOrder(); !ok {
v := group.DefaultSortOrder
_c.mutation.SetSortOrder(v)
}
return nil
}
@@ -585,6 +603,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.SupportedModelScopes(); !ok {
return &ValidationError{Name: "supported_model_scopes", err: errors.New(`ent: missing required field "Group.supported_model_scopes"`)}
}
if _, ok := _c.mutation.SortOrder(); !ok {
return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)}
}
return nil
}
@@ -708,6 +729,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value)
_node.SupportedModelScopes = value
}
if value, ok := _c.mutation.SortOrder(); ok {
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
_node.SortOrder = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1266,6 +1291,24 @@ func (u *GroupUpsert) UpdateSupportedModelScopes() *GroupUpsert {
return u
}
// SetSortOrder sets the "sort_order" field.
func (u *GroupUpsert) SetSortOrder(v int) *GroupUpsert {
u.Set(group.FieldSortOrder, v)
return u
}
// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSortOrder() *GroupUpsert {
u.SetExcluded(group.FieldSortOrder)
return u
}
// AddSortOrder adds v to the "sort_order" field.
func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert {
u.Add(group.FieldSortOrder, v)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1780,6 +1823,27 @@ func (u *GroupUpsertOne) UpdateSupportedModelScopes() *GroupUpsertOne {
})
}
// SetSortOrder sets the "sort_order" field.
func (u *GroupUpsertOne) SetSortOrder(v int) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSortOrder(v)
})
}
// AddSortOrder adds v to the "sort_order" field.
func (u *GroupUpsertOne) AddSortOrder(v int) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSortOrder(v)
})
}
// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSortOrder()
})
}
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -2460,6 +2524,27 @@ func (u *GroupUpsertBulk) UpdateSupportedModelScopes() *GroupUpsertBulk {
})
}
// SetSortOrder sets the "sort_order" field.
func (u *GroupUpsertBulk) SetSortOrder(v int) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSortOrder(v)
})
}
// AddSortOrder adds v to the "sort_order" field.
func (u *GroupUpsertBulk) AddSortOrder(v int) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSortOrder(v)
})
}
// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSortOrder()
})
}
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -475,6 +475,27 @@ func (_u *GroupUpdate) AppendSupportedModelScopes(v []string) *GroupUpdate {
return _u
}
// SetSortOrder sets the "sort_order" field.
func (_u *GroupUpdate) SetSortOrder(v int) *GroupUpdate {
_u.mutation.ResetSortOrder()
_u.mutation.SetSortOrder(v)
return _u
}
// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSortOrder(v *int) *GroupUpdate {
if v != nil {
_u.SetSortOrder(*v)
}
return _u
}
// AddSortOrder adds value to the "sort_order" field.
func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate {
_u.mutation.AddSortOrder(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -912,6 +933,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
sqljson.Append(u, group.FieldSupportedModelScopes, value)
})
}
if value, ok := _u.mutation.SortOrder(); ok {
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedSortOrder(); ok {
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1666,6 +1693,27 @@ func (_u *GroupUpdateOne) AppendSupportedModelScopes(v []string) *GroupUpdateOne
return _u
}
// SetSortOrder sets the "sort_order" field.
func (_u *GroupUpdateOne) SetSortOrder(v int) *GroupUpdateOne {
_u.mutation.ResetSortOrder()
_u.mutation.SetSortOrder(v)
return _u
}
// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSortOrder(v *int) *GroupUpdateOne {
if v != nil {
_u.SetSortOrder(*v)
}
return _u
}
// AddSortOrder adds value to the "sort_order" field.
func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne {
_u.mutation.AddSortOrder(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -2133,6 +2181,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
sqljson.Append(u, group.FieldSupportedModelScopes, value)
})
}
if value, ok := _u.mutation.SortOrder(); ok {
_spec.SetField(group.FieldSortOrder, field.TypeInt, value)
}
if value, ok := _u.mutation.AddedSortOrder(); ok {
_spec.AddField(group.FieldSortOrder, field.TypeInt, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -325,6 +325,7 @@ var (
{Name: "response_code", Type: field.TypeInt, Nullable: true},
{Name: "passthrough_body", Type: field.TypeBool, Default: true},
{Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647},
{Name: "skip_monitoring", Type: field.TypeBool, Default: false},
{Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647},
}
// ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table.
@@ -372,6 +373,7 @@ var (
{Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
{Name: "mcp_xml_inject", Type: field.TypeBool, Default: true},
{Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "sort_order", Type: field.TypeInt, Default: 0},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{
@@ -404,6 +406,11 @@ var (
Unique: false,
Columns: []*schema.Column{GroupsColumns[3]},
},
{
Name: "group_sort_order",
Unique: false,
Columns: []*schema.Column{GroupsColumns[25]},
},
},
}
// PromoCodesColumns holds the columns for the "promo_codes" table.

View File

@@ -5776,6 +5776,7 @@ type ErrorPassthroughRuleMutation struct {
addresponse_code *int
passthrough_body *bool
custom_message *string
skip_monitoring *bool
description *string
clearedFields map[string]struct{}
done bool
@@ -6503,6 +6504,42 @@ func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() {
delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage)
}
// SetSkipMonitoring sets the "skip_monitoring" field.
func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) {
m.skip_monitoring = &b
}
// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation.
func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) {
v := m.skip_monitoring
if v == nil {
return
}
return *v, true
}
// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity.
// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldSkipMonitoring requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err)
}
return oldValue.SkipMonitoring, nil
}
// ResetSkipMonitoring resets all changes to the "skip_monitoring" field.
func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() {
m.skip_monitoring = nil
}
// SetDescription sets the "description" field.
func (m *ErrorPassthroughRuleMutation) SetDescription(s string) {
m.description = &s
@@ -6586,7 +6623,7 @@ func (m *ErrorPassthroughRuleMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *ErrorPassthroughRuleMutation) Fields() []string {
fields := make([]string, 0, 14)
fields := make([]string, 0, 15)
if m.created_at != nil {
fields = append(fields, errorpassthroughrule.FieldCreatedAt)
}
@@ -6626,6 +6663,9 @@ func (m *ErrorPassthroughRuleMutation) Fields() []string {
if m.custom_message != nil {
fields = append(fields, errorpassthroughrule.FieldCustomMessage)
}
if m.skip_monitoring != nil {
fields = append(fields, errorpassthroughrule.FieldSkipMonitoring)
}
if m.description != nil {
fields = append(fields, errorpassthroughrule.FieldDescription)
}
@@ -6663,6 +6703,8 @@ func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) {
return m.PassthroughBody()
case errorpassthroughrule.FieldCustomMessage:
return m.CustomMessage()
case errorpassthroughrule.FieldSkipMonitoring:
return m.SkipMonitoring()
case errorpassthroughrule.FieldDescription:
return m.Description()
}
@@ -6700,6 +6742,8 @@ func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string
return m.OldPassthroughBody(ctx)
case errorpassthroughrule.FieldCustomMessage:
return m.OldCustomMessage(ctx)
case errorpassthroughrule.FieldSkipMonitoring:
return m.OldSkipMonitoring(ctx)
case errorpassthroughrule.FieldDescription:
return m.OldDescription(ctx)
}
@@ -6802,6 +6846,13 @@ func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) er
}
m.SetCustomMessage(v)
return nil
case errorpassthroughrule.FieldSkipMonitoring:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetSkipMonitoring(v)
return nil
case errorpassthroughrule.FieldDescription:
v, ok := value.(string)
if !ok {
@@ -6963,6 +7014,9 @@ func (m *ErrorPassthroughRuleMutation) ResetField(name string) error {
case errorpassthroughrule.FieldCustomMessage:
m.ResetCustomMessage()
return nil
case errorpassthroughrule.FieldSkipMonitoring:
m.ResetSkipMonitoring()
return nil
case errorpassthroughrule.FieldDescription:
m.ResetDescription()
return nil
@@ -7059,6 +7113,8 @@ type GroupMutation struct {
mcp_xml_inject *bool
supported_model_scopes *[]string
appendsupported_model_scopes []string
sort_order *int
addsort_order *int
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -8411,6 +8467,62 @@ func (m *GroupMutation) ResetSupportedModelScopes() {
m.appendsupported_model_scopes = nil
}
// SetSortOrder sets the "sort_order" field.
func (m *GroupMutation) SetSortOrder(i int) {
m.sort_order = &i
m.addsort_order = nil
}
// SortOrder returns the value of the "sort_order" field in the mutation.
func (m *GroupMutation) SortOrder() (r int, exists bool) {
v := m.sort_order
if v == nil {
return
}
return *v, true
}
// OldSortOrder returns the old "sort_order" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldSortOrder is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldSortOrder requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldSortOrder: %w", err)
}
return oldValue.SortOrder, nil
}
// AddSortOrder adds i to the "sort_order" field.
func (m *GroupMutation) AddSortOrder(i int) {
if m.addsort_order != nil {
*m.addsort_order += i
} else {
m.addsort_order = &i
}
}
// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation.
func (m *GroupMutation) AddedSortOrder() (r int, exists bool) {
v := m.addsort_order
if v == nil {
return
}
return *v, true
}
// ResetSortOrder resets all changes to the "sort_order" field.
func (m *GroupMutation) ResetSortOrder() {
m.sort_order = nil
m.addsort_order = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -8769,7 +8881,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 24)
fields := make([]string, 0, 25)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -8842,6 +8954,9 @@ func (m *GroupMutation) Fields() []string {
if m.supported_model_scopes != nil {
fields = append(fields, group.FieldSupportedModelScopes)
}
if m.sort_order != nil {
fields = append(fields, group.FieldSortOrder)
}
return fields
}
@@ -8898,6 +9013,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.McpXMLInject()
case group.FieldSupportedModelScopes:
return m.SupportedModelScopes()
case group.FieldSortOrder:
return m.SortOrder()
}
return nil, false
}
@@ -8955,6 +9072,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldMcpXMLInject(ctx)
case group.FieldSupportedModelScopes:
return m.OldSupportedModelScopes(ctx)
case group.FieldSortOrder:
return m.OldSortOrder(ctx)
}
return nil, fmt.Errorf("unknown Group field %s", name)
}
@@ -9132,6 +9251,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetSupportedModelScopes(v)
return nil
case group.FieldSortOrder:
v, ok := value.(int)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetSortOrder(v)
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -9170,6 +9296,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.addfallback_group_id_on_invalid_request != nil {
fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
}
if m.addsort_order != nil {
fields = append(fields, group.FieldSortOrder)
}
return fields
}
@@ -9198,6 +9327,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedFallbackGroupID()
case group.FieldFallbackGroupIDOnInvalidRequest:
return m.AddedFallbackGroupIDOnInvalidRequest()
case group.FieldSortOrder:
return m.AddedSortOrder()
}
return nil, false
}
@@ -9277,6 +9408,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
}
m.AddFallbackGroupIDOnInvalidRequest(v)
return nil
case group.FieldSortOrder:
v, ok := value.(int)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddSortOrder(v)
return nil
}
return fmt.Errorf("unknown Group numeric field %s", name)
}
@@ -9445,6 +9583,9 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldSupportedModelScopes:
m.ResetSupportedModelScopes()
return nil
case group.FieldSortOrder:
m.ResetSortOrder()
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}

View File

@@ -326,6 +326,10 @@ func init() {
errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor()
// errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field.
errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool)
// errorpassthroughruleDescSkipMonitoring is the schema descriptor for skip_monitoring field.
errorpassthroughruleDescSkipMonitoring := errorpassthroughruleFields[11].Descriptor()
// errorpassthroughrule.DefaultSkipMonitoring holds the default value on creation for the skip_monitoring field.
errorpassthroughrule.DefaultSkipMonitoring = errorpassthroughruleDescSkipMonitoring.Default.(bool)
groupMixin := schema.Group{}.Mixin()
groupMixinHooks1 := groupMixin[1].Hooks()
group.Hooks[0] = groupMixinHooks1[0]
@@ -409,6 +413,10 @@ func init() {
groupDescSupportedModelScopes := groupFields[20].Descriptor()
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
// groupDescSortOrder is the schema descriptor for sort_order field.
groupDescSortOrder := groupFields[21].Descriptor()
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.

View File

@@ -105,6 +105,12 @@ func (ErrorPassthroughRule) Fields() []ent.Field {
Optional().
Nillable(),
// skip_monitoring: 是否跳过运维监控记录
// true: 匹配此规则的错误不会被记录到 ops_error_logs
// false: 正常记录到运维监控(默认行为)
field.Bool("skip_monitoring").
Default(false),
// description: 规则描述,用于说明规则的用途
field.Text("description").
Optional().

View File

@@ -121,6 +121,11 @@ func (Group) Fields() []ent.Field {
Default([]string{"claude", "gemini_text", "gemini_image"}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("支持的模型系列claude, gemini_text, gemini_image"),
// 分组排序 (added by migration 052)
field.Int("sort_order").
Default(0).
Comment("分组显示排序,数值越小越靠前"),
}
}
@@ -149,5 +154,6 @@ func (Group) Indexes() []ent.Index {
index.Fields("subscription_type"),
index.Fields("is_exclusive"),
index.Fields("deleted_at"),
index.Fields("sort_order"),
}
}

View File

@@ -103,6 +103,7 @@ require (
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect

View File

@@ -135,6 +135,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -170,6 +172,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -203,10 +207,14 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
@@ -230,6 +238,8 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -252,6 +262,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=

View File

@@ -424,10 +424,17 @@ type TestAccountRequest struct {
}
type SyncFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
SelectedAccountIDs []string `json:"selected_account_ids"`
}
type PreviewFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
// Test handles testing account connectivity with SSE streaming
@@ -466,10 +473,11 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
}
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
SelectedAccountIDs: req.SelectedAccountIDs,
})
if err != nil {
// Provide detailed error message for CRS sync failures
@@ -480,6 +488,28 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
response.Success(c, result)
}
// PreviewFromCRS handles previewing accounts from CRS before sync
// POST /api/v1/admin/accounts/sync/crs/preview
func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
var req PreviewFromCRSRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.crsSyncService.PreviewFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
})
if err != nil {
response.InternalError(c, "CRS preview failed: "+err.Error())
return
}
response.Success(c, result)
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {

View File

@@ -16,7 +16,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router := gin.New()
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc)
userHandler := NewUserHandler(adminSvc, nil)
groupHandler := NewGroupHandler(adminSvc)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc)

View File

@@ -357,5 +357,9 @@ func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int
return s.redeems, int64(len(s.redeems)), 100.0, nil
}
func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
return nil
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -32,6 +32,7 @@ type CreateErrorPassthroughRuleRequest struct {
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
SkipMonitoring *bool `json:"skip_monitoring"`
Description *string `json:"description"`
}
@@ -48,6 +49,7 @@ type UpdateErrorPassthroughRuleRequest struct {
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
SkipMonitoring *bool `json:"skip_monitoring"`
Description *string `json:"description"`
}
@@ -122,6 +124,9 @@ func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
} else {
rule.PassthroughBody = true
}
if req.SkipMonitoring != nil {
rule.SkipMonitoring = *req.SkipMonitoring
}
rule.ResponseCode = req.ResponseCode
rule.CustomMessage = req.CustomMessage
rule.Description = req.Description
@@ -190,6 +195,7 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
ResponseCode: existing.ResponseCode,
PassthroughBody: existing.PassthroughBody,
CustomMessage: existing.CustomMessage,
SkipMonitoring: existing.SkipMonitoring,
Description: existing.Description,
}
@@ -230,6 +236,9 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
if req.Description != nil {
rule.Description = req.Description
}
if req.SkipMonitoring != nil {
rule.SkipMonitoring = *req.SkipMonitoring
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {

View File

@@ -302,3 +302,36 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
}
response.Paginated(c, outKeys, total, page, pageSize)
}
// UpdateSortOrderRequest represents the request to update group sort orders
type UpdateSortOrderRequest struct {
Updates []struct {
ID int64 `json:"id" binding:"required"`
SortOrder int `json:"sort_order"`
} `json:"updates" binding:"required,min=1"`
}
// UpdateSortOrder handles updating group sort orders
// PUT /api/v1/admin/groups/sort-order
func (h *GroupHandler) UpdateSortOrder(c *gin.Context) {
var req UpdateSortOrderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
updates := make([]service.GroupSortOrderUpdate, 0, len(req.Updates))
for _, u := range req.Updates {
updates = append(updates, service.GroupSortOrderUpdate{
ID: u.ID,
SortOrder: u.SortOrder,
})
}
if err := h.adminService.UpdateGroupSortOrders(c.Request.Context(), updates); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Sort order updated successfully"})
}

View File

@@ -11,15 +11,23 @@ import (
"github.com/gin-gonic/gin"
)
// UserWithConcurrency wraps AdminUser with current concurrency info
type UserWithConcurrency struct {
dto.AdminUser
CurrentConcurrency int `json:"current_concurrency"`
}
// UserHandler handles admin user management
type UserHandler struct {
adminService service.AdminService
adminService service.AdminService
concurrencyService *service.ConcurrencyService
}
// NewUserHandler creates a new admin user handler
func NewUserHandler(adminService service.AdminService) *UserHandler {
func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler {
return &UserHandler{
adminService: adminService,
adminService: adminService,
concurrencyService: concurrencyService,
}
}
@@ -87,10 +95,30 @@ func (h *UserHandler) List(c *gin.Context) {
return
}
out := make([]dto.AdminUser, 0, len(users))
for i := range users {
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
// Batch get current concurrency (nil map if unavailable)
var loadInfo map[int64]*service.UserLoadInfo
if len(users) > 0 && h.concurrencyService != nil {
usersConcurrency := make([]service.UserWithConcurrency, len(users))
for i := range users {
usersConcurrency[i] = service.UserWithConcurrency{
ID: users[i].ID,
MaxConcurrency: users[i].Concurrency,
}
}
loadInfo, _ = h.concurrencyService.GetUsersLoadBatch(c.Request.Context(), usersConcurrency)
}
// Build response with concurrency info
out := make([]UserWithConcurrency, len(users))
for i := range users {
out[i] = UserWithConcurrency{
AdminUser: *dto.UserFromServiceAdmin(&users[i]),
}
if info := loadInfo[users[i].ID]; info != nil {
out[i].CurrentConcurrency = info.CurrentConcurrency
}
}
response.Paginated(c, out, total, page, pageSize)
}

View File

@@ -115,6 +115,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
MCPXMLInject: g.MCPXMLInject,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
SortOrder: g.SortOrder,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))

View File

@@ -2,11 +2,6 @@ package dto
import "time"
type ScopeRateLimitInfo struct {
ResetAt time.Time `json:"reset_at"`
RemainingSec int64 `json:"remaining_sec"`
}
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
@@ -98,6 +93,9 @@ type AdminGroup struct {
SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
// 分组排序
SortOrder int `json:"sort_order"`
}
type Account struct {
@@ -126,9 +124,6 @@ type Account struct {
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
// Antigravity scope 级限流状态(从 extra 提取)
ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"`

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
@@ -114,7 +115,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
@@ -203,6 +204,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则使用分组平台
@@ -229,9 +235,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
@@ -239,6 +253,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
}
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
} else {
@@ -333,17 +360,39 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
return
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// 错误响应已在Forward中处理这里只记录日志
@@ -385,10 +434,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
fallbackUsed := false
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
}
for {
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
var lastFailoverErr *service.UpstreamFailoverError
retryWithFallback := false
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
@@ -401,6 +458,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
}
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
} else {
@@ -482,7 +552,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
} else {
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
@@ -528,17 +598,39 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
sameAccountRetryCount[account.ID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
if !sleepSameAccountRetryDelay(c.Request.Context()) {
return
}
continue
}
// 同账号重试用尽,执行临时封禁并切换账号
if failoverErr.RetryableOnSameAccount {
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
}
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
return
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// 错误响应已在Forward中处理这里只记录日志
@@ -801,6 +893,65 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
const (
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
maxSameAccountRetries = 2
// sameAccountRetryDelay 同账号重试间隔
sameAccountRetryDelay = 500 * time.Millisecond
)
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
func sleepSameAccountRetryDelay(ctx context.Context) bool {
select {
case <-ctx.Done():
return false
case <-time.After(sameAccountRetryDelay):
return true
}
}
// sleepFailoverDelay 账号切换线性递增延时第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
delay := time.Duration(switchCount-1) * time.Second
if delay <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
// 当分组内只有一个可用账号且上游返回 503MODEL_CAPACITY_EXHAUSTED时使用
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
// (最多 3 次、总等待 30s所以 Handler 层的退避只需短暂等待即可。
// 返回 false 表示 context 已取消。
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
// 固定短延时2s
// Service 层已经在原地等待了足够长的时间retryDelay × 重试次数),
// Handler 层只需短暂间隔后重新进入 Service 层即可。
const delay = 2 * time.Second
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount)
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
@@ -820,6 +971,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
msg = *rule.CustomMessage
}
if rule.SkipMonitoring {
c.Set(service.OpsSkipPassthroughKey, true)
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}
@@ -934,7 +1089,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
@@ -962,6 +1117,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
// 计算粘性会话 hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 选择支持该模型的账号

View File

@@ -0,0 +1,51 @@
package handler
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// sleepAntigravitySingleAccountBackoff 测试
// ---------------------------------------------------------------------------
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.True(t, ok, "should return true when context is not canceled")
// 固定延迟 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
}
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
elapsed := time.Since(start)
require.False(t, ok, "should return false when context is canceled")
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
}
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
// 验证不同 retryCount 都使用固定 2s 延迟
ctx := context.Background()
start := time.Now()
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
elapsed := time.Since(start)
require.True(t, ok)
// 即使 retryCount=5延迟仍然是固定的 2s
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
require.Less(t, elapsed, 5*time.Second)
}

View File

@@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
@@ -30,13 +31,6 @@ import (
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
return true
}
return geminiCLITmpDirRegex.Match(body)
}
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
@@ -239,7 +233,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
sessionHash := extractGeminiCLISessionHash(c, body)
if sessionHash == "" {
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
parsedReq, _ := service.ParseGatewayRequest(body)
parsedReq, _ := service.ParseGatewayRequest(body, domain.PlatformGemini)
if parsedReq != nil {
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
}
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
}
sessionKey := sessionHash
@@ -258,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var geminiDigestChain string
var geminiPrefixHash string
var geminiSessionUUID string
var matchedDigestChain string
useDigestFallback := sessionBoundAccountID == 0
if useDigestFallback {
@@ -284,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
)
// 查找会话
foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession(
foundUUID, foundAccountID, foundMatchedChain, found := h.gatewayService.FindGeminiSession(
c.Request.Context(),
derefGroupID(apiKey.GroupID),
geminiPrefixHash,
geminiDigestChain,
)
if found {
matchedDigestChain = foundMatchedChain
sessionBoundAccountID = foundAccountID
geminiSessionUUID = foundUUID
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
@@ -316,7 +319,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
isCLI := isGeminiCLIRequest(c, body)
cleanedForUnknownBinding := false
maxAccountSwitches := h.maxAccountSwitchesGemini
@@ -325,6 +327,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
@@ -332,6 +341,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
failedAccountIDs = make(map[int64]struct{})
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
c.Request = c.Request.WithContext(ctx)
continue
}
}
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return
}
@@ -344,10 +366,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
// 为避免第一次转发就 400这里做一次确定性清理让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
@@ -410,7 +432,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
} else {
result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body)
@@ -422,7 +444,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if failoverErr.ForceCacheBilling {
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
@@ -433,6 +455,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
lastFailoverErr = failoverErr
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// ForwardNative already wrote the response
@@ -453,6 +480,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
geminiDigestChain,
geminiSessionUUID,
account.ID,
matchedDigestChain,
); err != nil {
log.Printf("[Gemini] Failed to save digest session: %v", err)
}
@@ -526,6 +554,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
msg = *rule.CustomMessage
}
if rule.SkipMonitoring {
c.Set(service.OpsSkipPassthroughKey, true)
}
googleError(c, respCode, msg)
return
}

View File

@@ -354,6 +354,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
msg = *rule.CustomMessage
}
if rule.SkipMonitoring {
c.Set(service.OpsSkipPassthroughKey, true)
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}

View File

@@ -537,6 +537,13 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
// Skip logging if a passthrough rule with skip_monitoring=true matched.
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
if skip, _ := v.(bool); skip {
return
}
}
enqueueOpsErrorLog(ops, entry, requestBody)
return
}
@@ -544,6 +551,13 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
body := w.buf.Bytes()
parsed := parseOpsErrorResponse(body)
// Skip logging if a passthrough rule with skip_monitoring=true matched.
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
if skip, _ := v.(bool); skip {
return
}
}
// Skip logging if the error should be filtered based on settings
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
return

View File

@@ -18,6 +18,7 @@ type ErrorPassthroughRule struct {
ResponseCode *int `json:"response_code"` // 自定义状态码passthrough_code=false 时使用)
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
CustomMessage *string `json:"custom_message"` // 自定义错误信息passthrough_body=false 时使用)
SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录
Description *string `json:"description"` // 规则描述
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`

View File

@@ -115,6 +115,23 @@ type LoadCodeAssistResponse struct {
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
}
// OnboardUserRequest onboardUser 请求
type OnboardUserRequest struct {
TierID string `json:"tierId"`
Metadata struct {
IDEType string `json:"ideType"`
Platform string `json:"platform,omitempty"`
PluginType string `json:"pluginType,omitempty"`
} `json:"metadata"`
}
// OnboardUserResponse onboardUser 响应
type OnboardUserResponse struct {
Name string `json:"name,omitempty"`
Done bool `json:"done"`
Response map[string]any `json:"response,omitempty"`
}
// GetTier 获取账户类型
// 优先返回 paidTier付费订阅级别否则返回 currentTier
func (r *LoadCodeAssistResponse) GetTier() string {
@@ -361,6 +378,117 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
return nil, nil, lastErr
}
// OnboardUser 触发账号 onboarding并返回 project_id
// 说明:
// 1) 部分账号 loadCodeAssist 不会立即返回 cloudaicompanionProject
// 2) 这时需要调用 onboardUser 完成初始化,之后才能拿到 project_id。
func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
tierID = strings.TrimSpace(tierID)
if tierID == "" {
return "", fmt.Errorf("tier_id 为空")
}
reqBody := OnboardUserRequest{TierID: tierID}
reqBody.Metadata.IDEType = "ANTIGRAVITY"
reqBody.Metadata.Platform = "PLATFORM_UNSPECIFIED"
reqBody.Metadata.PluginType = "GEMINI"
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
availableURLs := BaseURLs
var lastErr error
for urlIdx, baseURL := range availableURLs {
apiURL := baseURL + "/v1internal:onboardUser"
for attempt := 1; attempt <= 5; attempt++ {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes))
if err != nil {
lastErr = fmt.Errorf("创建请求失败: %w", err)
break
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("onboardUser 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
log.Printf("[antigravity] onboardUser URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
break
}
return "", lastErr
}
respBodyBytes, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
}
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
log.Printf("[antigravity] onboardUser URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
break
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("onboardUser 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
return "", lastErr
}
var onboardResp OnboardUserResponse
if err := json.Unmarshal(respBodyBytes, &onboardResp); err != nil {
lastErr = fmt.Errorf("onboardUser 响应解析失败: %w", err)
return "", lastErr
}
if onboardResp.Done {
if projectID := extractProjectIDFromOnboardResponse(onboardResp.Response); projectID != "" {
DefaultURLAvailability.MarkSuccess(baseURL)
return projectID, nil
}
lastErr = fmt.Errorf("onboardUser 完成但未返回 project_id")
return "", lastErr
}
// done=false 时等待后重试(与 CLIProxyAPI 行为一致)
select {
case <-time.After(2 * time.Second):
case <-ctx.Done():
return "", ctx.Err()
}
}
}
if lastErr != nil {
return "", lastErr
}
return "", fmt.Errorf("onboardUser 未返回 project_id")
}
func extractProjectIDFromOnboardResponse(resp map[string]any) string {
if len(resp) == 0 {
return ""
}
if v, ok := resp["cloudaicompanionProject"]; ok {
switch project := v.(type) {
case string:
return strings.TrimSpace(project)
case map[string]any:
if id, ok := project["id"].(string); ok {
return strings.TrimSpace(id)
}
}
}
return ""
}
// ModelQuotaInfo 模型配额信息
type ModelQuotaInfo struct {
RemainingFraction float64 `json:"remainingFraction"`

View File

@@ -0,0 +1,76 @@
package antigravity
import (
"testing"
)
func TestExtractProjectIDFromOnboardResponse(t *testing.T) {
t.Parallel()
tests := []struct {
name string
resp map[string]any
want string
}{
{
name: "nil response",
resp: nil,
want: "",
},
{
name: "empty response",
resp: map[string]any{},
want: "",
},
{
name: "project as string",
resp: map[string]any{
"cloudaicompanionProject": "my-project-123",
},
want: "my-project-123",
},
{
name: "project as string with spaces",
resp: map[string]any{
"cloudaicompanionProject": " my-project-123 ",
},
want: "my-project-123",
},
{
name: "project as map with id",
resp: map[string]any{
"cloudaicompanionProject": map[string]any{
"id": "proj-from-map",
},
},
want: "proj-from-map",
},
{
name: "project as map without id",
resp: map[string]any{
"cloudaicompanionProject": map[string]any{
"name": "some-name",
},
},
want: "",
},
{
name: "missing cloudaicompanionProject key",
resp: map[string]any{
"otherField": "value",
},
want: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := extractProjectIDFromOnboardResponse(tc.resp)
if got != tc.want {
t.Fatalf("extractProjectIDFromOnboardResponse() = %q, want %q", got, tc.want)
}
})
}
}

View File

@@ -271,6 +271,21 @@ func filterOpenCodePrompt(text string) string {
return ""
}
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
var systemBlockFilterPrefixes = []string{
"x-anthropic-billing-header",
}
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
func filterSystemBlockByPrefix(text string) string {
for _, prefix := range systemBlockFilterPrefixes {
if strings.HasPrefix(text, prefix) {
return ""
}
}
return text
}
// buildSystemInstruction 构建 systemInstruction与 Antigravity-Manager 保持一致)
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
var parts []GeminiPart
@@ -287,8 +302,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if strings.Contains(sysStr, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(sysStr)
// 过滤 OpenCode 默认提示词和黑名单前缀
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
@@ -302,8 +317,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if strings.Contains(block.Text, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(block.Text)
// 过滤 OpenCode 默认提示词和黑名单前缀
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}

View File

@@ -28,4 +28,8 @@ const (
// IsMaxTokensOneHaikuRequest 标识当前请求是否为 max_tokens=1 + haiku 模型的探测请求
// 用于 ClaudeCodeOnly 验证绕过(绕过 system prompt 检查,但仍需验证 User-Agent
IsMaxTokensOneHaikuRequest Key = "ctx_is_max_tokens_one_haiku"
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
// 在此模式下Service 层的模型限流预检查将等待限流过期而非直接切换账号。
SingleAccountRetry Key = "ctx_single_account_retry"
)

View File

@@ -282,6 +282,34 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return &accounts[0], nil
}
func (r *accountRepository) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT id, extra->>'crs_account_id'
FROM accounts
WHERE deleted_at IS NULL
AND extra->>'crs_account_id' IS NOT NULL
AND extra->>'crs_account_id' != ''
`)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make(map[string]int64)
for rows.Next() {
var id int64
var crsID string
if err := rows.Scan(&id, &crsID); err != nil {
return nil, err
}
result[crsID] = id
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
if account == nil {
return nil
@@ -798,53 +826,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return nil
}
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
now := time.Now().UTC()
payload := map[string]string{
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
raw, err := json.Marshal(payload)
if err != nil {
return err
}
scopeKey := string(scope)
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
`UPDATE accounts SET
extra = jsonb_set(
jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true),
ARRAY['antigravity_quota_scopes', $1]::text[],
$2::jsonb,
true
),
updated_at = NOW(),
last_used_at = NOW()
WHERE id = $3 AND deleted_at IS NULL`,
scopeKey,
raw,
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
if scope == "" {
return nil

View File

@@ -485,6 +485,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.McpXMLInject,
SupportedModelScopes: g.SupportedModelScopes,
SortOrder: g.SortOrder,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}

View File

@@ -54,7 +54,8 @@ func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.Err
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
SetPassthroughBody(rule.PassthroughBody).
SetSkipMonitoring(rule.SkipMonitoring)
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
@@ -90,7 +91,8 @@ func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.Err
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
SetPassthroughBody(rule.PassthroughBody).
SetSkipMonitoring(rule.SkipMonitoring)
// 处理可选字段
if len(rule.ErrorCodes) > 0 {
@@ -149,6 +151,7 @@ func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model
Platforms: e.Platforms,
PassthroughCode: e.PassthroughCode,
PassthroughBody: e.PassthroughBody,
SkipMonitoring: e.SkipMonitoring,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}

View File

@@ -11,63 +11,6 @@ import (
const stickySessionPrefix = "sticky_session:"
// Gemini Trie Lua 脚本
const (
// geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本
// KEYS[1] = trie key
// ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d")
// ARGV[2] = TTL seconds (用于刷新)
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
// 查找成功时自动刷新 TTL防止活跃会话意外过期
geminiTrieFindScript = `
local chain = ARGV[1]
local ttl = tonumber(ARGV[2])
local lastMatch = nil
local path = ""
for part in string.gmatch(chain, "[^-]+") do
path = path == "" and part or path .. "-" .. part
local val = redis.call('HGET', KEYS[1], path)
if val and val ~= "" then
lastMatch = val
end
end
if lastMatch then
redis.call('EXPIRE', KEYS[1], ttl)
end
return lastMatch
`
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
// KEYS[1] = trie key
// ARGV[1] = digestChain
// ARGV[2] = value (uuid:accountID)
// ARGV[3] = TTL seconds
geminiTrieSaveScript = `
local chain = ARGV[1]
local value = ARGV[2]
local ttl = tonumber(ARGV[3])
local path = ""
for part in string.gmatch(chain, "[^-]+") do
path = path == "" and part or path .. "-" .. part
end
redis.call('HSET', KEYS[1], path, value)
redis.call('EXPIRE', KEYS[1], ttl)
return "OK"
`
)
// 模型负载统计相关常量
const (
modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀
modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀
modelLoadTTL = 24 * time.Hour // 调用次数 TTL24 小时无调用后清零)
modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL
)
type gatewayCache struct {
rdb *redis.Client
}
@@ -108,133 +51,3 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}
// ============ Antigravity 模型负载统计方法 ============
// modelLoadKey 构建模型调用次数 key
// 格式: ag:model_load:{accountID}:{model}
func modelLoadKey(accountID int64, model string) string {
return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model)
}
// modelLastUsedKey 构建模型最后调度时间 key
// 格式: ag:model_last_used:{accountID}:{model}
func modelLastUsedKey(accountID int64, model string) string {
return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model)
}
// IncrModelCallCount 增加模型调用次数并更新最后调度时间
// 返回更新后的调用次数
func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
loadKey := modelLoadKey(accountID, model)
lastUsedKey := modelLastUsedKey(accountID, model)
pipe := c.rdb.Pipeline()
incrCmd := pipe.Incr(ctx, loadKey)
pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL
pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, err
}
return incrCmd.Val(), nil
}
// GetModelLoadBatch 批量获取账号的模型负载信息
func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) {
if len(accountIDs) == 0 {
return make(map[int64]*service.ModelLoadInfo), nil
}
loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model)
return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil
}
// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作
func (c *gatewayCache) pipelineModelLoadGet(
ctx context.Context,
accountIDs []int64,
model string,
) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) {
pipe := c.rdb.Pipeline()
loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
for _, id := range accountIDs {
loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model))
lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model))
}
_, _ = pipe.Exec(ctx) // 忽略错误key 不存在是正常的
return loadCmds, lastUsedCmds
}
// parseModelLoadResults 解析 Pipeline 结果
func (c *gatewayCache) parseModelLoadResults(
accountIDs []int64,
loadCmds map[int64]*redis.StringCmd,
lastUsedCmds map[int64]*redis.StringCmd,
) map[int64]*service.ModelLoadInfo {
result := make(map[int64]*service.ModelLoadInfo, len(accountIDs))
for _, id := range accountIDs {
result[id] = &service.ModelLoadInfo{
CallCount: getInt64OrZero(loadCmds[id]),
LastUsedAt: getTimeOrZero(lastUsedCmds[id]),
}
}
return result
}
// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0
func getInt64OrZero(cmd *redis.StringCmd) int64 {
val, _ := cmd.Int64()
return val
}
// getTimeOrZero 从 StringCmd 获取 time.Time失败返回零值
func getTimeOrZero(cmd *redis.StringCmd) time.Time {
val, err := cmd.Int64()
if err != nil {
return time.Time{}
}
return time.Unix(val, 0)
}
// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============
// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询)
// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL
func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" {
return "", 0, false
}
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
// 使用 Lua 脚本在 Redis 端执行 Trie 查找O(L) 次 HGET1 次网络往返
// 查找成功时自动刷新 TTL防止活跃会话意外过期
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
if err != nil || result == nil {
return "", 0, false
}
value, ok := result.(string)
if !ok || value == "" {
return "", 0, false
}
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
return uuid, accountID, ok
}
// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本)
func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" {
return nil
}
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
value := service.FormatGeminiSessionValue(uuid, accountID)
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
}

View File

@@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
// ============ Gemini Trie 会话测试 ============
func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() {
groupID := int64(1)
prefixHash := "testprefix"
digestChain := "u:hash1-m:hash2-u:hash3"
uuid := "test-uuid-123"
accountID := int64(42)
// 保存会话
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID)
require.NoError(s.T(), err, "SaveGeminiSession")
// 精确匹配查找
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain)
require.True(s.T(), found, "should find exact match")
require.Equal(s.T(), uuid, foundUUID)
require.Equal(s.T(), accountID, foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() {
groupID := int64(1)
prefixHash := "prefixmatch"
shortChain := "u:a-m:b"
longChain := "u:a-m:b-u:c-m:d"
uuid := "uuid-prefix"
accountID := int64(100)
// 保存短链
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID)
require.NoError(s.T(), err)
// 用长链查找,应该匹配到短链(前缀匹配)
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain)
require.True(s.T(), found, "should find prefix match")
require.Equal(s.T(), uuid, foundUUID)
require.Equal(s.T(), accountID, foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() {
groupID := int64(1)
prefixHash := "longestmatch"
// 保存多个不同长度的链
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1)
require.NoError(s.T(), err)
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2)
require.NoError(s.T(), err)
err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3)
require.NoError(s.T(), err)
// 查找更长的链,应该匹配到最长的前缀
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e")
require.True(s.T(), found, "should find longest prefix match")
require.Equal(s.T(), "uuid-long", foundUUID)
require.Equal(s.T(), int64(3), foundAccountID)
// 查找中等长度的链
foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x")
require.True(s.T(), found)
require.Equal(s.T(), "uuid-medium", foundUUID)
require.Equal(s.T(), int64(2), foundAccountID)
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() {
groupID := int64(1)
prefixHash := "nomatch"
digestChain := "u:a-m:b"
// 保存一个会话
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1)
require.NoError(s.T(), err)
// 用不同的链查找,应该找不到
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y")
require.False(s.T(), found, "should not find non-matching chain")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() {
groupID := int64(1)
digestChain := "u:a-m:b"
// 保存到 prefixHash1
err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1)
require.NoError(s.T(), err)
// 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离)
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain)
require.False(s.T(), found, "different prefixHash should be isolated")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() {
prefixHash := "sameprefix"
digestChain := "u:a-m:b"
// 保存到 groupID 1
err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1)
require.NoError(s.T(), err)
// 用 groupID 2 查找,应该找不到(分组隔离)
_, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain)
require.False(s.T(), found, "different groupID should be isolated")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() {
groupID := int64(1)
prefixHash := "emptytest"
// 空链不应该保存
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1)
require.NoError(s.T(), err, "empty chain should not error")
// 空链查找应该返回 false
_, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "")
require.False(s.T(), found, "empty chain should not match")
}
func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() {
groupID := int64(1)
prefixHash := "multisession"
// 保存多个不同会话(模拟 1000 个并发会话的场景)
sessions := []struct {
chain string
uuid string
accountID int64
}{
{"u:session1", "uuid-1", 1},
{"u:session2-m:reply2", "uuid-2", 2},
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
}
for _, sess := range sessions {
err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID)
require.NoError(s.T(), err)
}
// 验证每个会话都能正确查找
for _, sess := range sessions {
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain)
require.True(s.T(), found, "should find session: %s", sess.chain)
require.Equal(s.T(), sess.uuid, foundUUID)
require.Equal(s.T(), sess.accountID, foundAccountID)
}
// 验证继续对话的场景
foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg")
require.True(s.T(), found)
require.Equal(s.T(), "uuid-2", foundUUID)
require.Equal(s.T(), int64(2), foundAccountID)
}
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))

View File

@@ -1,234 +0,0 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// ============ Gateway Cache 模型负载统计集成测试 ============
type GatewayCacheModelLoadSuite struct {
suite.Suite
}
func TestGatewayCacheModelLoadSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheModelLoadSuite))
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(123)
model := "claude-sonnet-4-20250514"
// 首次调用应返回 1
count1, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
// 第二次调用应返回 2
count2, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(2), count2)
// 第三次调用应返回 3
count3, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
require.Equal(t, int64(3), count3)
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(456)
model1 := "claude-sonnet-4-20250514"
model2 := "claude-opus-4-5-20251101"
// 不同模型应该独立计数
count1, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
count2, err := cache.IncrModelCallCount(ctx, accountID, model2)
require.NoError(t, err)
require.Equal(t, int64(1), count2)
count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
require.Equal(t, int64(2), count1Again)
}
func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
account1 := int64(111)
account2 := int64(222)
model := "gemini-2.5-pro"
// 不同账号应该独立计数
count1, err := cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
require.Equal(t, int64(1), count1)
count2, err := cache.IncrModelCallCount(ctx, account2, model)
require.NoError(t, err)
require.Equal(t, int64(1), count2)
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model")
require.NoError(t, err)
require.NotNil(t, result)
require.Empty(t, result)
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
// 查询不存在的账号应返回零值
result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514")
require.NoError(t, err)
require.Len(t, result, 2)
require.Equal(t, int64(0), result[9999].CallCount)
require.True(t, result[9999].LastUsedAt.IsZero())
require.Equal(t, int64(0), result[9998].CallCount)
require.True(t, result[9998].LastUsedAt.IsZero())
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(789)
model := "claude-sonnet-4-20250514"
// 先增加调用次数
beforeIncr := time.Now()
_, err := cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, accountID, model)
require.NoError(t, err)
afterIncr := time.Now()
// 获取负载信息
result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model)
require.NoError(t, err)
require.Len(t, result, 1)
loadInfo := result[accountID]
require.NotNil(t, loadInfo)
require.Equal(t, int64(3), loadInfo.CallCount)
require.False(t, loadInfo.LastUsedAt.IsZero())
// LastUsedAt 应该在 beforeIncr 和 afterIncr 之间
require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr))
require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr))
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
model := "claude-opus-4-5-20251101"
account1 := int64(1001)
account2 := int64(1002)
account3 := int64(1003) // 不调用
// account1 调用 2 次
_, err := cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
_, err = cache.IncrModelCallCount(ctx, account1, model)
require.NoError(t, err)
// account2 调用 5 次
for i := 0; i < 5; i++ {
_, err = cache.IncrModelCallCount(ctx, account2, model)
require.NoError(t, err)
}
// 批量获取
result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model)
require.NoError(t, err)
require.Len(t, result, 3)
require.Equal(t, int64(2), result[account1].CallCount)
require.False(t, result[account1].LastUsedAt.IsZero())
require.Equal(t, int64(5), result[account2].CallCount)
require.False(t, result[account2].LastUsedAt.IsZero())
require.Equal(t, int64(0), result[account3].CallCount)
require.True(t, result[account3].LastUsedAt.IsZero())
}
func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() {
t := s.T()
rdb := testRedis(t)
cache := &gatewayCache{rdb: rdb}
ctx := context.Background()
accountID := int64(2001)
model1 := "claude-sonnet-4-20250514"
model2 := "gemini-2.5-pro"
// 对 model1 调用 3 次
for i := 0; i < 3; i++ {
_, err := cache.IncrModelCallCount(ctx, accountID, model1)
require.NoError(t, err)
}
// 获取 model1 的负载
result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1)
require.NoError(t, err)
require.Equal(t, int64(3), result1[accountID].CallCount)
// 获取 model2 的负载(应该为 0
result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2)
require.NoError(t, err)
require.Equal(t, int64(0), result2[accountID].CallCount)
}
// ============ 辅助函数测试 ============
func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() {
t := s.T()
key := modelLoadKey(123, "claude-sonnet-4")
require.Equal(t, "ag:model_load:123:claude-sonnet-4", key)
}
func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() {
t := s.T()
key := modelLastUsedKey(456, "gemini-2.5-pro")
require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key)
}

View File

@@ -191,7 +191,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
groups, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Asc(group.FieldID)).
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
@@ -218,7 +218,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
groups, err := r.client.Group.Query().
Where(group.StatusEQ(service.StatusActive)).
Order(dbent.Asc(group.FieldID)).
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, err
@@ -245,7 +245,7 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
groups, err := r.client.Group.Query().
Where(group.StatusEQ(service.StatusActive), group.PlatformEQ(platform)).
Order(dbent.Asc(group.FieldID)).
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, err
@@ -497,3 +497,29 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64
return nil
}
// UpdateSortOrders 批量更新分组排序
func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
if len(updates) == 0 {
return nil
}
// 使用事务批量更新
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for _, u := range updates {
if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}

View File

@@ -896,6 +896,10 @@ func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int
return nil, errors.New("not implemented")
}
func (stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
return nil
}
type stubAccountRepo struct {
bulkUpdateIDs []int64
}
@@ -1004,10 +1008,6 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return errors.New("not implemented")
}
@@ -1049,6 +1049,10 @@ func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s
return int64(len(ids)), nil
}
func (s *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return nil, errors.New("not implemented")
}
type stubProxyRepo struct{}
func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error {

View File

@@ -192,6 +192,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
@@ -208,6 +209,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)

View File

@@ -425,6 +425,22 @@ func (a *Account) GetBaseURL() string {
if baseURL == "" {
return "https://api.anthropic.com"
}
if a.Platform == PlatformAntigravity {
return strings.TrimRight(baseURL, "/") + "/antigravity"
}
return baseURL
}
// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。
// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。
func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string {
baseURL := strings.TrimSpace(a.GetCredential("base_url"))
if baseURL == "" {
return defaultBaseURL
}
if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey {
return strings.TrimRight(baseURL, "/") + "/antigravity"
}
return baseURL
}

View File

@@ -0,0 +1,160 @@
//go:build unit
package service
import (
"testing"
)
func TestGetBaseURL(t *testing.T) {
tests := []struct {
name string
account Account
expected string
}{
{
name: "non-apikey type returns empty",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAnthropic,
},
expected: "",
},
{
name: "apikey without base_url returns default anthropic",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAnthropic,
Credentials: map[string]any{},
},
expected: "https://api.anthropic.com",
},
{
name: "apikey with custom base_url",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAnthropic,
Credentials: map[string]any{"base_url": "https://custom.example.com"},
},
expected: "https://custom.example.com",
},
{
name: "antigravity apikey auto-appends /antigravity",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity apikey trims trailing slash before appending",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity non-apikey returns empty",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetBaseURL()
if result != tt.expected {
t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected)
}
})
}
}
func TestGetGeminiBaseURL(t *testing.T) {
const defaultGeminiURL = "https://generativelanguage.googleapis.com"
tests := []struct {
name string
account Account
expected string
}{
{
name: "apikey without base_url returns default",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{},
},
expected: defaultGeminiURL,
},
{
name: "apikey with custom base_url",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"},
},
expected: "https://custom-gemini.example.com",
},
{
name: "antigravity apikey auto-appends /antigravity",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity apikey trims trailing slash",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity oauth does NOT append /antigravity",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "https://upstream.example.com",
},
{
name: "oauth without base_url returns default",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{},
},
expected: defaultGeminiURL,
},
{
name: "nil credentials returns default",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
expected: defaultGeminiURL,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetGeminiBaseURL(defaultGeminiURL)
if result != tt.expected {
t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected)
}
})
}
}

View File

@@ -25,6 +25,9 @@ type AccountRepository interface {
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
// for all accounts that have been synced from CRS.
ListCRSAccountIDs(ctx context.Context) (map[string]int64, error)
Update(ctx context.Context, account *Account) error
Delete(ctx context.Context, id int64) error
@@ -50,7 +53,6 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error

View File

@@ -54,6 +54,10 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st
panic("unexpected GetByCRSAccountID call")
}
func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
panic("unexpected ListCRSAccountIDs call")
}
func (s *accountRepoStub) Update(ctx context.Context, account *Account) error {
panic("unexpected Update call")
}
@@ -143,10 +147,6 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic("unexpected SetRateLimited call")
}
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
panic("unexpected SetAntigravityQuotaScopeLimit call")
}
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
panic("unexpected SetModelRateLimit call")
}

View File

@@ -245,7 +245,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
// Set common headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
// Apply Claude Code client headers
for key, value := range claude.DefaultHeaders {
@@ -254,8 +253,10 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
// Set authentication header
if useBearer {
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
req.Header.Set("Authorization", "Bearer "+authToken)
} else {
req.Header.Set("anthropic-beta", claude.APIKeyBetaHeader)
req.Header.Set("x-api-key", authToken)
}

View File

@@ -36,6 +36,7 @@ type AdminService interface {
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
@@ -1015,6 +1016,10 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
return keys, result.Total, nil
}
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}

View File

@@ -172,6 +172,10 @@ func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []
panic("unexpected GetAccountIDsByGroupIDs call")
}
func (s *groupRepoStub) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return nil
}
type proxyRepoStub struct {
deleteErr error
countErr error

View File

@@ -116,6 +116,10 @@ func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []i
panic("unexpected GetAccountIDsByGroupIDs call")
}
func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
return nil
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{}
@@ -395,6 +399,10 @@ func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Contex
panic("unexpected GetAccountIDsByGroupIDs call")
}
func (s *groupRepoStubForFallbackCycle) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
return nil
}
type groupRepoStubForInvalidRequestFallback struct {
groups map[int64]*Group
created *Group
@@ -466,6 +474,10 @@ func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.C
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForInvalidRequestFallback) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
return nil
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{

View File

@@ -0,0 +1,79 @@
package service
import (
"encoding/json"
"strings"
"time"
)
// Anthropic 会话 Fallback 相关常量
const (
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL5 分钟)
anthropicSessionTTLSeconds = 300
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
)
// AnthropicSessionTTL 返回 Anthropic 会话缓存 TTL
func AnthropicSessionTTL() time.Duration {
return anthropicSessionTTLSeconds * time.Second
}
// BuildAnthropicDigestChain 根据 Anthropic 请求生成摘要链
// 格式: s:<hash>-u:<hash>-a:<hash>-u:<hash>-...
// s = system, u = user, a = assistant
func BuildAnthropicDigestChain(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
var parts []string
// 1. system prompt
if parsed.System != nil {
systemData, _ := json.Marshal(parsed.System)
if len(systemData) > 0 && string(systemData) != "null" {
parts = append(parts, "s:"+shortHash(systemData))
}
}
// 2. messages
for _, msg := range parsed.Messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
role, _ := msgMap["role"].(string)
prefix := rolePrefix(role)
content := msgMap["content"]
contentData, _ := json.Marshal(content)
parts = append(parts, prefix+":"+shortHash(contentData))
}
return strings.Join(parts, "-")
}
// rolePrefix 将 Anthropic 的 role 映射为单字符前缀
func rolePrefix(role string) string {
switch role {
case "assistant":
return "a"
default:
return "u"
}
}
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
prefix := prefixHash
if len(prefixHash) >= 8 {
prefix = prefixHash[:8]
}
uuidPart := uuid
if len(uuid) >= 8 {
uuidPart = uuid[:8]
}
return anthropicDigestSessionKeyPrefix + prefix + ":" + uuidPart
}

View File

@@ -0,0 +1,320 @@
package service
import (
"strings"
"testing"
)
func TestBuildAnthropicDigestChain_NilRequest(t *testing.T) {
result := BuildAnthropicDigestChain(nil)
if result != "" {
t.Errorf("expected empty string for nil request, got: %s", result)
}
}
func TestBuildAnthropicDigestChain_EmptyMessages(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{},
}
result := BuildAnthropicDigestChain(parsed)
if result != "" {
t.Errorf("expected empty string for empty messages, got: %s", result)
}
}
func TestBuildAnthropicDigestChain_SingleUserMessage(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("expected prefix 'u:', got: %s", parts[0])
}
}
func TestBuildAnthropicDigestChain_UserAndAssistant(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("part[0] expected prefix 'u:', got: %s", parts[0])
}
if !strings.HasPrefix(parts[1], "a:") {
t.Errorf("part[1] expected prefix 'a:', got: %s", parts[1])
}
}
func TestBuildAnthropicDigestChain_WithSystemString(t *testing.T) {
parsed := &ParsedRequest{
System: "You are a helpful assistant",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "s:") {
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
}
if !strings.HasPrefix(parts[1], "u:") {
t.Errorf("part[1] expected prefix 'u:', got: %s", parts[1])
}
}
func TestBuildAnthropicDigestChain_WithSystemContentBlocks(t *testing.T) {
parsed := &ParsedRequest{
System: []any{
map[string]any{"type": "text", "text": "You are a helpful assistant"},
},
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "s:") {
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
}
}
func TestBuildAnthropicDigestChain_ConversationPrefixRelationship(t *testing.T) {
// 核心测试:验证对话增长时链的前缀关系
// 上一轮的完整链一定是下一轮链的前缀
system := "You are a helpful assistant"
// 第 1 轮: system + user
round1 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
chain1 := BuildAnthropicDigestChain(round1)
// 第 2 轮: system + user + assistant + user
round2 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
map[string]any{"role": "user", "content": "how are you?"},
},
}
chain2 := BuildAnthropicDigestChain(round2)
// 第 3 轮: system + user + assistant + user + assistant + user
round3 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
map[string]any{"role": "user", "content": "how are you?"},
map[string]any{"role": "assistant", "content": "I'm doing well"},
map[string]any{"role": "user", "content": "great"},
},
}
chain3 := BuildAnthropicDigestChain(round3)
t.Logf("Chain1: %s", chain1)
t.Logf("Chain2: %s", chain2)
t.Logf("Chain3: %s", chain3)
// chain1 是 chain2 的前缀
if !strings.HasPrefix(chain2, chain1) {
t.Errorf("chain1 should be prefix of chain2:\n chain1: %s\n chain2: %s", chain1, chain2)
}
// chain2 是 chain3 的前缀
if !strings.HasPrefix(chain3, chain2) {
t.Errorf("chain2 should be prefix of chain3:\n chain2: %s\n chain3: %s", chain2, chain3)
}
// chain1 也是 chain3 的前缀(传递性)
if !strings.HasPrefix(chain3, chain1) {
t.Errorf("chain1 should be prefix of chain3:\n chain1: %s\n chain3: %s", chain1, chain3)
}
}
func TestBuildAnthropicDigestChain_DifferentSystemProducesDifferentChain(t *testing.T) {
parsed1 := &ParsedRequest{
System: "System A",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
parsed2 := &ParsedRequest{
System: "System B",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
chain1 := BuildAnthropicDigestChain(parsed1)
chain2 := BuildAnthropicDigestChain(parsed2)
if chain1 == chain2 {
t.Error("Different system prompts should produce different chains")
}
// 但 user 部分的 hash 应该相同
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
if parts1[1] != parts2[1] {
t.Error("Same user message should produce same hash regardless of system")
}
}
func TestBuildAnthropicDigestChain_DifferentContentProducesDifferentChain(t *testing.T) {
parsed1 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "ORIGINAL reply"},
map[string]any{"role": "user", "content": "next"},
},
}
parsed2 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "TAMPERED reply"},
map[string]any{"role": "user", "content": "next"},
},
}
chain1 := BuildAnthropicDigestChain(parsed1)
chain2 := BuildAnthropicDigestChain(parsed2)
if chain1 == chain2 {
t.Error("Different content should produce different chains")
}
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
// 第一个 user message hash 应该相同
if parts1[0] != parts2[0] {
t.Error("First user message hash should be the same")
}
// assistant reply hash 应该不同
if parts1[1] == parts2[1] {
t.Error("Assistant reply hash should differ")
}
}
func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
parsed := &ParsedRequest{
System: "test system",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi"},
},
}
chain1 := BuildAnthropicDigestChain(parsed)
chain2 := BuildAnthropicDigestChain(parsed)
if chain1 != chain2 {
t.Errorf("BuildAnthropicDigestChain not deterministic: %s vs %s", chain1, chain2)
}
}
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
tests := []struct {
name string
prefixHash string
uuid string
want string
}{
{
name: "normal 16 char hash with uuid",
prefixHash: "abcdefgh12345678",
uuid: "550e8400-e29b-41d4-a716-446655440000",
want: "anthropic:digest:abcdefgh:550e8400",
},
{
name: "exactly 8 chars",
prefixHash: "12345678",
uuid: "abcdefgh",
want: "anthropic:digest:12345678:abcdefgh",
},
{
name: "short values",
prefixHash: "abc",
uuid: "xyz",
want: "anthropic:digest:abc:xyz",
},
{
name: "empty values",
prefixHash: "",
uuid: "",
want: "anthropic:digest::",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateAnthropicDigestSessionKey(tt.prefixHash, tt.uuid)
if got != tt.want {
t.Errorf("GenerateAnthropicDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want)
}
})
}
// 验证不同 uuid 产生不同 sessionKey
t.Run("different uuid different key", func(t *testing.T) {
hash := "sameprefix123456"
result1 := GenerateAnthropicDigestSessionKey(hash, "uuid0001-session-a")
result2 := GenerateAnthropicDigestSessionKey(hash, "uuid0002-session-b")
if result1 == result2 {
t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2)
}
})
}
func TestAnthropicSessionTTL(t *testing.T) {
ttl := AnthropicSessionTTL()
if ttl.Seconds() != 300 {
t.Errorf("expected 300 seconds, got: %v", ttl.Seconds())
}
}
func TestBuildAnthropicDigestChain_ContentBlocks(t *testing.T) {
// 测试 content 为 content blocks 数组的情况
parsed := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "describe this image"},
map[string]any{"type": "image", "source": map[string]any{"type": "base64"}},
},
},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("expected prefix 'u:', got: %s", parts[0])
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,17 +4,42 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter
type antigravityFailingWriter struct {
gin.ResponseWriter
failAfter int // 允许成功写入的次数,之后所有写入返回错误
writes int
}
func (w *antigravityFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService
func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService {
return &AntigravityGatewayService{
settingService: &SettingService{cfg: cfg},
}
}
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
req := &antigravity.ClaudeRequest{
Model: "claude-sonnet-4-5",
@@ -337,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling
// 验证:ForwardGemini 粘性会话切换时UpstreamFailoverError.ForceCacheBilling 应为 true
// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies
// that ForwardGemini sets ForceCacheBilling=true for sticky session switch.
func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
@@ -391,3 +416,438 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
// --- 流式 happy path 测试 ---
// TestStreamUpstreamResponse_NormalComplete
// 验证正常流式转发完成时数据正确透传、usage 正确收集、clientDisconnect=false
func TestStreamUpstreamResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `event: message_start`)
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: content_block_delta`)
fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: message_delta`)
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`)
fmt.Fprintln(pw, "")
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pr.Close()
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta")
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证数据被透传到客户端
body := rec.Body.String()
require.Contains(t, body, "event: message_start")
require.Contains(t, body, "content_block_delta")
require.Contains(t, body, "message_delta")
}
// TestHandleGeminiStreamingResponse_NormalComplete
// 验证:正常 Gemini 流式转发数据正确透传、usage 正确收集
func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// 第一个 chunk部分内容
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`)
fmt.Fprintln(pw, "")
// 第二个 chunk最终内容+完整 usage
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
// Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2
// → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2
require.Equal(t, 8, result.usage.InputTokens)
require.Equal(t, 8, result.usage.OutputTokens)
require.Equal(t, 2, result.usage.CacheReadInputTokens)
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证数据被透传到客户端
body := rec.Body.String()
require.Contains(t, body, "Hello")
require.Contains(t, body, "world")
// 不应包含错误事件
require.NotContains(t, body, "event: error")
}
// TestHandleClaudeStreamingResponse_NormalComplete
// 验证:正常 Claude 流式转发Gemini→Claude 转换),数据正确转换并输出
func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// v1internal 包装格式Gemini 数据嵌套在 "response" 字段下
// ProcessLine 先尝试反序列化为 V1InternalResponse裸格式会导致 Response.UsageMetadata 为空
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect")
require.NotNil(t, result.usage)
// Gemini→Claude 转换的 usagepromptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3
require.Equal(t, 5, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.NotNil(t, result.firstTokenMs, "should record first token time")
// 验证输出是 Claude SSE 格式processor 会转换)
body := rec.Body.String()
require.Contains(t, body, "event: message_start", "should contain Claude message_start event")
require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event")
// 不应包含错误事件
require.NotContains(t, body, "event: error")
}
// --- 流式客户端断开检测测试 ---
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
// 验证客户端写入失败后streamUpstreamResponse 继续读取上游以收集 usage
func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `event: message_start`)
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`)
fmt.Fprintln(pw, "")
fmt.Fprintln(pw, `event: message_delta`)
fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`)
fmt.Fprintln(pw, "")
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pr.Close()
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotNil(t, result.usage)
require.Equal(t, 20, result.usage.OutputTokens)
}
// TestStreamUpstreamResponse_ContextCanceled
// 验证context 取消时返回 usage 且标记 clientDisconnect
func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result := svc.streamUpstreamResponse(c, resp, time.Now())
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestStreamUpstreamResponse_Timeout
// 验证:上游超时时返回已收集的 usage
func TestStreamUpstreamResponse_Timeout(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pw.Close()
_ = pr.Close()
require.NotNil(t, result)
require.False(t, result.clientDisconnect)
}
// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect
// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect
func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`)
fmt.Fprintln(pw, "")
// 不关闭 pw → 等待超时
}()
result := svc.streamUpstreamResponse(c, resp, time.Now())
_ = pw.Close()
_ = pr.Close()
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
// TestHandleGeminiStreamingResponse_ClientDisconnect
// 验证Gemini 流式转发中客户端断开后继续 drain 上游
func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "write_failed")
}
// TestHandleGeminiStreamingResponse_ContextCanceled
// 验证context 取消时不注入错误事件
func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestHandleClaudeStreamingResponse_ClientDisconnect
// 验证Claude 流式转发中客户端断开后继续 drain 上游
func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
go func() {
defer func() { _ = pw.Close() }()
// v1internal 包装格式
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`)
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
// TestHandleClaudeStreamingResponse_ContextCanceled
// 验证context 取消时不注入错误事件
func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newAntigravityTestService(&config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
})
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.NotContains(t, rec.Body.String(), "event: error")
}
// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage
func TestExtractSSEUsage(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
line string
expected ClaudeUsage
}{
{
name: "message_delta with output_tokens",
line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`,
expected: ClaudeUsage{OutputTokens: 42},
},
{
name: "non-data line ignored",
line: `event: message_start`,
expected: ClaudeUsage{},
},
{
name: "top-level usage with all fields",
line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`,
expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
usage := &ClaudeUsage{}
svc.extractSSEUsage(tt.line, usage)
require.Equal(t, tt.expected, *usage)
})
}
}
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
func TestAntigravityClientWriter(t *testing.T) {
t.Run("normal write succeeds", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(c.Writer, flusher, "test")
ok := cw.Write([]byte("hello"))
require.True(t, ok)
require.False(t, cw.Disconnected())
require.Contains(t, rec.Body.String(), "hello")
})
t.Run("write failure marks disconnected", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(fw, flusher, "test")
ok := cw.Write([]byte("hello"))
require.False(t, ok)
require.True(t, cw.Disconnected())
})
t.Run("subsequent writes are no-op", func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(fw, flusher, "test")
cw.Write([]byte("first"))
ok := cw.Fprintf("second %d", 2)
require.False(t, ok)
require.True(t, cw.Disconnected())
})
}

View File

@@ -273,12 +273,21 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
}
client := antigravity.NewClient(proxyURL)
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil
}
if err == nil {
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
return projectID, nil
} else if onboardErr != nil {
lastErr = onboardErr
continue
}
}
// 记录错误
if err != nil {
lastErr = err
@@ -292,6 +301,65 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
tierID := resolveDefaultTierID(loadRaw)
if tierID == "" {
return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
}
projectID, err := client.OnboardUser(ctx, accessToken, tierID)
if err != nil {
return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
}
return projectID, nil
}
func resolveDefaultTierID(loadRaw map[string]any) string {
if len(loadRaw) == 0 {
return ""
}
rawTiers, ok := loadRaw["allowedTiers"]
if !ok {
return ""
}
tiers, ok := rawTiers.([]any)
if !ok {
return ""
}
for _, rawTier := range tiers {
tier, ok := rawTier.(map[string]any)
if !ok {
continue
}
if isDefault, _ := tier["isDefault"].(bool); !isDefault {
continue
}
if id, ok := tier["id"].(string); ok {
id = strings.TrimSpace(id)
if id != "" {
return id
}
}
}
return ""
}
// FillProjectID 仅获取 project_id不刷新 OAuth token
func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) {
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
}
// BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{

View File

@@ -0,0 +1,82 @@
package service
import (
"testing"
)
func TestResolveDefaultTierID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
loadRaw map[string]any
want string
}{
{
name: "nil loadRaw",
loadRaw: nil,
want: "",
},
{
name: "missing allowedTiers",
loadRaw: map[string]any{
"paidTier": map[string]any{"id": "g1-pro-tier"},
},
want: "",
},
{
name: "empty allowedTiers",
loadRaw: map[string]any{"allowedTiers": []any{}},
want: "",
},
{
name: "tier missing id field",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"isDefault": true},
},
},
want: "",
},
{
name: "allowedTiers but no default",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": "free-tier", "isDefault": false},
map[string]any{"id": "standard-tier", "isDefault": false},
},
},
want: "",
},
{
name: "default tier found",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": "free-tier", "isDefault": true},
map[string]any{"id": "standard-tier", "isDefault": false},
},
},
want: "free-tier",
},
{
name: "default tier id with spaces",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": " standard-tier ", "isDefault": true},
},
},
want: "standard-tier",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := resolveDefaultTierID(tc.loadRaw)
if got != tc.want {
t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
}
})
}
}

View File

@@ -2,63 +2,23 @@ package service
import (
"context"
"slices"
"strings"
"time"
)
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
// AntigravityQuotaScope 表示 Antigravity 的配额域
type AntigravityQuotaScope string
const (
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
)
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool {
if len(supportedScopes) == 0 {
// 未配置时默认全部支持
return true
}
supported := slices.Contains(supportedScopes, string(scope))
return supported
}
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
return resolveAntigravityQuotaScope(requestedModel)
}
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
model := normalizeAntigravityModelName(requestedModel)
if model == "" {
return "", false
}
switch {
case strings.HasPrefix(model, "claude-"):
return AntigravityQuotaScopeClaude, true
case strings.HasPrefix(model, "gemini-"):
if isImageGenerationModel(model) {
return AntigravityQuotaScopeGeminiImage, true
}
return AntigravityQuotaScopeGeminiText, true
default:
return "", false
}
}
func normalizeAntigravityModelName(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
normalized = strings.TrimPrefix(normalized, "models/")
return normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。
// resolveAntigravityModelKey 根据请求的模型名解析限流 key
// 返回空字符串表示无法解析
func resolveAntigravityModelKey(requestedModel string) string {
return normalizeAntigravityModelName(requestedModel)
}
// IsSchedulableForModel 结合模型级限流判断是否可调度。
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
@@ -74,107 +34,20 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
return false
}
if a.Platform != PlatformAntigravity {
return true
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return true
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return true
}
now := time.Now()
return !now.Before(*resetAt)
return true
}
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
if a == nil || a.Extra == nil || scope == "" {
return nil
}
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
if !ok {
return nil
}
rawScope, ok := rawScopes[string(scope)].(map[string]any)
if !ok {
return nil
}
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
if !ok || strings.TrimSpace(resetAtRaw) == "" {
return nil
}
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
if err != nil {
return nil
}
return &resetAt
}
var antigravityAllScopes = []AntigravityQuotaScope{
AntigravityQuotaScopeClaude,
AntigravityQuotaScopeGeminiText,
AntigravityQuotaScopeGeminiImage,
}
func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 {
if a == nil || a.Platform != PlatformAntigravity {
return nil
}
now := time.Now()
result := make(map[string]int64)
for _, scope := range antigravityAllScopes {
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt != nil && now.Before(*resetAt) {
remainingSec := int64(time.Until(*resetAt).Seconds())
if remainingSec > 0 {
result[string(scope)] = remainingSec
}
}
}
if len(result) == 0 {
return nil
}
return result
}
// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间
// 返回 0 表示未限流或已过期
func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration {
if a == nil || a.Platform != PlatformAntigravity {
return 0
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return 0
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return 0
}
if remaining := time.Until(*resetAt); remaining > 0 {
return remaining
}
return 0
}
// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值)
// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel)
if modelRemaining > scopeRemaining {
return modelRemaining
}
return scopeRemaining
return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
}

View File

@@ -59,12 +59,6 @@ func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string,
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
type scopeLimitCall struct {
accountID int64
scope AntigravityQuotaScope
resetAt time.Time
}
type rateLimitCall struct {
accountID int64
resetAt time.Time
@@ -78,16 +72,10 @@ type modelRateLimitCall struct {
type stubAntigravityAccountRepo struct {
AccountRepository
scopeCalls []scopeLimitCall
rateCalls []rateLimitCall
modelRateLimitCalls []modelRateLimitCall
}
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
return nil
}
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
return nil
@@ -98,7 +86,9 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i
return nil
}
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
t.Setenv(antigravityForwardBaseURLEnv, "")
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
@@ -131,10 +121,9 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
quotaScope: AntigravityQuotaScopeClaude,
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCalled = true
return nil
},
@@ -144,32 +133,16 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.False(t, handleErrorCalled)
require.Len(t, upstream.calls, 2)
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
require.True(t, handleErrorCalled)
require.Len(t, upstream.calls, antigravityMaxRetries)
for _, callURL := range upstream.calls {
require.True(t, strings.HasPrefix(callURL, base1))
}
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
require.NotEmpty(t, available)
require.Equal(t, base2, available[0])
}
func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) {
// 分区限流始终开启,不再支持通过环境变量关闭
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
body := buildGeminiRateLimitBody("3s")
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
require.Len(t, repo.scopeCalls, 1)
require.Empty(t, repo.rateCalls)
call := repo.scopeCalls[0]
require.Equal(t, account.ID, call.accountID)
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
require.Equal(t, base1, available[0])
}
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
@@ -189,7 +162,7 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
// 应该触发模型限流
require.NotNil(t, result)
@@ -200,31 +173,32 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) {
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流
// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走模型级限流兜底
func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity}
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reasonscope 限流
// 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason走模型级限流兜底
body := buildGeminiRateLimitBody("5s")
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false)
// 不应该触发模型限流,应该走 scope 限流
// handleModelRateLimit 不会处理(因为没有 RATE_LIMIT_EXCEEDED
// 但 429 兜底逻辑会使用 requestedModel 设置模型级限流
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Len(t, repo.scopeCalls, 1)
require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity}
// 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流
// 503 + MODEL_CAPACITY_EXHAUSTED → 等待重试,不切换账号
body := []byte(`{
"error": {
"status": "UNAVAILABLE",
@@ -235,15 +209,15 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 应该触发模型限流
// MODEL_CAPACITY_EXHAUSTED 应该标记为已处理,不切换账号,不设置模型限流
// 实际重试由 handleSmartRetry 处理
require.NotNil(t, result)
require.True(t, result.Handled)
require.NotNil(t, result.SwitchError)
require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel)
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
require.False(t, result.ShouldRetry, "MODEL_CAPACITY_EXHAUSTED should not trigger retry from handleModelRateLimit path")
require.Nil(t, result.SwitchError, "MODEL_CAPACITY_EXHAUSTED should not trigger account switch")
require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
}
// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理)
@@ -263,12 +237,11 @@ func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) {
}
}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 503 非模型限流不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit")
require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit")
require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit")
}
@@ -281,12 +254,11 @@ func TestHandleUpstreamError_503_EmptyBody(t *testing.T) {
// 503 + 空响应体 → 不做任何处理
body := []byte(`{}`)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false)
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
// 503 空响应不应该做任何处理
require.Nil(t, result)
require.Empty(t, repo.modelRateLimitCalls)
require.Empty(t, repo.scopeCalls)
require.Empty(t, repo.rateCalls)
}
@@ -307,15 +279,7 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
account.RateLimitResetAt = nil
account.Extra = map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future.Format(time.RFC3339),
},
},
}
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
}
@@ -341,11 +305,12 @@ func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) {
func TestParseAntigravitySmartRetryInfo(t *testing.T) {
tests := []struct {
name string
body string
expectedDelay time.Duration
expectedModel string
expectedNil bool
name string
body string
expectedDelay time.Duration
expectedModel string
expectedNil bool
expectedIsModelCapacityExhausted bool
}{
{
name: "valid complete response with RATE_LIMIT_EXCEEDED",
@@ -408,8 +373,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
"message": "No capacity available for model gemini-3-pro-high on the server"
}
}`,
expectedDelay: 39 * time.Second,
expectedModel: "gemini-3-pro-high",
expectedDelay: 39 * time.Second,
expectedModel: "gemini-3-pro-high",
expectedIsModelCapacityExhausted: true,
},
{
name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil",
@@ -520,6 +486,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
if result.ModelName != tt.expectedModel {
t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel)
}
if result.IsModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
t.Errorf("IsModelCapacityExhausted = %v, want %v", result.IsModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
}
})
}
}
@@ -531,13 +500,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
apiKeyAccount := &Account{Type: AccountTypeAPIKey}
tests := []struct {
name string
account *Account
body string
expectedShouldRetry bool
expectedShouldRateLimit bool
minWait time.Duration
modelName string
name string
account *Account
body string
expectedShouldRetry bool
expectedShouldRateLimit bool
expectedIsModelCapacityExhausted bool
minWait time.Duration
modelName string
}{
{
name: "OAuth account with short delay (< 7s) - smart retry",
@@ -635,6 +605,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 7 * time.Second,
modelName: "gemini-pro",
},
{
@@ -650,12 +621,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
]
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "gemini-3-pro-high",
expectedShouldRetry: true,
expectedShouldRateLimit: false,
expectedIsModelCapacityExhausted: true,
minWait: 1 * time.Second,
modelName: "gemini-3-pro-high",
},
{
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit",
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use fixed wait",
account: oauthAccount,
body: `{
"error": {
@@ -667,9 +640,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
"message": "No capacity available for model gemini-2.5-flash on the server"
}
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
modelName: "gemini-2.5-flash",
expectedShouldRetry: true,
expectedShouldRateLimit: false,
expectedIsModelCapacityExhausted: true,
minWait: 1 * time.Second,
modelName: "gemini-2.5-flash",
},
{
name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit",
@@ -686,24 +661,33 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
}`,
expectedShouldRetry: false,
expectedShouldRateLimit: true,
minWait: 30 * time.Second,
modelName: "claude-sonnet-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
shouldRetry, shouldRateLimit, wait, model, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
if shouldRetry != tt.expectedShouldRetry {
t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry)
}
if shouldRateLimit != tt.expectedShouldRateLimit {
t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit)
}
if isModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
t.Errorf("isModelCapacityExhausted = %v, want %v", isModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
}
if shouldRetry {
if wait < tt.minWait {
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
}
}
if shouldRateLimit && tt.minWait > 0 {
if wait < tt.minWait {
t.Errorf("rate limit wait = %v, want >= %v", wait, tt.minWait)
}
}
if (shouldRetry || shouldRateLimit) && model != tt.modelName {
t.Errorf("modelName = %q, want %q", model, tt.modelName)
}
@@ -803,7 +787,7 @@ func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) {
require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope")
}
func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) {
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 1,
@@ -815,19 +799,15 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
// RFC3339 here is second-precision; keep it safely in the future.
"rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339),
},
},
},
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
defer cancel()
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
@@ -836,17 +816,21 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, result)
require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession)
require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check")
}
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) {
func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 2,
@@ -875,7 +859,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t
requestedModel: "claude-sonnet-4-5",
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
@@ -946,6 +930,22 @@ func TestIsAntigravityAccountSwitchError(t *testing.T) {
}
}
func TestResolveAntigravityForwardBaseURL_DefaultDaily(t *testing.T) {
t.Setenv(antigravityForwardBaseURLEnv, "")
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
defer func() {
antigravity.BaseURLs = oldBaseURLs
}()
prodURL := "https://prod.test"
dailyURL := "https://daily.test"
antigravity.BaseURLs = []string{dailyURL, prodURL}
resolved := resolveAntigravityForwardBaseURL()
require.Equal(t, dailyURL, resolved)
}
func TestAntigravityAccountSwitchError_Error(t *testing.T) {
err := &AntigravityAccountSwitchError{
OriginalAccountID: 789,

View File

@@ -0,0 +1,904 @@
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// 辅助函数:构造带 SingleAccountRetry 标记的 context
// ---------------------------------------------------------------------------
func ctxWithSingleAccountRetry() context.Context {
return context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
}
// ---------------------------------------------------------------------------
// 1. isSingleAccountRetry 测试
// ---------------------------------------------------------------------------
func TestIsSingleAccountRetry_True(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
require.True(t, isSingleAccountRetry(ctx))
}
func TestIsSingleAccountRetry_False_NoValue(t *testing.T) {
require.False(t, isSingleAccountRetry(context.Background()))
}
func TestIsSingleAccountRetry_False_ExplicitFalse(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, false)
require.False(t, isSingleAccountRetry(ctx))
}
func TestIsSingleAccountRetry_False_WrongType(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, "true")
require.False(t, isSingleAccountRetry(ctx))
}
// ---------------------------------------------------------------------------
// 2. 常量验证
// ---------------------------------------------------------------------------
func TestSingleAccountRetryConstants(t *testing.T) {
require.Equal(t, 3, antigravitySingleAccountSmartRetryMaxAttempts,
"单账号原地重试最多 3 次")
require.Equal(t, 15*time.Second, antigravitySingleAccountSmartRetryMaxWait,
"单次最大等待 15s")
require.Equal(t, 30*time.Second, antigravitySingleAccountSmartRetryTotalMaxWait,
"总累计等待不超过 30s")
}
// ---------------------------------------------------------------------------
// 3. handleSmartRetry + 503 + SingleAccountRetry → 走 handleSingleAccountRetryInPlace
// (而非设模型限流 + 切换账号)
// ---------------------------------------------------------------------------
// TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace
// 核心场景503 + retryDelay >= 7s + SingleAccountRetry 标记
// → 不设模型限流、不切换账号,改为原地重试
func TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace(t *testing.T) {
// 原地重试成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 1,
Name: "acc-single",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
// 503 + 39s >= 7s 阈值 + MODEL_CAPACITY_EXHAUSTED
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
],
"message": "No capacity available for model gemini-3-pro-high on the server"
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(), // 关键:设置单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 关键断言:返回 resp原地重试成功而非 switchError切换账号
require.NotNil(t, result.resp, "should return successful response from in-place retry")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should NOT return switchError in single account mode")
require.Nil(t, result.err)
// 验证未设模型限流(单账号模式不应设限流)
require.Len(t, repo.modelRateLimitCalls, 0,
"should NOT set model rate limit in single account retry mode")
// 验证确实调用了 upstream原地重试
require.GreaterOrEqual(t, len(upstream.calls), 1, "should have made at least one retry call")
}
// TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches
// 对照组503 + retryDelay >= 7s + 无 SingleAccountRetry 标记
// → 照常设模型限流 + 切换账号
func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 2,
Name: "acc-multi",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 503 + 39s >= 7s 阈值(使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED
// 因为 MODEL_CAPACITY_EXHAUSTED 走独立的重试路径,不触发 shouldRateLimitModel
respBody := []byte(`{
"error": {
"code": 503,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(), // 关键:无单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 对照:多账号模式返回 switchError
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
require.Nil(t, result.resp, "should not return resp when switchError is set")
// 对照:多账号模式应设模型限流
require.Len(t, repo.modelRateLimitCalls, 1,
"multi-account mode SHOULD set model rate limit")
}
// TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches
// 边界情况429非 503+ SingleAccountRetry 标记
// → 单账号原地重试仅针对 503429 依然走切换账号逻辑
func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
Name: "acc-429",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 429 + 15s >= 7s 阈值
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests, // 429不是 503
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(), // 有单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 429 即使有单账号标记,也应走切换账号
require.NotNil(t, result.switchError, "429 should still return switchError even with SingleAccountRetry")
require.Len(t, repo.modelRateLimitCalls, 1,
"429 should still set model rate limit even with SingleAccountRetry")
}
// ---------------------------------------------------------------------------
// 4. handleSmartRetry + 503 + 短延迟 + SingleAccountRetry → 智能重试耗尽后不设限流
// ---------------------------------------------------------------------------
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503不设限流
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
// 智能重试也返回 503
failRespBody := `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 4,
Name: "acc-short-503",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 0.1s < 7s 阈值
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 关键断言:单账号 503 模式下,智能重试耗尽后直接返回 503 响应,不切换
require.NotNil(t, result.resp, "should return 503 response directly for single account mode")
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
require.Nil(t, result.switchError, "should NOT switch account in single account mode")
// 关键断言:不设模型限流
require.Len(t, repo.modelRateLimitCalls, 0,
"should NOT set model rate limit for 503 in single account mode")
}
// TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit
// 对照组503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流
// 使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED因为后者走独立的 60 次重试路径
func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) {
failRespBody := `{
"error": {
"code": 503,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 5,
Name: "acc-multi-503",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"code": 503,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(), // 无单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 对照:多账号模式应返回 switchError
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
// 对照:多账号模式应设模型限流
require.Len(t, repo.modelRateLimitCalls, 1,
"multi-account mode should set model rate limit")
}
// ---------------------------------------------------------------------------
// 5. handleSingleAccountRetryInPlace 直接测试
// ---------------------------------------------------------------------------
// TestHandleSingleAccountRetryInPlace_Success 原地重试成功
func TestHandleSingleAccountRetryInPlace_Success(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
account := &Account{
ID: 10,
Name: "acc-inplace-ok",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not switch account on success")
require.Nil(t, result.err)
}
// TestHandleSingleAccountRetryInPlace_AllRetriesFail 所有重试都失败,返回 503不设限流
func TestHandleSingleAccountRetryInPlace_AllRetriesFail(t *testing.T) {
// 构造 3 个 503 响应(对应 3 次原地重试)
var responses []*http.Response
var errors []error
for i := 0; i < antigravitySingleAccountSmartRetryMaxAttempts; i++ {
responses = append(responses, &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)),
})
errors = append(errors, nil)
}
upstream := &mockSmartRetryUpstream{
responses: responses,
errors: errors,
}
account := &Account{
ID: 11,
Name: "acc-inplace-fail",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
origBody := []byte(`{"error":{"code":503,"status":"UNAVAILABLE"}}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{"X-Test": {"original"}},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, origBody, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 关键:返回 503 resp不返回 switchError
require.NotNil(t, result.resp, "should return 503 response directly")
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
require.Nil(t, result.switchError, "should NOT return switchError - let Handler handle it")
require.Nil(t, result.err)
// 验证确实重试了指定次数
require.Len(t, upstream.calls, antigravitySingleAccountSmartRetryMaxAttempts,
"should have made exactly maxAttempts retry calls")
}
// TestHandleSingleAccountRetryInPlace_WaitDurationClamped 等待时间被限制在 [min, max] 范围
func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) {
// 用短延迟的成功响应,只验证不 panic
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
account := &Account{
ID: 12,
Name: "acc-clamp",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
// 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)
require.Equal(t, http.StatusOK, result.resp.StatusCode)
}
// TestHandleSingleAccountRetryInPlace_ContextCanceled context 取消时立即返回
func TestHandleSingleAccountRetryInPlace_ContextCanceled(t *testing.T) {
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil},
errors: []error{nil},
}
account := &Account{
ID: 13,
Name: "acc-cancel",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
ctx, cancel := context.WithCancel(context.Background())
ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true)
cancel() // 立即取消
params := antigravityRetryLoopParams{
ctx: ctx,
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Error(t, result.err, "should return context error")
// 不应调用 upstream因为在等待阶段就被取消了
require.Len(t, upstream.calls, 0, "should not call upstream when context is canceled")
}
// TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry 网络错误时继续重试
func TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
// 第1次网络错误nil resp第2次成功
responses: []*http.Response{nil, successResp},
errors: []error{nil, nil},
}
account := &Account{
ID: 14,
Name: "acc-net-retry",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after network error recovery")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Len(t, upstream.calls, 2, "first call fails (network error), second succeeds")
}
// ---------------------------------------------------------------------------
// 6. antigravityRetryLoop 预检查:单账号模式跳过限流
// ---------------------------------------------------------------------------
// TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit
// 预检查中,如果有 SingleAccountRetry 标记,即使账号已限流也跳过直接发请求
func TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit(t *testing.T) {
// 创建一个已设模型限流的账号
upstream := &recordingOKUpstream{}
account := &Account{
ID: 20,
Name: "acc-rate-limited",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err, "should not return error")
require.NotNil(t, result, "should return result")
require.NotNil(t, result.resp, "should have response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
// 关键:尽管限流了,有 SingleAccountRetry 标记时仍然到达了 upstream
require.Equal(t, 1, upstream.calls, "should have reached upstream despite rate limit")
}
// TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit
// 对照组:无 SingleAccountRetry + 已限流 → 预检查返回 switchError
func TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 21,
Name: "acc-rate-limited-multi",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(), // 无单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result, "should not return result on rate limit switch")
require.NotNil(t, err, "should return error")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr, "should return AntigravityAccountSwitchError")
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
// upstream 不应被调用(预检查就短路了)
require.Equal(t, 0, upstream.calls, "upstream should NOT be called when pre-check blocks")
}
// ---------------------------------------------------------------------------
// 7. 端到端集成场景测试
// ---------------------------------------------------------------------------
// TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E
// 端到端场景503 + 单账号 + 原地重试第2次成功
func TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E(t *testing.T) {
// 第1次原地重试仍返回 503第2次成功
fail503Body := `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
resp503 := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(fail503Body)),
}
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{resp503, successResp},
errors: []error{nil, nil},
}
account := &Account{
ID: 30,
Name: "acc-e2e",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after 2nd attempt")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError)
require.Len(t, upstream.calls, 2, "first 503, second OK")
}
// TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E
// 通过 antigravityRetryLoop → handleSmartRetry → handleSingleAccountRetryInPlace 完整链路
func TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E(t *testing.T) {
// 初始请求返回 503 + 长延迟
initial503Body := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "10s"}
],
"message": "No capacity available"
}
}`)
initial503Resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(initial503Body)),
}
// 原地重试成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
// 第1次调用retryLoop 主循环)返回 503
// 第2次调用handleSingleAccountRetryInPlace 原地重试)返回 200
responses: []*http.Response{initial503Resp, successResp},
errors: []error{nil, nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 31,
Name: "acc-e2e-loop",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err, "should not return error on successful retry")
require.NotNil(t, result, "should return result")
require.NotNil(t, result.resp, "should return response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
// 验证未设模型限流
require.Len(t, repo.modelRateLimitCalls, 0,
"should NOT set model rate limit in single account retry mode")
}

View File

@@ -13,6 +13,23 @@ import (
"github.com/stretchr/testify/require"
)
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
// 仅关注 DeleteSessionAccountID 的调用记录
type stubSmartRetryCache struct {
GatewayCache // 嵌入接口,未实现的方法 panic确保只调用预期方法
deleteCalls []deleteSessionCall
}
type deleteSessionCall struct {
groupID int64
sessionHash string
}
func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID int64, sessionHash string) error {
c.deleteCalls = append(c.deleteCalls, deleteSessionCall{groupID: groupID, sessionHash: sessionHash})
return nil
}
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
type mockSmartRetryUpstream struct {
responses []*http.Response
@@ -58,7 +75,7 @@ func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) {
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -110,7 +127,7 @@ func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) {
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -177,7 +194,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -198,7 +215,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) {
// TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError
func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *testing.T) {
// 智能重试后仍然返回 429需要提供 3 个响应,因为智能重试最多 3 次)
// 智能重试后仍然返回 429需要提供 1 个响应,因为智能重试最多 1 次)
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
@@ -213,19 +230,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp2 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
failResp3 := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp1, failResp2, failResp3},
errors: []error{nil, nil, nil},
responses: []*http.Response{failResp1},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
@@ -236,7 +243,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
Platform: PlatformAntigravity,
}
// 3s < 7s 阈值,应该触发智能重试(最多 3 次)
// 3s < 7s 阈值,应该触发智能重试(最多 1 次)
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
@@ -262,7 +269,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
httpUpstream: upstream,
accountRepo: repo,
isStickySession: false,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -284,11 +291,12 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-flash", repo.modelRateLimitCalls[0].modelKey)
require.Len(t, upstream.calls, 3, "should have made three retry calls (max attempts)")
require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)")
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError
func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) {
// TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess 测试 503 MODEL_CAPACITY_EXHAUSTED 重试成功
// MODEL_CAPACITY_EXHAUSTED 使用固定 1s 间隔重试,不切换账号
func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
@@ -297,7 +305,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
Platform: PlatformAntigravity,
}
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s(上游 retryDelay 应被忽略,使用固定 1s
respBody := []byte(`{
"error": {
"code": 503,
@@ -315,6 +323,14 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
Body: io.NopCloser(bytes.NewReader(respBody)),
}
// mock: 第 1 次重试返回 200 成功
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{
{StatusCode: http.StatusOK, Header: http.Header{}, Body: io.NopCloser(strings.NewReader(`{"ok":true}`))},
},
errors: []error{nil},
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
@@ -323,8 +339,9 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
httpUpstream: upstream,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -336,16 +353,67 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Nil(t, result.resp)
require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.err)
require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel)
require.True(t, result.switchError.IsStickySession)
require.Nil(t, result.switchError, "MODEL_CAPACITY_EXHAUSTED should not return switchError")
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
// 不应设置模型限流
require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
require.Len(t, upstream.calls, 1, "should have made one retry call before success")
}
// TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel 测试 MODEL_CAPACITY_EXHAUSTED 上下文取消
func TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
Name: "acc-3",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
// 立即取消上下文,验证重试循环能正确退出
ctx, cancel := context.WithCancel(context.Background())
cancel()
params := antigravityRetryLoopParams{
ctx: ctx,
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Error(t, result.err, "should return context error")
require.Nil(t, result.switchError, "should not return switchError on context cancel")
require.Empty(t, repo.modelRateLimitCalls, "should not set model rate limit on context cancel")
}
// TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑
@@ -380,7 +448,7 @@ func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -429,7 +497,7 @@ func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T)
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -480,7 +548,7 @@ func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) {
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -541,7 +609,7 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
@@ -556,19 +624,15 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing
require.True(t, switchErr.IsStickySession)
}
// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试
func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
// 第一次网络错误,第二次成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
// TestHandleSmartRetry_NetworkError_ExhaustsRetry 测试网络错误时maxAttempts=1直接耗尽重试并切换账号
func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) {
// 唯一一次重试遇到网络错误nil response
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil, successResp}, // 第一次返回 nil模拟网络错误
errors: []error{nil, nil}, // mock 不返回 error靠 nil response 触发
responses: []*http.Response{nil}, // 返回 nil模拟网络错误
errors: []error{nil}, // mock 不返回 error靠 nil response 触发
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 8,
Name: "acc-8",
@@ -600,7 +664,8 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -612,10 +677,15 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) {
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after network error recovery")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not return switchError on success")
require.Len(t, upstream.calls, 2, "should have made two retry calls")
require.Nil(t, result.resp, "should not return resp when switchError is set")
require.NotNil(t, result.switchError, "should return switchError after network error exhausted retry")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel)
require.Len(t, upstream.calls, 1, "should have made one retry call")
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流
@@ -653,7 +723,7 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
@@ -674,3 +744,617 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) {
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// ---------------------------------------------------------------------------
// 以下测试覆盖本次改动:
// 1. antigravitySmartRetryMaxAttempts = 1仅重试 1 次)
// 2. 智能重试失败后清除粘性会话绑定DeleteSessionAccountID
// ---------------------------------------------------------------------------
// TestSmartRetryMaxAttempts_VerifyConstant 验证常量值为 1
func TestSmartRetryMaxAttempts_VerifyConstant(t *testing.T) {
require.Equal(t, 1, antigravitySmartRetryMaxAttempts,
"antigravitySmartRetryMaxAttempts should be 1 to prevent repeated rate limiting")
}
// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession
// 核心场景:粘性会话 + 短延迟重试失败 → 必须清除粘性绑定
func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t *testing.T) {
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 10,
Name: "acc-10",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-abc",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
// 验证返回 switchError
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession, "switchError should carry IsStickySession=true")
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
// 核心断言DeleteSessionAccountID 被调用,且参数正确
require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID exactly once")
require.Equal(t, int64(42), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-hash-abc", cache.deleteCalls[0].sessionHash)
// 验证仅重试 1 次
require.Len(t, upstream.calls, 1, "should make exactly 1 retry call (maxAttempts=1)")
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
}
// TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession
// 非粘性会话 + 短延迟重试失败 → 不应调用 DeleteSessionAccountIDsessionHash 为空)
func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession(t *testing.T) {
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 11,
Name: "acc-11",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: false,
groupID: 42,
sessionHash: "", // 非粘性会话sessionHash 为空
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.False(t, result.switchError.IsStickySession)
// 核心断言sessionHash 为空时不应调用 DeleteSessionAccountID
require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID when sessionHash is empty")
}
// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic
// 边界cache 为 nil 时不应 panic
func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic(t *testing.T) {
failRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 12,
Name: "acc-12",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-nil-cache",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
// cache 为 nil不应 panic
svc := &AntigravityGatewayService{cache: nil}
require.NotPanics(t, func() {
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
})
}
// TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession
// 重试成功时不应清除粘性会话(只有失败才清除)
func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 13,
Name: "acc-13",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-success",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not return switchError on success")
// 核心断言:重试成功时不应清除粘性会话
require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID on successful retry")
}
// TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry
// 长延迟路径情况1在 handleSmartRetry 中不直接调用 DeleteSessionAccountID
// (清除由 handler 层的 shouldClearStickySession 在下次请求时处理)
func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 14,
Name: "acc-14",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 15s >= 7s 阈值 → 走长延迟路径
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
isStickySession: true,
groupID: 42,
sessionHash: "sticky-hash-long-delay",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
// 长延迟路径不在 handleSmartRetry 中调用 DeleteSessionAccountID
// (由上游 handler 的 shouldClearStickySession 处理)
require.Len(t, cache.deleteCalls, 0,
"long delay path should NOT call DeleteSessionAccountID in handleSmartRetry (handled by handler layer)")
}
// TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession
// 网络错误耗尽重试 + 粘性会话 → 也应清除粘性绑定
func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t *testing.T) {
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil}, // 网络错误
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 15,
Name: "acc-15",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 99,
sessionHash: "sticky-net-error",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
// 核心断言:网络错误耗尽重试后也应清除粘性绑定
require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID after network error exhausts retry")
require.Equal(t, int64(99), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-net-error", cache.deleteCalls[0].sessionHash)
}
// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
// 429 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) {
failRespBody := `{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 16,
Name: "acc-16",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"code": 429,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 77,
sessionHash: "sticky-503-short",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{cache: cache}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.NotNil(t, result.switchError)
require.True(t, result.switchError.IsStickySession)
// 验证粘性绑定被清除
require.Len(t, cache.deleteCalls, 1)
require.Equal(t, int64(77), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-503-short", cache.deleteCalls[0].sessionHash)
// 验证模型限流已设置
require.Len(t, repo.modelRateLimitCalls, 1)
require.Equal(t, "gemini-3-pro", repo.modelRateLimitCalls[0].modelKey)
}
// TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates
// 集成测试antigravityRetryLoop → handleSmartRetry → switchError 传播
// 验证 IsStickySession 正确传递到上层,且粘性绑定被清除
func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates(t *testing.T) {
// 初始 429 响应
initialRespBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
initialResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(initialRespBody)),
}
// 智能重试也返回 429
retryRespBody := `{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
retryResp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(retryRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{initialResp, retryResp},
errors: []error{nil, nil},
}
repo := &stubAntigravityAccountRepo{}
cache := &stubSmartRetryCache{}
account := &Account{
ID: 17,
Name: "acc-17",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
svc := &AntigravityGatewayService{cache: cache}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
isStickySession: true,
groupID: 55,
sessionHash: "sticky-loop-test",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result, "should not return result when switchError")
require.NotNil(t, err, "should return error")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr, "error should be AntigravityAccountSwitchError")
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-opus-4-6", switchErr.RateLimitedModel)
require.True(t, switchErr.IsStickySession, "IsStickySession must propagate through retryLoop")
// 验证粘性绑定被清除
require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry")
require.Equal(t, int64(55), cache.deleteCalls[0].groupID)
require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash)
}

View File

@@ -7,12 +7,14 @@ import (
"log/slog"
"strconv"
"strings"
"sync"
"time"
)
const (
antigravityTokenRefreshSkew = 3 * time.Minute
antigravityTokenCacheSkew = 5 * time.Minute
antigravityBackfillCooldown = 5 * time.Minute
)
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
@@ -23,6 +25,7 @@ type AntigravityTokenProvider struct {
accountRepo AccountRepository
tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
}
func NewAntigravityTokenProvider(
@@ -93,13 +96,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if err != nil {
return "", err
}
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
p.mergeCredentials(account, tokenInfo)
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
}
@@ -113,6 +110,21 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("access_token not found in credentials")
}
// 如果账号还没有 project_id尝试在线补齐避免请求 daily/sandbox 时出现
// "Invalid project resource name projects/"。
// 仅调用 loadProjectIDWithRetry不刷新 OAuth token带冷却机制防止频繁重试。
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
if p.shouldAttemptBackfill(account.ID) {
p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
}
}
}
}
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
@@ -144,6 +156,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil
}
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
}
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id冷却期内不重复尝试
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
if v, ok := p.backfillCooldown.Load(accountID); ok {
if lastAttempt, ok := v.(time.Time); ok {
return time.Since(lastAttempt) > antigravityBackfillCooldown
}
}
return true
}
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
p.backfillCooldown.Store(accountID, time.Now())
}
func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {

View File

@@ -0,0 +1,112 @@
package service
import (
"testing"
)
func TestBuildSelectedSet(t *testing.T) {
tests := []struct {
name string
ids []string
wantNil bool
wantSize int
}{
{
name: "nil input returns nil (backward compatible: create all)",
ids: nil,
wantNil: true,
},
{
name: "empty slice returns empty map (create none)",
ids: []string{},
wantNil: false,
wantSize: 0,
},
{
name: "single ID",
ids: []string{"abc-123"},
wantNil: false,
wantSize: 1,
},
{
name: "multiple IDs",
ids: []string{"a", "b", "c"},
wantNil: false,
wantSize: 3,
},
{
name: "duplicate IDs are deduplicated",
ids: []string{"a", "a", "b"},
wantNil: false,
wantSize: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildSelectedSet(tt.ids)
if tt.wantNil {
if got != nil {
t.Errorf("buildSelectedSet(%v) = %v, want nil", tt.ids, got)
}
return
}
if got == nil {
t.Fatalf("buildSelectedSet(%v) = nil, want non-nil map", tt.ids)
}
if len(got) != tt.wantSize {
t.Errorf("buildSelectedSet(%v) has %d entries, want %d", tt.ids, len(got), tt.wantSize)
}
// Verify all unique IDs are present
for _, id := range tt.ids {
if _, ok := got[id]; !ok {
t.Errorf("buildSelectedSet(%v) missing key %q", tt.ids, id)
}
}
})
}
}
func TestShouldCreateAccount(t *testing.T) {
tests := []struct {
name string
crsID string
selectedSet map[string]struct{}
want bool
}{
{
name: "nil set allows all (backward compatible)",
crsID: "any-id",
selectedSet: nil,
want: true,
},
{
name: "empty set blocks all",
crsID: "any-id",
selectedSet: map[string]struct{}{},
want: false,
},
{
name: "ID in set is allowed",
crsID: "abc-123",
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
want: true,
},
{
name: "ID not in set is blocked",
crsID: "xyz-789",
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := shouldCreateAccount(tt.crsID, tt.selectedSet)
if got != tt.want {
t.Errorf("shouldCreateAccount(%q, %v) = %v, want %v",
tt.crsID, tt.selectedSet, got, tt.want)
}
})
}
}

View File

@@ -45,10 +45,11 @@ func NewCRSSyncService(
}
type SyncFromCRSInput struct {
BaseURL string
Username string
Password string
SyncProxies bool
BaseURL string
Username string
Password string
SyncProxies bool
SelectedAccountIDs []string // if non-empty, only create new accounts with these CRS IDs
}
type SyncFromCRSItemResult struct {
@@ -190,25 +191,27 @@ type crsGeminiAPIKeyAccount struct {
Extra map[string]any `json:"extra"`
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
// fetchCRSExport validates the connection parameters, authenticates with CRS,
// and returns the exported accounts. Shared by SyncFromCRS and PreviewFromCRS.
func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, password string) (*crsExportResponse, error) {
if s.cfg == nil {
return nil, errors.New("config is not available")
}
baseURL := strings.TrimSpace(input.BaseURL)
normalizedURL := strings.TrimSpace(baseURL)
if s.cfg.Security.URLAllowlist.Enabled {
normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
normalized, err := normalizeBaseURL(normalizedURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts)
if err != nil {
return nil, err
}
baseURL = normalized
normalizedURL = normalized
} else {
normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
normalized, err := urlvalidator.ValidateURLFormat(normalizedURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
baseURL = normalized
normalizedURL = normalized
}
if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" {
if strings.TrimSpace(username) == "" || strings.TrimSpace(password) == "" {
return nil, errors.New("username and password are required")
}
@@ -221,12 +224,16 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
client = &http.Client{Timeout: 20 * time.Second}
}
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
adminToken, err := crsLogin(ctx, client, normalizedURL, username, password)
if err != nil {
return nil, err
}
exported, err := crsExportAccounts(ctx, client, baseURL, adminToken)
return crsExportAccounts(ctx, client, normalizedURL, adminToken)
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password)
if err != nil {
return nil, err
}
@@ -241,6 +248,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
),
}
selectedSet := buildSelectedSet(input.SelectedAccountIDs)
var proxies []Proxy
if input.SyncProxies {
proxies, _ = s.proxyRepo.ListActive(ctx)
@@ -329,6 +338,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
@@ -446,6 +462,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
@@ -569,6 +592,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
@@ -690,6 +720,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
@@ -798,6 +835,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
@@ -909,6 +953,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
if !shouldCreateAccount(src.ID, selectedSet) {
item.Action = "skipped"
item.Error = "not selected"
result.Skipped++
result.Items = append(result.Items, item)
continue
}
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
@@ -1253,3 +1304,102 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account
return newCredentials
}
// buildSelectedSet converts a slice of selected CRS account IDs to a set for O(1) lookup.
// Returns nil if ids is nil (field not sent → backward compatible: create all).
// Returns an empty map if ids is non-nil but empty (user selected none → create none).
func buildSelectedSet(ids []string) map[string]struct{} {
if ids == nil {
return nil
}
set := make(map[string]struct{}, len(ids))
for _, id := range ids {
set[id] = struct{}{}
}
return set
}
// shouldCreateAccount checks if a new CRS account should be created based on user selection.
// Returns true if selectedSet is nil (backward compatible: create all) or if crsID is in the set.
func shouldCreateAccount(crsID string, selectedSet map[string]struct{}) bool {
if selectedSet == nil {
return true
}
_, ok := selectedSet[crsID]
return ok
}
// PreviewFromCRSResult contains the preview of accounts from CRS before sync.
type PreviewFromCRSResult struct {
NewAccounts []CRSPreviewAccount `json:"new_accounts"`
ExistingAccounts []CRSPreviewAccount `json:"existing_accounts"`
}
// CRSPreviewAccount represents a single account in the preview result.
type CRSPreviewAccount struct {
CRSAccountID string `json:"crs_account_id"`
Kind string `json:"kind"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
}
// PreviewFromCRS connects to CRS, fetches all accounts, and classifies them
// as new or existing by batch-querying local crs_account_id mappings.
func (s *CRSSyncService) PreviewFromCRS(ctx context.Context, input SyncFromCRSInput) (*PreviewFromCRSResult, error) {
exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password)
if err != nil {
return nil, err
}
// Batch query all existing CRS account IDs
existingCRSIDs, err := s.accountRepo.ListCRSAccountIDs(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list existing CRS accounts: %w", err)
}
result := &PreviewFromCRSResult{
NewAccounts: make([]CRSPreviewAccount, 0),
ExistingAccounts: make([]CRSPreviewAccount, 0),
}
classify := func(crsID, kind, name, platform, accountType string) {
preview := CRSPreviewAccount{
CRSAccountID: crsID,
Kind: kind,
Name: defaultName(name, crsID),
Platform: platform,
Type: accountType,
}
if _, exists := existingCRSIDs[crsID]; exists {
result.ExistingAccounts = append(result.ExistingAccounts, preview)
} else {
result.NewAccounts = append(result.NewAccounts, preview)
}
}
for _, src := range exported.Data.ClaudeAccounts {
authType := strings.TrimSpace(src.AuthType)
if authType == "" {
authType = AccountTypeOAuth
}
classify(src.ID, src.Kind, src.Name, PlatformAnthropic, authType)
}
for _, src := range exported.Data.ClaudeConsoleAccounts {
classify(src.ID, src.Kind, src.Name, PlatformAnthropic, AccountTypeAPIKey)
}
for _, src := range exported.Data.OpenAIOAuthAccounts {
classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeOAuth)
}
for _, src := range exported.Data.OpenAIResponsesAccounts {
classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeAPIKey)
}
for _, src := range exported.Data.GeminiOAuthAccounts {
classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeOAuth)
}
for _, src := range exported.Data.GeminiAPIKeyAccounts {
classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeAPIKey)
}
return result, nil
}

View File

@@ -0,0 +1,69 @@
package service
import (
"strconv"
"strings"
"time"
gocache "github.com/patrickmn/go-cache"
)
// digestSessionTTL 摘要会话默认 TTL
const digestSessionTTL = 5 * time.Minute
// sessionEntry flat cache 条目
type sessionEntry struct {
uuid string
accountID int64
}
// DigestSessionStore 内存摘要会话存储flat cache 实现)
// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry
type DigestSessionStore struct {
cache *gocache.Cache
}
// NewDigestSessionStore 创建内存摘要会话存储
func NewDigestSessionStore() *DigestSessionStore {
return &DigestSessionStore{
cache: gocache.New(digestSessionTTL, time.Minute),
}
}
// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain用于删旧 key。
func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) {
if digestChain == "" {
return
}
ns := buildNS(groupID, prefixHash)
s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration)
if oldDigestChain != "" && oldDigestChain != digestChain {
s.cache.Delete(ns + oldDigestChain)
}
}
// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。
func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" {
return "", 0, "", false
}
ns := buildNS(groupID, prefixHash)
chain := digestChain
for {
if val, ok := s.cache.Get(ns + chain); ok {
if e, ok := val.(*sessionEntry); ok {
return e.uuid, e.accountID, chain, true
}
}
i := strings.LastIndex(chain, "-")
if i < 0 {
return "", 0, "", false
}
chain = chain[:i]
}
}
// buildNS 构建 namespace 前缀
func buildNS(groupID int64, prefixHash string) string {
return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|"
}

View File

@@ -0,0 +1,312 @@
//go:build unit
package service
import (
"fmt"
"sync"
"testing"
"time"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDigestSessionStore_SaveAndFind(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_PrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
// 保存短链
store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "")
// 用长链查找,应前缀匹配到短链
uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-short", uuid)
assert.Equal(t, int64(10), accountID)
assert.Equal(t, "u:a-m:b", matchedChain)
}
func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a", "uuid-1", 1, "")
store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "")
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "")
// 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的)
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "uuid-3", uuid)
assert.Equal(t, int64(3), accountID)
// 查找中等长度,应匹配到 "u:a-m:b"
uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) {
store := NewDigestSessionStore()
// 第一轮:保存 "u:a-m:b"
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 第二轮:同一 uuid 保存更长的链,传入旧 chain
store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b")
// 旧链 "u:a-m:b" 应已被删除
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found, "old chain should be deleted")
// 新链应能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}
func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) {
store := NewDigestSessionStore()
// 相同系统提示词,不同用户提示词
store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "")
store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "")
uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(200), accountID)
}
func TestDigestSessionStore_NoMatch(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 完全不同的 chain
_, _, _, found := store.Find(1, "prefix", "u:x-m:y")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "")
// 不同 prefixHash 应隔离
_, _, _, found := store.Find(1, "prefix2", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_DifferentGroupID(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 不同 groupID 应隔离
_, _, _, found := store.Find(2, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_EmptyDigestChain(t *testing.T) {
store := NewDigestSessionStore()
// 空链不应保存
store.Save(1, "prefix", "", "uuid-1", 100, "")
_, _, _, found := store.Find(1, "prefix", "")
assert.False(t, found)
}
func TestDigestSessionStore_TTLExpiration(t *testing.T) {
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 立即应该能找到
_, _, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
// 等待过期 + 清理周期
time.Sleep(300 * time.Millisecond)
// 过期后应找不到
_, _, _, found = store.Find(1, "prefix", "u:a-m:b")
assert.False(t, found)
}
func TestDigestSessionStore_ConcurrentSafety(t *testing.T) {
store := NewDigestSessionStore()
var wg sync.WaitGroup
const goroutines = 50
const operations = 100
wg.Add(goroutines)
for g := 0; g < goroutines; g++ {
go func(id int) {
defer wg.Done()
prefix := fmt.Sprintf("prefix-%d", id%5)
for i := 0; i < operations; i++ {
chain := fmt.Sprintf("u:%d-m:%d", id, i)
uuid := fmt.Sprintf("uuid-%d-%d", id, i)
store.Save(1, prefix, chain, uuid, int64(id), "")
store.Find(1, prefix, chain)
}
}(g)
}
wg.Wait()
}
func TestDigestSessionStore_MultipleSessions(t *testing.T) {
store := NewDigestSessionStore()
sessions := []struct {
chain string
uuid string
accountID int64
}{
{"u:session1", "uuid-1", 1},
{"u:session2-m:reply2", "uuid-2", 2},
{"u:session3-m:reply3-u:msg3", "uuid-3", 3},
}
for _, sess := range sessions {
store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "")
}
// 验证每个会话都能正确查找
for _, sess := range sessions {
uuid, accountID, _, found := store.Find(1, "prefix", sess.chain)
require.True(t, found, "should find session: %s", sess.chain)
assert.Equal(t, sess.uuid, uuid)
assert.Equal(t, sess.accountID, accountID)
}
// 验证继续对话的场景
uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg")
require.True(t, found)
assert.Equal(t, "uuid-2", uuid)
assert.Equal(t, int64(2), accountID)
}
func TestDigestSessionStore_Performance1000Sessions(t *testing.T) {
store := NewDigestSessionStore()
// 插入 1000 个会话
for i := 0; i < 1000; i++ {
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
// 查找性能测试
start := time.Now()
const lookups = 10000
for i := 0; i < lookups; i++ {
idx := i % 1000
chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx)
_, _, _, found := store.Find(1, "prefix", chain)
assert.True(t, found)
}
elapsed := time.Since(start)
t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups)
}
func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) {
store := NewDigestSessionStore()
store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "")
// 精确匹配
_, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
// 前缀匹配(截断后命中)
_, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e")
require.True(t, found)
assert.Equal(t, "u:a-m:b-u:c", matchedChain)
}
func TestDigestSessionStore_CacheItemCountStable(t *testing.T) {
store := NewDigestSessionStore()
// 模拟 100 个独立会话,每个进行 10 轮对话
// 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key
for conv := 0; conv < 100; conv++ {
var prevMatchedChain string
for round := 0; round < 10; round++ {
chain := fmt.Sprintf("s:sys-u:user%d", conv)
for r := 0; r < round; r++ {
chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1)
}
uuid := fmt.Sprintf("uuid-conv%d", conv)
_, _, matched, _ := store.Find(1, "prefix", chain)
store.Save(1, "prefix", chain, uuid, int64(conv), matched)
prevMatchedChain = matched
_ = prevMatchedChain
}
}
// 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key
// 允许少量并发残留,但绝不能接近 100×10=1000
itemCount := store.cache.ItemCount()
assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount)
t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount)
}
func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) {
// 使用极短 TTL 验证大量写入后 cache 能被清理
store := &DigestSessionStore{
cache: gocache.New(100*time.Millisecond, 50*time.Millisecond),
}
// 插入 500 个不同的 key无 oldDigestChain模拟最坏场景全是新会话首轮
for i := 0; i < 500; i++ {
chain := fmt.Sprintf("u:user%d", i)
store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "")
}
assert.Equal(t, 500, store.cache.ItemCount())
// 等待 TTL + 清理周期
time.Sleep(300 * time.Millisecond)
assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up")
}
func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) {
store := NewDigestSessionStore()
// 保存 chain
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "")
// 用户重发相同消息oldDigestChain == digestChain不应删掉刚设置的 key
store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b")
// 仍然能找到
uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b")
require.True(t, found)
assert.Equal(t, "uuid-1", uuid)
assert.Equal(t, int64(100), accountID)
}

View File

@@ -61,6 +61,11 @@ func applyErrorPassthroughRule(
errMsg = *rule.CustomMessage
}
// 命中 skip_monitoring 时在 context 中标记,供 ops_error_logger 跳过记录。
if rule.SkipMonitoring {
c.Set(OpsSkipPassthroughKey, true)
}
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
errType = "upstream_error"
return status, errType, errMsg, true

View File

@@ -194,6 +194,63 @@ func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) {
assert.Equal(t, "Gemini上游失败", errField["message"])
}
func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
rule.SkipMonitoring = true
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
BindErrorPassthroughService(c, ruleSvc)
_, _, _, matched := applyErrorPassthroughRule(
c,
PlatformAnthropic,
http.StatusBadRequest,
[]byte(`{"error":{"message":"prompt is too long"}}`),
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
)
assert.True(t, matched)
v, exists := c.Get(OpsSkipPassthroughKey)
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
boolVal, ok := v.(bool)
assert.True(t, ok, "value should be bool")
assert.True(t, boolVal)
}
func TestApplyErrorPassthroughRule_NoSkipMonitoringDoesNotSetContextKey(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
rule.SkipMonitoring = false
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
BindErrorPassthroughService(c, ruleSvc)
_, _, _, matched := applyErrorPassthroughRule(
c,
PlatformAnthropic,
http.StatusBadRequest,
[]byte(`{"error":{"message":"prompt is too long"}}`),
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
)
assert.True(t, matched)
_, exists := c.Get(OpsSkipPassthroughKey)
assert.False(t, exists, "OpsSkipPassthroughKey should NOT be set when skip_monitoring=false")
}
func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule {
return &model.ErrorPassthroughRule{
ID: 1,

View File

@@ -45,10 +45,20 @@ type ErrorPassthroughService struct {
cache ErrorPassthroughCache
// 本地内存缓存,用于快速匹配
localCache []*model.ErrorPassthroughRule
localCache []*cachedPassthroughRule
localCacheMu sync.RWMutex
}
// cachedPassthroughRule 预计算的规则缓存,避免运行时重复 ToLower
type cachedPassthroughRule struct {
*model.ErrorPassthroughRule
lowerKeywords []string // 预计算的小写关键词
lowerPlatforms []string // 预计算的小写平台
errorCodeSet map[int]struct{} // 预计算的 error code set
}
const maxBodyMatchLen = 8 << 10 // 8KB错误信息不会在 8KB 之后才出现
// NewErrorPassthroughService 创建错误透传规则服务
func NewErrorPassthroughService(
repo ErrorPassthroughRepository,
@@ -150,17 +160,19 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
return nil
}
bodyStr := strings.ToLower(string(body))
lowerPlatform := strings.ToLower(platform)
var bodyLower string // 延迟初始化,只在需要关键词匹配时计算
var bodyLowerDone bool
for _, rule := range rules {
if !rule.Enabled {
continue
}
if !s.platformMatches(rule, platform) {
if !s.platformMatchesCached(rule, lowerPlatform) {
continue
}
if s.ruleMatches(rule, statusCode, bodyStr) {
return rule
if s.ruleMatchesOptimized(rule, statusCode, body, &bodyLower, &bodyLowerDone) {
return rule.ErrorPassthroughRule
}
}
@@ -168,7 +180,7 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
}
// getCachedRules 获取缓存的规则列表(按优先级排序)
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule {
s.localCacheMu.RLock()
rules := s.localCache
s.localCacheMu.RUnlock()
@@ -223,17 +235,39 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
return nil
}
// setLocalCache 设置本地缓存
// setLocalCache 设置本地缓存,预计算小写值和 set 以避免运行时重复计算
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
cached := make([]*cachedPassthroughRule, len(rules))
for i, r := range rules {
cr := &cachedPassthroughRule{ErrorPassthroughRule: r}
if len(r.Keywords) > 0 {
cr.lowerKeywords = make([]string, len(r.Keywords))
for j, kw := range r.Keywords {
cr.lowerKeywords[j] = strings.ToLower(kw)
}
}
if len(r.Platforms) > 0 {
cr.lowerPlatforms = make([]string, len(r.Platforms))
for j, p := range r.Platforms {
cr.lowerPlatforms[j] = strings.ToLower(p)
}
}
if len(r.ErrorCodes) > 0 {
cr.errorCodeSet = make(map[int]struct{}, len(r.ErrorCodes))
for _, code := range r.ErrorCodes {
cr.errorCodeSet[code] = struct{}{}
}
}
cached[i] = cr
}
// 按优先级排序
sorted := make([]*model.ErrorPassthroughRule, len(rules))
copy(sorted, rules)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
sort.Slice(cached, func(i, j int) bool {
return cached[i].Priority < cached[j].Priority
})
s.localCacheMu.Lock()
s.localCache = sorted
s.localCache = cached
s.localCacheMu.Unlock()
}
@@ -273,62 +307,79 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
}
}
// platformMatches 检查平台是否匹配
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
// 如果没有配置平台限制,则匹配所有平台
if len(rule.Platforms) == 0 {
// ensureBodyLower 延迟初始化 body 的小写版本,只做一次转换,限制 8KB
func ensureBodyLower(body []byte, bodyLower *string, done *bool) string {
if *done {
return *bodyLower
}
b := body
if len(b) > maxBodyMatchLen {
b = b[:maxBodyMatchLen]
}
*bodyLower = strings.ToLower(string(b))
*done = true
return *bodyLower
}
// platformMatchesCached 使用预计算的小写平台检查是否匹配
func (s *ErrorPassthroughService) platformMatchesCached(rule *cachedPassthroughRule, lowerPlatform string) bool {
if len(rule.lowerPlatforms) == 0 {
return true
}
platform = strings.ToLower(platform)
for _, p := range rule.Platforms {
if strings.ToLower(p) == platform {
for _, p := range rule.lowerPlatforms {
if p == lowerPlatform {
return true
}
}
return false
}
// ruleMatches 检查规则是否匹配
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
hasErrorCodes := len(rule.ErrorCodes) > 0
hasKeywords := len(rule.Keywords) > 0
// ruleMatchesOptimized 优化的规则匹配,支持短路和延迟 body 转换
func (s *ErrorPassthroughService) ruleMatchesOptimized(rule *cachedPassthroughRule, statusCode int, body []byte, bodyLower *string, bodyLowerDone *bool) bool {
hasErrorCodes := len(rule.errorCodeSet) > 0
hasKeywords := len(rule.lowerKeywords) > 0
// 如果没有配置任何条件,不匹配
if !hasErrorCodes && !hasKeywords {
return false
}
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
codeMatch := !hasErrorCodes || s.containsIntSet(rule.errorCodeSet, statusCode)
if rule.MatchMode == model.MatchModeAll {
// "all" 模式:所有配置的条件都必须满足
return codeMatch && keywordMatch
// "all" 模式:所有配置的条件都必须满足,短路
if hasErrorCodes && !codeMatch {
return false
}
if hasKeywords {
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
}
return codeMatch
}
// "any" 模式:任一条件满足即可
// "any" 模式:任一条件满足即可,短路
if hasErrorCodes && hasKeywords {
return codeMatch || keywordMatch
if codeMatch {
return true
}
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
}
return codeMatch && keywordMatch
// 只配置了一种条件
if hasKeywords {
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
}
return codeMatch
}
// containsInt 检查切片是否包含指定整数
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
for _, v := range slice {
if v == val {
return true
}
}
return false
}
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
for _, kw := range keywords {
if strings.Contains(bodyLower, strings.ToLower(kw)) {
// containsIntSet 使用 map 查找替代线性扫描
func (s *ErrorPassthroughService) containsIntSet(set map[int]struct{}, val int) bool {
_, ok := set[val]
return ok
}
// containsAnyKeywordCached 使用预计算的小写关键词检查匹配
func (s *ErrorPassthroughService) containsAnyKeywordCached(bodyLower string, lowerKeywords []string) bool {
for _, kw := range lowerKeywords {
if strings.Contains(bodyLower, kw) {
return true
}
}

View File

@@ -145,32 +145,58 @@ func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughServic
return svc
}
// newCachedRuleForTest 从 model.ErrorPassthroughRule 创建 cachedPassthroughRule测试用
func newCachedRuleForTest(rule *model.ErrorPassthroughRule) *cachedPassthroughRule {
cr := &cachedPassthroughRule{ErrorPassthroughRule: rule}
if len(rule.Keywords) > 0 {
cr.lowerKeywords = make([]string, len(rule.Keywords))
for j, kw := range rule.Keywords {
cr.lowerKeywords[j] = strings.ToLower(kw)
}
}
if len(rule.Platforms) > 0 {
cr.lowerPlatforms = make([]string, len(rule.Platforms))
for j, p := range rule.Platforms {
cr.lowerPlatforms[j] = strings.ToLower(p)
}
}
if len(rule.ErrorCodes) > 0 {
cr.errorCodeSet = make(map[int]struct{}, len(rule.ErrorCodes))
for _, code := range rule.ErrorCodes {
cr.errorCodeSet[code] = struct{}{}
}
}
return cr
}
// =============================================================================
// 测试 ruleMatches 核心匹配逻辑
// 测试 ruleMatchesOptimized 核心匹配逻辑
// =============================================================================
func TestRuleMatches_NoConditions(t *testing.T) {
// 没有配置任何条件时,不应该匹配
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{},
MatchMode: model.MatchModeAny,
}
})
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
var bodyLower string
var bodyLowerDone bool
assert.False(t, svc.ruleMatchesOptimized(rule, 422, []byte("some error message"), &bodyLower, &bodyLowerDone),
"没有配置条件时不应该匹配")
}
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{},
MatchMode: model.MatchModeAny,
}
})
tests := []struct {
name string
@@ -186,7 +212,9 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
var bodyLower string
var bodyLowerDone bool
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
assert.Equal(t, tt.expected, result)
})
}
@@ -194,12 +222,12 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{"context limit", "model not supported"},
MatchMode: model.MatchModeAny,
}
})
tests := []struct {
name string
@@ -210,16 +238,14 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
{"关键词匹配 context limit", 500, "error: context limit reached", true},
{"关键词匹配 model not supported", 400, "the model not supported here", true},
{"关键词不匹配", 422, "some other error", false},
// 注意ruleMatches 接收的 body 参数应该是已经转换为小写的
// 实际使用时MatchRule 会先将 body 转换为小写再传给 ruleMatches
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
{"关键词大小写 - 自动转换", 500, "Context Limit exceeded", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 MatchRule 的行为:先转换为小写
bodyLower := strings.ToLower(tt.body)
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
var bodyLower string
var bodyLowerDone bool
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
assert.Equal(t, tt.expected, result)
})
}
@@ -228,12 +254,12 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
// any 模式:错误码 OR 关键词
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAny,
}
})
tests := []struct {
name string
@@ -274,7 +300,9 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
var bodyLower string
var bodyLowerDone bool
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
@@ -283,12 +311,12 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
// all 模式:错误码 AND 关键词
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAll,
}
})
tests := []struct {
name string
@@ -329,14 +357,16 @@ func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
var bodyLower string
var bodyLowerDone bool
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
}
// =============================================================================
// 测试 platformMatches 平台匹配逻辑
// 测试 platformMatchesCached 平台匹配逻辑
// =============================================================================
func TestPlatformMatches(t *testing.T) {
@@ -394,10 +424,10 @@ func TestPlatformMatches(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rule := &model.ErrorPassthroughRule{
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
Platforms: tt.rulePlatforms,
}
result := svc.platformMatches(rule, tt.requestPlatform)
})
result := svc.platformMatchesCached(rule, strings.ToLower(tt.requestPlatform))
assert.Equal(t, tt.expected, result)
})
}

View File

@@ -0,0 +1,472 @@
//go:build unit
package service
import (
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Mocks (scoped to this file by naming convention)
// ---------------------------------------------------------------------------
// epFixedUpstream returns a fixed response for every request.
type epFixedUpstream struct {
statusCode int
body string
calls int
}
func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
u.calls++
return &http.Response{
StatusCode: u.statusCode,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(u.body)),
}, nil
}
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
// epAccountRepo records SetTempUnschedulable / SetError calls.
type epAccountRepo struct {
mockAccountRepoForGemini
tempCalls int
setErrCalls int
}
func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.tempCalls++
return nil
}
func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error {
r.setErrCalls++
return nil
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
func saveAndSetBaseURLs(t *testing.T) {
t.Helper()
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvail := antigravity.DefaultURLAvailability
antigravity.BaseURLs = []string{"https://ep-test.example"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
t.Cleanup(func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvail
})
}
func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams {
return antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[ep-test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: handleError,
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_CustomErrorCodes
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
tests := []struct {
name string
upstreamStatus int
upstreamBody string
customCodes []any
expectHandleError int
expectUpstream int
expectStatusCode int
}{
{
name: "429_in_custom_codes_matched",
upstreamStatus: 429,
upstreamBody: `{"error":"rate limited"}`,
customCodes: []any{float64(429)},
expectHandleError: 1,
expectUpstream: 1,
expectStatusCode: 429,
},
{
name: "429_not_in_custom_codes_skipped",
upstreamStatus: 429,
upstreamBody: `{"error":"rate limited"}`,
customCodes: []any{float64(500)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 500,
},
{
name: "500_in_custom_codes_matched",
upstreamStatus: 500,
upstreamBody: `{"error":"internal"}`,
customCodes: []any{float64(500)},
expectHandleError: 1,
expectUpstream: 1,
expectStatusCode: 500,
},
{
name: "500_not_in_custom_codes_skipped",
upstreamStatus: 500,
upstreamBody: `{"error":"internal"}`,
customCodes: []any{float64(429)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 500,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 100,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": tt.customCodes,
},
}
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, tt.expectStatusCode, result.resp.StatusCode)
require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count")
require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count")
})
}
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_TempUnschedulable
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) {
tempRulesAccount := func(rules []any) *Account {
return &Account{
ID: 200,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": rules,
},
}
}
overloadedRule := map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
}
rateLimitRule := map[string]any{
"error_code": float64(429),
"keywords": []any{"rate limited keyword"},
"duration_minutes": float64(5),
}
t.Run("503_overloaded_matches_rule", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{overloadedRule})
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
t.Error("handleError should not be called for temp unschedulable")
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, 1, upstream.calls, "should not retry")
})
t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{rateLimitRule})
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
t.Error("handleError should not be called for temp unschedulable")
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, 1, upstream.calls, "should not retry")
})
t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 503, body: `random`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := tempRulesAccount([]any{overloadedRule})
// Use a short-lived context: the backoff sleep (~1s) will be
// interrupted, proving the code entered the default retry path
// instead of breaking early via error policy.
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
return nil
})
p.ctx = ctx
result, err := svc.antigravityRetryLoop(p)
// Context cancellation during backoff proves default retry was entered
require.Nil(t, result)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once")
})
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NilRateLimitService
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
// rateLimitService is nil — must not panic
svc := &AntigravityGatewayService{rateLimitService: nil}
account := &Account{
ID: 300,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
return nil
})
p.ctx = ctx
// Should not panic; enters the default retry path (eventually times out)
result, err := svc.antigravityRetryLoop(p)
require.Nil(t, result)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.GreaterOrEqual(t, upstream.calls, 1)
}
// ---------------------------------------------------------------------------
// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
// ---------------------------------------------------------------------------
func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
repo := &epAccountRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
// Plain OAuth account with no error policy configured
account := &Account{
ID: 400,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.resp)
defer func() { _ = result.resp.Body.Close() }()
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
}
// ---------------------------------------------------------------------------
// epTrackingRepo — records SetRateLimited / SetError calls for verification.
// ---------------------------------------------------------------------------
type epTrackingRepo struct {
mockAccountRepoForGemini
rateLimitedCalls int
rateLimitedID int64
setErrCalls int
setErrID int64
tempCalls int
}
func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
r.rateLimitedCalls++
r.rateLimitedID = id
return nil
}
func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error {
r.setErrCalls++
r.setErrID = id
return nil
}
func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.tempCalls++
return nil
}
// ---------------------------------------------------------------------------
// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit
//
// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码),
// 当上游返回 429/500/503/401 时:
// - 返回给客户端的状态码必须是 500而不是透传原始状态码
// - 不调用 SetRateLimited不进入限流状态
// - 不调用 SetError不停止调度
// - 不调用 handleError
// ---------------------------------------------------------------------------
func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) {
errorCodes := []int{429, 500, 503, 401, 403}
for _, upstreamStatus := range errorCodes {
t.Run(http.StatusText(upstreamStatus), func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{
statusCode: upstreamStatus,
body: `{"error":"some upstream error"}`,
}
repo := &epTrackingRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := &Account{
ID: 500,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(599)},
},
}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
// 不应返回 errorSkipped 不触发账号切换)
require.NoError(t, err, "should not return error")
require.NotNil(t, result, "result should not be nil")
require.NotNil(t, result.resp, "response should not be nil")
defer func() { _ = result.resp.Body.Close() }()
// 状态码必须是 500不透传原始状态码
require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode,
"skipped error should return 500, not %d", upstreamStatus)
// 不调用 handleError
require.Equal(t, 0, handleErrorCount,
"handleError should NOT be called for skipped errors")
// 不标记限流
require.Equal(t, 0, repo.rateLimitedCalls,
"SetRateLimited should NOT be called for skipped errors")
// 不停止调度
require.Equal(t, 0, repo.setErrCalls,
"SetError should NOT be called for skipped errors")
// 只调用一次上游(不重试)
require.Equal(t, 1, upstream.calls,
"should call upstream exactly once (no retry)")
})
}
}

View File

@@ -0,0 +1,295 @@
//go:build unit
package service
import (
"context"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
// ---------------------------------------------------------------------------
func TestCheckErrorPolicy(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expected ErrorPolicyResult
}{
{
name: "no_policy_oauth_returns_none",
account: &Account{
ID: 1,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
// no custom error codes, no temp rules
},
statusCode: 500,
body: []byte(`"error"`),
expected: ErrorPolicyNone,
},
{
name: "custom_error_codes_hit_returns_matched",
account: &Account{
ID: 2,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 500,
body: []byte(`"error"`),
expected: ErrorPolicyMatched,
},
{
name: "custom_error_codes_miss_returns_skipped",
account: &Account{
ID: 3,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 503,
body: []byte(`"error"`),
expected: ErrorPolicySkipped,
},
{
name: "temp_unschedulable_hit_returns_temp_unscheduled",
account: &Account{
ID: 4,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_body_miss_returns_none",
account: &Account{
ID: 5,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`random msg`),
expected: ErrorPolicyNone,
},
{
name: "custom_error_codes_override_temp_unschedulable",
account: &Account{
ID: 6,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(503)},
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
"description": "overloaded rule",
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
})
}
}
// ---------------------------------------------------------------------------
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
// ---------------------------------------------------------------------------
func TestApplyErrorPolicy(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expectedHandled bool
expectedStatus int // expected outStatus
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
handleErrorCalls int
}{
{
name: "none_not_handled",
account: &Account{
ID: 10,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
},
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: false,
expectedStatus: 500, // passthrough
handleErrorCalls: 0,
},
{
name: "skipped_handled_no_handleError",
account: &Account{
ID: 11,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500, // not in custom codes
body: []byte(`"error"`),
expectedHandled: true,
expectedStatus: http.StatusInternalServerError, // skipped → 500
handleErrorCalls: 0,
},
{
name: "matched_handled_calls_handleError",
account: &Account{
ID: 12,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(500)},
},
},
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: true,
expectedStatus: 500, // matched → original status
handleErrorCalls: 1,
},
{
name: "temp_unscheduled_returns_switch_error",
account: &Account{
ID: 13,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expectedHandled: true,
expectedStatus: 503, // temp_unscheduled → original status
expectedSwitchErr: true,
handleErrorCalls: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{
rateLimitService: rlSvc,
}
var handleErrorCount int
p := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: tt.account,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
},
isStickySession: true,
}
handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch")
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
if tt.expectedSwitchErr {
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, retErr, &switchErr)
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
} else {
require.NoError(t, retErr)
}
})
}
}
// ---------------------------------------------------------------------------
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
// ---------------------------------------------------------------------------
type errorPolicyRepoStub struct {
mockAccountRepoForGemini
tempCalls int
setErrCalls int
lastErrorMsg string
}
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrCalls++
r.lastErrorMsg = errorMsg
return nil
}

View File

@@ -77,6 +77,9 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun
func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) Update(ctx context.Context, account *Account) error {
return nil
}
@@ -142,9 +145,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
@@ -216,22 +216,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context
return nil
}
func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int
@@ -290,6 +274,10 @@ func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, g
return nil, nil
}
func (m *mockGroupRepoForGateway) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return nil
}
func ptr[T any](v T) *T {
return &v
}

View File

@@ -6,9 +6,19 @@ import (
"fmt"
"math"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
// 仅在 GenerateSessionHash 第 3 级 fallback消息内容 hash时混入
// 避免不同用户发送相同消息产生相同 hash 导致账号集中。
type SessionContext struct {
ClientIP string
UserAgent string
APIKeyID int64
}
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
@@ -22,20 +32,22 @@ import (
// 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal减少 CPU 和内存开销
type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id用于会话亲和
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking部分平台会影响最终模型名
MaxTokens int // max_tokens 值(用于探测请求拦截)
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id用于会话亲和
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking部分平台会影响最终模型名
MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选请求上下文区分因子nil 时行为不变)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// protocol 指定请求协议格式domain.PlatformAnthropic / domain.PlatformGemini
// 不同协议使用不同的 system/messages 字段名。
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
@@ -64,14 +76,29 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
parsed.MetadataUserID = userID
}
}
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
switch protocol {
case domain.PlatformGemini:
// Gemini 原生格式: systemInstruction.parts / contents
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
parsed.System = parts
}
}
if contents, ok := req["contents"].([]any); ok {
parsed.Messages = contents
}
default:
// Anthropic / OpenAI 格式: system / messages
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
}
// thinking: {type: "enabled"}

View File

@@ -4,12 +4,13 @@ import (
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/stretchr/testify/require"
)
func TestParseGatewayRequest(t *testing.T) {
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
require.True(t, parsed.Stream)
@@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) {
func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, "claude-sonnet-4-5", parsed.Model)
require.True(t, parsed.ThinkingEnabled)
@@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) {
func TestParseGatewayRequest_MaxTokens(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, 1, parsed.MaxTokens)
}
func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
require.Equal(t, 0, parsed.MaxTokens)
}
func TestParseGatewayRequest_SystemNull(t *testing.T) {
body := []byte(`{"model":"claude-3","system":null}`)
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
require.NoError(t, err)
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
require.True(t, parsed.HasSystem)
@@ -53,16 +54,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) {
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
body := []byte(`{"model":123}`)
_, err := ParseGatewayRequest(body)
_, err := ParseGatewayRequest(body, "")
require.Error(t, err)
}
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
body := []byte(`{"stream":"true"}`)
_, err := ParseGatewayRequest(body)
_, err := ParseGatewayRequest(body, "")
require.Error(t, err)
}
// ============ Gemini 原生格式解析测试 ============
func TestParseGatewayRequest_GeminiContents(t *testing.T) {
body := []byte(`{
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there"}]},
{"role": "user", "parts": [{"text": "How are you?"}]}
]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Len(t, parsed.Messages, 3, "should parse contents as Messages")
require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem")
require.Nil(t, parsed.System, "no systemInstruction means nil System")
}
func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) {
body := []byte(`{
"systemInstruction": {
"parts": [{"text": "You are a helpful assistant."}]
},
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]}
]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System")
parts, ok := parsed.System.([]any)
require.True(t, ok)
require.Len(t, parts, 1)
partMap, ok := parts[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "You are a helpful assistant.", partMap["text"])
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_GeminiWithModel(t *testing.T) {
body := []byte(`{
"model": "gemini-2.5-pro",
"contents": [{"role": "user", "parts": [{"text": "test"}]}]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Equal(t, "gemini-2.5-pro", parsed.Model)
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) {
// Gemini 格式下 system/messages 字段应被忽略
body := []byte(`{
"system": "should be ignored",
"messages": [{"role": "user", "content": "ignored"}],
"contents": [{"role": "user", "parts": [{"text": "real content"}]}]
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field")
require.Nil(t, parsed.System, "no systemInstruction = nil System")
require.Len(t, parsed.Messages, 1, "should use contents, not messages")
}
func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) {
body := []byte(`{"contents": []}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Empty(t, parsed.Messages)
}
func TestParseGatewayRequest_GeminiNoContents(t *testing.T) {
body := []byte(`{"model": "gemini-2.5-flash"}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformGemini)
require.NoError(t, err)
require.Nil(t, parsed.Messages)
require.Equal(t, "gemini-2.5-flash", parsed.Model)
}
func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) {
// Anthropic 格式下 contents/systemInstruction 字段应被忽略
body := []byte(`{
"system": "real system",
"messages": [{"role": "user", "content": "real content"}],
"contents": [{"role": "user", "parts": [{"text": "ignored"}]}],
"systemInstruction": {"parts": [{"text": "ignored"}]}
}`)
parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic)
require.NoError(t, err)
require.True(t, parsed.HasSystem)
require.Equal(t, "real system", parsed.System)
require.Len(t, parsed.Messages, 1)
msg, ok := parsed.Messages[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "real content", msg["content"])
}
func TestFilterThinkingBlocks(t *testing.T) {
containsThinkingBlock := func(body []byte) bool {
var req map[string]any

View File

@@ -5,7 +5,6 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -17,6 +16,7 @@ import (
"os"
"regexp"
"sort"
"strconv"
"strings"
"sync/atomic"
"time"
@@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -242,12 +243,15 @@ var (
}
)
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除
var systemBlockFilterPrefixes = []string{
"x-anthropic-billing-header",
}
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
// allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{
"accept": true,
@@ -273,13 +277,6 @@ var allowedHeaders = map[string]bool{
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话Sticky Session的存储、查询、刷新和删除功能。
//
// ModelLoadInfo 模型负载信息(用于 Antigravity 调度)
// Model load info for Antigravity scheduling
type ModelLoadInfo struct {
CallCount int64 // 当前分钟调用次数 / Call count in current minute
LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled)
}
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
@@ -295,24 +292,6 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// IncrModelCallCount 增加模型调用次数并更新最后调度时间Antigravity 专用)
// Increment model call count and update last scheduling time (Antigravity only)
// 返回更新后的调用次数
IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error)
// GetModelLoadBatch 批量获取账号的模型负载信息Antigravity 专用)
// Batch get model load info for accounts (Antigravity only)
GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error)
// FindGeminiSession 查找 Gemini 会话MGET 倒序匹配)
// Find Gemini session using MGET reverse order matching
// 返回最长匹配的会话信息uuid, accountID
FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool)
// SaveGeminiSession 保存 Gemini 会话
// Save Gemini session binding
SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@@ -323,21 +302,15 @@ func derefGroupID(groupID *int64) int64 {
return *groupID
}
// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。
// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。
// 低于此阈值时保持粘性会话,等待短暂限流结束。
const stickySessionRateLimitThreshold = 10 * time.Second
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。
// 或请求的模型处于限流状态时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// within temporary unschedulable period, or model rate limit remaining time
// exceeds stickySessionRateLimitThreshold.
// within temporary unschedulable period, or the requested model is rate-limited.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil {
@@ -349,8 +322,8 @@ func shouldClearStickySession(account *Account, requestedModel string) bool {
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
// 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold {
// 检查模型限流和 scope 限流,有限流即清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
return true
}
return false
@@ -395,15 +368,31 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应应在同一账号上重试 N 次再切换
}
func (e *UpstreamFailoverError) Error() string {
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
}
// TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。
// 由 handler 层在同账号重试全部用尽、切换账号时调用。
func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) {
if failoverErr == nil || !failoverErr.RetryableOnSameAccount {
return
}
// 根据状态码选择封禁策略
switch failoverErr.StatusCode {
case http.StatusBadRequest:
tempUnscheduleGoogleConfigError(ctx, s.accountRepo, accountID, "[handler]")
case http.StatusBadGateway:
tempUnscheduleEmptyResponse(ctx, s.accountRepo, accountID, "[handler]")
}
}
// GatewayService handles API gateway operations
type GatewayService struct {
accountRepo AccountRepository
@@ -413,6 +402,7 @@ type GatewayService struct {
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
digestStore *DigestSessionStore
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
@@ -446,6 +436,7 @@ func NewGatewayService(
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
digestStore *DigestSessionStore,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
@@ -455,6 +446,7 @@ func NewGatewayService(
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
digestStore: digestStore,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
@@ -488,23 +480,45 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return s.hashContent(cacheableContent)
}
// 3. Fallback: 使用 system 内容
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
var combined strings.Builder
// 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash
if parsed.SessionContext != nil {
_, _ = combined.WriteString(parsed.SessionContext.ClientIP)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(parsed.SessionContext.UserAgent)
_, _ = combined.WriteString(":")
_, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10))
_, _ = combined.WriteString("|")
}
if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" {
return s.hashContent(systemText)
_, _ = combined.WriteString(systemText)
}
}
// 4. 最后 fallback: 使用第一条消息
if len(parsed.Messages) > 0 {
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
msgText := s.extractTextFromContent(firstMsg["content"])
if msgText != "" {
return s.hashContent(msgText)
for _, msg := range parsed.Messages {
if m, ok := msg.(map[string]any); ok {
if content, exists := m["content"]; exists {
// Anthropic: messages[].content
if msgText := s.extractTextFromContent(content); msgText != "" {
_, _ = combined.WriteString(msgText)
}
} else if parts, ok := m["parts"].([]any); ok {
// Gemini: contents[].parts[].text
for _, part := range parts {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok {
_, _ = combined.WriteString(text)
}
}
}
}
}
}
if combined.Len() > 0 {
return s.hashContent(combined.String())
}
return ""
}
@@ -532,19 +546,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID
// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配)
// 返回最长匹配的会话信息uuid, accountID
func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
if digestChain == "" || s.cache == nil {
return "", 0, false
func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.digestStore == nil {
return "", 0, "", false
}
return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain)
return s.digestStore.Find(groupID, prefixHash, digestChain)
}
// SaveGeminiSession 保存 Gemini 会话
func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
if digestChain == "" || s.cache == nil {
// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain用于删旧 key。
func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.digestStore == nil {
return nil
}
return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID)
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
}
// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配)
func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) {
if digestChain == "" || s.digestStore == nil {
return "", 0, "", false
}
return s.digestStore.Find(groupID, prefixHash, digestChain)
}
// SaveAnthropicSession 保存 Anthropic 会话
func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error {
if digestChain == "" || s.digestStore == nil {
return nil
}
s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain)
return nil
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
@@ -629,8 +661,8 @@ func (s *GatewayService) extractTextFromContent(content any) string {
}
func (s *GatewayService) hashContent(content string) string {
hash := sha256.Sum256([]byte(content))
return hex.EncodeToString(hash[:16]) // 32字符
h := xxhash.Sum64String(content)
return strconv.FormatUint(h, 36)
}
// replaceModelInBody 替换请求体中的model字段
@@ -989,13 +1021,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
}
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
@@ -1110,7 +1135,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
}
@@ -1190,6 +1214,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
shuffleWithinSortGroups(routingAvailable)
// 4. 尝试获取槽位
for _, item := range routingAvailable {
@@ -1264,7 +1289,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
@@ -1344,10 +1368,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return result, nil
}
} else {
// Antigravity 平台:获取模型负载信息
var modelLoadMap map[int64]*ModelLoadInfo
isAntigravity := platform == PlatformAntigravity
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
@@ -1362,109 +1382,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致)
if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 {
modelLoadMap = make(map[int64]*ModelLoadInfo, len(available))
modelToAccountIDs := make(map[string][]int64)
for _, item := range available {
mappedModel := mapAntigravityModel(item.account, requestedModel)
if mappedModel == "" {
continue
}
modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID)
// 分层过滤选择:优先级 → 负载率 → LRU
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
}
for model, ids := range modelToAccountIDs {
batch, err := s.cache.GetModelLoadBatch(ctx, ids, model)
if err != nil {
continue
}
for id, info := range batch {
modelLoadMap[id] = info
}
}
if len(modelLoadMap) == 0 {
modelLoadMap = nil
}
}
// Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值)
// 其他平台:分层过滤选择:优先级 → 负载率 → LRU
if isAntigravity {
for len(available) > 0 {
// 1. 取优先级最小的集合(硬过滤)
candidates := filterByMinPriority(available)
// 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值)
selected := selectByCallCount(candidates, modelLoadMap, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
// 移除已尝试的账号,重新选择
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
} else {
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
}
result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
} else {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
// 移除已尝试的账号,重新进行分层过滤
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
// 移除已尝试的账号,重新进行分层过滤
selectedID := selected.account.ID
newAvailable := make([]accountWithLoad, 0, len(available)-1)
for _, acc := range available {
if acc.account.ID != selectedID {
newAvailable = append(newAvailable, acc)
}
}
available = newAvailable
}
available = newAvailable
}
}
@@ -1750,6 +1705,17 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
return accounts, useMixed, nil
}
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
func (s *GatewayService) IsSingleAntigravityAccountGroup(ctx context.Context, groupID *int64) bool {
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAntigravity, true)
if err != nil {
return false
}
return len(accounts) == 1
}
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
if account == nil {
return false
@@ -2000,87 +1966,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
shuffleWithinPriorityAndLastUsed(accounts)
}
// selectByCallCount 从候选账号中选择调用次数最少的账号Antigravity 专用)
// 新账号CallCount=0使用平均调用次数作为虚拟值避免冷启动被猛调
// 如果有多个账号具有相同的最小调用次数,则随机选择一个
func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad {
if len(accounts) == 0 {
return nil
// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。
// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。
func shuffleWithinSortGroups(accounts []accountWithLoad) {
if len(accounts) <= 1 {
return
}
if len(accounts) == 1 {
return &accounts[0]
}
// 如果没有负载信息,回退到 LRU
if modelLoadMap == nil {
return selectByLRU(accounts, preferOAuth)
}
// 1. 计算平均调用次数(用于新账号冷启动)
var totalCallCount int64
var countWithCalls int
for _, acc := range accounts {
if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 {
totalCallCount += info.CallCount
countWithCalls++
i := 0
for i < len(accounts) {
j := i + 1
for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) {
j++
}
}
var avgCallCount int64
if countWithCalls > 0 {
avgCallCount = totalCallCount / int64(countWithCalls)
}
// 2. 获取每个账号的有效调用次数
getEffectiveCallCount := func(acc accountWithLoad) int64 {
if acc.account == nil {
return 0
if j-i > 1 {
mathrand.Shuffle(j-i, func(a, b int) {
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
})
}
info := modelLoadMap[acc.account.ID]
if info == nil || info.CallCount == 0 {
return avgCallCount // 新账号使用平均值
}
return info.CallCount
i = j
}
}
// 3. 找到最小调用次数
minCount := getEffectiveCallCount(accounts[0])
for _, acc := range accounts[1:] {
if c := getEffectiveCallCount(acc); c < minCount {
minCount = c
}
// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组
func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
if a.account.Priority != b.account.Priority {
return false
}
// 4. 收集所有具有最小调用次数的账号
var candidateIdxs []int
for i, acc := range accounts {
if getEffectiveCallCount(acc) == minCount {
candidateIdxs = append(candidateIdxs, i)
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return false
}
return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt)
}
// 5. 如果只有一个候选,直接返回
if len(candidateIdxs) == 1 {
return &accounts[candidateIdxs[0]]
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
func shuffleWithinPriorityAndLastUsed(accounts []*Account) {
if len(accounts) <= 1 {
return
}
// 6. preferOAuth 处理
if preferOAuth {
var oauthIdxs []int
for _, idx := range candidateIdxs {
if accounts[idx].account.Type == AccountTypeOAuth {
oauthIdxs = append(oauthIdxs, idx)
}
i := 0
for i < len(accounts) {
j := i + 1
for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) {
j++
}
if len(oauthIdxs) > 0 {
candidateIdxs = oauthIdxs
if j-i > 1 {
mathrand.Shuffle(j-i, func(a, b int) {
accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a]
})
}
i = j
}
}
// 7. 随机选择
return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]]
// sameAccountGroup 判断两个 Account 是否属于同一排序组Priority + LastUsedAt
func sameAccountGroup(a, b *Account) bool {
if a.Priority != b.Priority {
return false
}
return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt)
}
// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒)
func sameLastUsedAt(a, b *time.Time) bool {
switch {
case a == nil && b == nil:
return true
case a == nil || b == nil:
return false
default:
return a.Unix() == b.Unix()
}
}
// sortCandidatesForFallback 根据配置选择排序策略
@@ -2135,13 +2093,6 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
@@ -2169,9 +2120,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -2272,9 +2220,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
@@ -2383,9 +2328,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -2488,9 +2430,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
@@ -2767,6 +2706,60 @@ func hasClaudeCodePrefix(text string) bool {
return false
}
// matchesFilterPrefix 检查文本是否匹配任一过滤前缀
func matchesFilterPrefix(text string) bool {
for _, prefix := range systemBlockFilterPrefixes {
if strings.HasPrefix(text, prefix) {
return true
}
}
return false
}
// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素
// 直接从 body 解析 system不依赖外部传入的 parsed.System因为前置步骤可能已修改 body 中的 system
func filterSystemBlocksByPrefix(body []byte) []byte {
sys := gjson.GetBytes(body, "system")
if !sys.Exists() {
return body
}
switch {
case sys.Type == gjson.String:
if matchesFilterPrefix(sys.Str) {
result, err := sjson.DeleteBytes(body, "system")
if err != nil {
return body
}
return result
}
case sys.IsArray():
var parsed []any
if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil {
return body
}
filtered := make([]any, 0, len(parsed))
changed := false
for _, item := range parsed {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) {
changed = true
continue
}
}
filtered = append(filtered, item)
}
if changed {
result, err := sjson.SetBytes(body, "system", filtered)
if err != nil {
return body
}
return result
}
}
return body
}
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
// 处理 null、字符串、数组三种格式
func injectClaudeCodePrompt(body []byte, system any) []byte {
@@ -3046,6 +3039,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据)
// 放在 inject/normalize 之后,确保不会被覆盖
if account.IsOAuth() {
body = filterSystemBlocksByPrefix(body)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
@@ -5165,27 +5164,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return normalized, nil
}
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
if !ok {
return nil // 无法解析 scope跳过检查
}
group, err := s.resolveGroupByID(ctx, groupID)
if err != nil {
return nil // 查询失败时放行
}
if group == nil {
return nil // 分组不存在时放行
}
if !IsScopeSupported(group.SupportedModelScopes, scope) {
return ErrModelScopeNotSupported
}
return nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {

View File

@@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
parsed, err := ParseGatewayRequest(body)
parsed, err := ParseGatewayRequest(body, "")
if err != nil {
b.Fatalf("解析请求失败: %v", err)
}

View File

@@ -0,0 +1,384 @@
//go:build unit
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestShouldFailoverGeminiUpstreamError — verifies the failover decision
// for the ErrorPolicyNone path (original logic preserved).
// ---------------------------------------------------------------------------
func TestShouldFailoverGeminiUpstreamError(t *testing.T) {
svc := &GeminiMessagesCompatService{}
tests := []struct {
name string
statusCode int
expected bool
}{
{"401_failover", 401, true},
{"403_failover", 403, true},
{"429_failover", 429, true},
{"529_failover", 529, true},
{"500_failover", 500, true},
{"502_failover", 502, true},
{"503_failover", 503, true},
{"400_no_failover", 400, false},
{"404_no_failover", 404, false},
{"422_no_failover", 422, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode)
require.Equal(t, tt.expected, got)
})
}
}
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works
// correctly for Gemini platform accounts (API Key type).
// ---------------------------------------------------------------------------
func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
tests := []struct {
name string
account *Account
statusCode int
body []byte
expected ErrorPolicyResult
}{
{
name: "gemini_apikey_custom_codes_hit",
account: &Account{
ID: 100,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429), float64(500)},
},
},
statusCode: 429,
body: []byte(`{"error":"rate limited"}`),
expected: ErrorPolicyMatched,
},
{
name: "gemini_apikey_custom_codes_miss",
account: &Account{
ID: 101,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500,
body: []byte(`{"error":"internal"}`),
expected: ErrorPolicySkipped,
},
{
name: "gemini_apikey_no_custom_codes_returns_none",
account: &Account{
ID: 102,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 500,
body: []byte(`{"error":"internal"}`),
expected: ErrorPolicyNone,
},
{
name: "gemini_apikey_temp_unschedulable_hit",
account: &Account{
ID: 103,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded service`),
expected: ErrorPolicyTempUnscheduled,
},
{
name: "gemini_custom_codes_override_temp_unschedulable",
account: &Account{
ID: 104,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(503)},
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
body: []byte(`overloaded`),
expected: ErrorPolicyMatched, // custom codes take precedence
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &errorPolicyRepoStub{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
require.Equal(t, tt.expected, result)
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling
// paths produce the correct behavior for each ErrorPolicyResult.
//
// These tests simulate the inline error policy switch in handleClaudeCompat
// and forwardNativeGemini by calling the same methods in the same order.
// ---------------------------------------------------------------------------
func TestGeminiErrorPolicyIntegration(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
account *Account
statusCode int
respBody []byte
expectFailover bool // expect UpstreamFailoverError
expectHandleError bool // expect handleGeminiUpstreamError to be called
expectShouldFailover bool // for None path, whether shouldFailover triggers
}{
{
name: "custom_codes_matched_429_failover",
account: &Account{
ID: 200,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 429,
respBody: []byte(`{"error":"rate limited"}`),
expectFailover: true,
expectHandleError: true,
},
{
name: "custom_codes_skipped_500_no_failover",
account: &Account{
ID: 201,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
},
statusCode: 500,
respBody: []byte(`{"error":"internal"}`),
expectFailover: false,
expectHandleError: false,
},
{
name: "temp_unschedulable_matched_failover",
account: &Account{
ID: 202,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(503),
"keywords": []any{"overloaded"},
"duration_minutes": float64(10),
},
},
},
},
statusCode: 503,
respBody: []byte(`overloaded`),
expectFailover: true,
expectHandleError: true,
},
{
name: "no_policy_429_failover_via_shouldFailover",
account: &Account{
ID: 203,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 429,
respBody: []byte(`{"error":"rate limited"}`),
expectFailover: true,
expectHandleError: true,
expectShouldFailover: true,
},
{
name: "no_policy_400_no_failover",
account: &Account{
ID: 204,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
statusCode: 400,
respBody: []byte(`{"error":"bad request"}`),
expectFailover: false,
expectHandleError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &geminiErrorPolicyRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &GeminiMessagesCompatService{
accountRepo: repo,
rateLimitService: rlSvc,
}
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// Simulate the Claude compat error handling path (same logic as native).
// This mirrors the inline switch in handleClaudeCompat.
var handleErrorCalled bool
var gotFailover bool
ctx := context.Background()
statusCode := tt.statusCode
respBody := tt.respBody
account := tt.account
headers := http.Header{}
if svc.rateLimitService != nil {
switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) {
case ErrorPolicySkipped:
// Skipped → return error directly (no handleGeminiUpstreamError, no failover)
gotFailover = false
handleErrorCalled = false
goto verify
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
handleErrorCalled = true
gotFailover = true
goto verify
}
}
// ErrorPolicyNone → original logic
svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody)
handleErrorCalled = true
if svc.shouldFailoverGeminiUpstreamError(statusCode) {
gotFailover = true
}
verify:
require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch")
require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch")
if tt.expectShouldFailover {
require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode),
"shouldFailoverGeminiUpstreamError should return true for status %d", statusCode)
}
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety
// ---------------------------------------------------------------------------
func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) {
svc := &GeminiMessagesCompatService{
rateLimitService: nil,
}
// When rateLimitService is nil, error policy is skipped → falls through to
// shouldFailoverGeminiUpstreamError (original logic).
// Verify this doesn't panic and follows expected behavior.
ctx := context.Background()
account := &Account{
ID: 300,
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(429)},
},
}
// The nil check should prevent CheckErrorPolicy from being called
if svc.rateLimitService != nil {
t.Fatal("rateLimitService should be nil for this test")
}
// shouldFailoverGeminiUpstreamError still works
require.True(t, svc.shouldFailoverGeminiUpstreamError(429))
require.False(t, svc.shouldFailoverGeminiUpstreamError(400))
// handleGeminiUpstreamError should not panic with nil rateLimitService
require.NotPanics(t, func() {
svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`))
})
}
// ---------------------------------------------------------------------------
// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error
// policy tests. Embeds mockAccountRepoForGemini and adds tracking.
// ---------------------------------------------------------------------------
type geminiErrorPolicyRepo struct {
mockAccountRepoForGemini
setErrorCalls int
setRateLimitedCalls int
setTempCalls int
}
func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error {
r.setErrorCalls++
return nil
}
func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error {
r.setRateLimitedCalls++
return nil
}
func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.setTempCalls++
return nil
}

View File

@@ -560,10 +560,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
@@ -640,10 +637,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
@@ -776,6 +770,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
break
}
// 错误策略优先:匹配则跳过重试直接处理。
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
resp = rebuilt
break
} else {
resp = rebuilt
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
@@ -837,37 +839,77 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
// 统一错误策略:自定义错误码 + 临时不可调度
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if tempMatched {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
case ErrorPolicySkipped:
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamDetail = truncateString(string(respBody), maxBytes)
return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
}
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
if resp.StatusCode == http.StatusBadRequest {
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if isGoogleProjectConfigError(msg400) {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true}
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader)
@@ -1026,10 +1068,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
@@ -1097,10 +1136,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return upstreamReq, "x-request-id", nil
} else {
// Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token)
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
@@ -1179,6 +1215,14 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
}
// 错误策略优先:匹配则跳过重试直接处理。
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
resp = rebuilt
break
} else {
resp = rebuilt
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
@@ -1261,14 +1305,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// This avoids Gemini SDKs failing hard during preflight token counting.
// Checked before error policy so it always works regardless of custom error codes.
if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
@@ -1282,29 +1321,73 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
if tempMatched {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
// 统一错误策略:自定义错误码 + 临时不可调度
if s.rateLimitService != nil {
switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) {
case ErrorPolicySkipped:
respBody = unwrapIfNeeded(isOAuth, respBody)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
upstreamDetail = truncateString(string(evBody), maxBytes)
c.Data(http.StatusInternalServerError, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(evBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
}
// ErrorPolicyNone → 原有逻辑
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
if resp.StatusCode == http.StatusBadRequest {
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if isGoogleProjectConfigError(msg400) {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody)))
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(evBody), maxBytes)
}
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody, RetryableOnSameAccount: true}
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody)
@@ -1417,6 +1500,26 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
// checkErrorPolicyInLoop 在重试循环内预检查错误策略。
// 返回 true 表示策略已匹配(调用者应 breakresp 已重建可直接使用。
// 返回 false 表示 ErrorPolicyNoneresp 已重建,调用者继续走重试逻辑。
func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop(
ctx context.Context, account *Account, resp *http.Response,
) (matched bool, rebuilt *http.Response) {
if resp.StatusCode < 400 || s.rateLimitService == nil {
return false, resp
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
rebuilt = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(body)),
}
policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body)
return policy != ErrorPolicyNone, rebuilt
}
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
@@ -2420,10 +2523,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, errors.New("invalid path")
}
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
if baseURL == "" {
baseURL = geminicli.AIStudioBaseURL
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
@@ -2592,6 +2692,10 @@ func asInt(v any) (int, bool) {
}
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
// 遵守自定义错误码策略:未命中则跳过所有限流处理
if !account.ShouldHandleErrorCode(statusCode) {
return
}
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
return

View File

@@ -66,6 +66,9 @@ func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account)
func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil }
func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
@@ -133,9 +136,6 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
@@ -226,6 +226,10 @@ func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, gr
return nil, nil
}
func (m *mockGroupRepoForGemini) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return nil
}
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
@@ -265,22 +269,6 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context,
return nil
}
func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()

View File

@@ -6,26 +6,11 @@ import (
"encoding/json"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/cespare/xxhash/v2"
)
// Gemini 会话 ID Fallback 相关常量
const (
// geminiSessionTTLSeconds Gemini 会话缓存 TTL5 分钟)
geminiSessionTTLSeconds = 300
// geminiSessionKeyPrefix Gemini 会话 Redis key 前缀
geminiSessionKeyPrefix = "gemini:sess:"
)
// GeminiSessionTTL 返回 Gemini 会话缓存 TTL
func GeminiSessionTTL() time.Duration {
return geminiSessionTTLSeconds * time.Second
}
// shortHash 使用 XXHash64 + Base36 生成短 hash16 字符)
// XXHash64 比 SHA256 快约 10 倍Base36 比 Hex 短约 20%
func shortHash(data []byte) string {
@@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m
return base64.RawURLEncoding.EncodeToString(hash[:12])
}
// BuildGeminiSessionKey 构建 Gemini 会话 Redis key
// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain}
func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string {
return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain
}
// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短)
// 用于 MGET 批量查询最长匹配
func GenerateDigestChainPrefixes(chain string) []string {
if chain == "" {
return nil
}
var prefixes []string
c := chain
for c != "" {
prefixes = append(prefixes, c)
// 找到最后一个 "-" 的位置
if i := strings.LastIndex(c, "-"); i > 0 {
c = c[:i]
} else {
break
}
}
return prefixes
}
// ParseGeminiSessionValue 解析 Gemini 会话缓存值
// 格式: {uuid}:{accountID}
func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) {
@@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string {
// geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀
const geminiDigestSessionKeyPrefix = "gemini:digest:"
// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀
const geminiTrieKeyPrefix = "gemini:trie:"
// BuildGeminiTrieKey 构建 Gemini Trie Redis key
// 格式: gemini:trie:{groupID}:{prefixHash}
func BuildGeminiTrieKey(groupID int64, prefixHash string) string {
return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash
}
// GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
// 用于在 SelectAccountWithLoadAwareness 中保持粘性会话

View File

@@ -1,41 +1,14 @@
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// mockGeminiSessionCache 模拟 Redis 缓存
type mockGeminiSessionCache struct {
sessions map[string]string // key -> value
}
func newMockGeminiSessionCache() *mockGeminiSessionCache {
return &mockGeminiSessionCache{sessions: make(map[string]string)}
}
func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) {
key := BuildGeminiSessionKey(groupID, prefixHash, digestChain)
value := FormatGeminiSessionValue(uuid, accountID)
m.sessions[key] = value
}
func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
prefixes := GenerateDigestChainPrefixes(digestChain)
for _, p := range prefixes {
key := BuildGeminiSessionKey(groupID, prefixHash, p)
if val, ok := m.sessions[key]; ok {
return ParseGeminiSessionValue(val)
}
}
return "", 0, false
}
// TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配
func TestGeminiSessionContinuousConversation(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
sessionUUID := "session-uuid-12345"
@@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 1 chain: %s", chain1)
// 第一轮:没有找到会话,创建新会话
_, _, found := cache.Find(groupID, prefixHash, chain1)
_, _, _, found := store.Find(groupID, prefixHash, chain1)
if found {
t.Error("Round 1: should not find existing session")
}
// 保存第一轮会话
cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID)
// 保存第一轮会话(首轮无旧 chain
store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "")
// 模拟第二轮对话(用户继续对话)
req2 := &antigravity.GeminiRequest{
@@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 2 chain: %s", chain2)
// 第二轮:应该能找到会话(通过前缀匹配)
foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2)
foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2)
if !found {
t.Error("Round 2: should find session via prefix matching")
}
@@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID)
}
// 保存第二轮会话
cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID)
// 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key
store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain)
// 模拟第三轮对话
req3 := &antigravity.GeminiRequest{
@@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
t.Logf("Round 3 chain: %s", chain3)
// 第三轮:应该能找到会话(通过第二轮的前缀匹配)
foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3)
foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3)
if !found {
t.Error("Round 3: should find session via prefix matching")
}
@@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) {
if foundAccID != accountID {
t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID)
}
t.Log("✓ Continuous conversation session matching works correctly!")
}
// TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配
func TestGeminiSessionDifferentConversations(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
@@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
},
}
chain1 := BuildGeminiDigestChain(req1)
cache.Save(groupID, prefixHash, chain1, "session-1", 100)
store.Save(groupID, prefixHash, chain1, "session-1", 100, "")
// 第二个完全不同的会话
req2 := &antigravity.GeminiRequest{
@@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) {
chain2 := BuildGeminiDigestChain(req2)
// 不同会话不应该匹配
_, _, found := cache.Find(groupID, prefixHash, chain2)
_, _, _, found := store.Find(groupID, prefixHash, chain2)
if found {
t.Error("Different conversations should not match")
}
t.Log("✓ Different conversations are correctly isolated!")
}
// TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先)
func TestGeminiSessionPrefixMatchingOrder(t *testing.T) {
cache := newMockGeminiSessionCache()
store := NewDigestSessionStore()
groupID := int64(1)
prefixHash := "test_prefix_hash"
// 创建一个三轮对话
req := &antigravity.GeminiRequest{
SystemInstruction: &antigravity.GeminiContent{
Parts: []antigravity.GeminiPart{{Text: "System prompt"}},
},
Contents: []antigravity.GeminiContent{
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}},
{Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}},
{Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}},
},
}
fullChain := BuildGeminiDigestChain(req)
prefixes := GenerateDigestChainPrefixes(fullChain)
t.Logf("Full chain: %s", fullChain)
t.Logf("Prefixes (longest first): %v", prefixes)
// 验证前缀生成顺序(从长到短)
if len(prefixes) != 4 {
t.Errorf("Expected 4 prefixes, got %d", len(prefixes))
}
// 保存不同轮次的会话到不同账号
// 第一轮(最短前缀)-> 账号 1
cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1)
// 第二轮 -> 账号 2
cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2)
// 第三轮(最长前缀,完整链)-> 账号 3
cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3)
store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "")
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "")
store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "")
// 查找应该返回最长匹配(账号 3
_, accID, found := cache.Find(groupID, prefixHash, fullChain)
// 查找更长的链,应该返回最长匹配(账号 3
_, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2")
if !found {
t.Error("Should find session")
}
if accID != 3 {
t.Errorf("Should match longest prefix (account 3), got account %d", accID)
}
t.Log("✓ Longest prefix matching works correctly!")
}
// 确保 context 包被使用(避免未使用的导入警告)
var _ = context.Background

View File

@@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) {
}
}
func TestGenerateDigestChainPrefixes(t *testing.T) {
tests := []struct {
name string
chain string
want []string
wantLen int
}{
{
name: "empty",
chain: "",
wantLen: 0,
},
{
name: "single part",
chain: "u:abc123",
want: []string{"u:abc123"},
wantLen: 1,
},
{
name: "two parts",
chain: "s:xyz-u:abc",
want: []string{"s:xyz-u:abc", "s:xyz"},
wantLen: 2,
},
{
name: "four parts",
chain: "s:a-u:b-m:c-u:d",
want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"},
wantLen: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateDigestChainPrefixes(tt.chain)
if len(result) != tt.wantLen {
t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result)
}
if tt.want != nil {
for i, want := range tt.want {
if i >= len(result) {
t.Errorf("missing prefix at index %d", i)
continue
}
if result[i] != want {
t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i])
}
}
}
})
}
}
func TestParseGeminiSessionValue(t *testing.T) {
tests := []struct {
name string
@@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) {
}
})
}
func TestBuildGeminiTrieKey(t *testing.T) {
tests := []struct {
name string
groupID int64
prefixHash string
want string
}{
{
name: "normal",
groupID: 123,
prefixHash: "abcdef12",
want: "gemini:trie:123:abcdef12",
},
{
name: "zero group",
groupID: 0,
prefixHash: "xyz",
want: "gemini:trie:0:xyz",
},
{
name: "empty prefix",
groupID: 1,
prefixHash: "",
want: "gemini:trie:1:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash)
if got != tt.want {
t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -45,6 +45,9 @@ type Group struct {
// 可选值: claude, gemini_text, gemini_image
SupportedModelScopes []string
// 分组排序
SortOrder int
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -33,6 +33,14 @@ type GroupRepository interface {
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
// BindAccountsToGroup 将多个账号绑定到指定分组
BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error
// UpdateSortOrders 批量更新分组排序
UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
}
// GroupSortOrderUpdate 分组排序更新
type GroupSortOrderUpdate struct {
ID int64 `json:"id"`
SortOrder int `json:"sort_order"`
}
// CreateGroupRequest 创建分组请求

View File

@@ -318,110 +318,6 @@ func TestGetModelRateLimitRemainingTime(t *testing.T) {
}
}
func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future10m := now.Add(10 * time.Minute).Format(time.RFC3339)
past := now.Add(-10 * time.Minute).Format(time.RFC3339)
tests := []struct {
name string
account *Account
requestedModel string
minExpected time.Duration
maxExpected time.Duration
}{
{
name: "nil account",
account: nil,
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "non-antigravity platform",
account: &Account{
Platform: PlatformAnthropic,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "claude scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "gemini_text scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"gemini_text": map[string]any{
"rate_limit_reset_at": future10m,
},
},
},
},
requestedModel: "gemini-3-flash",
minExpected: 9 * time.Minute,
maxExpected: 11 * time.Minute,
},
{
name: "expired scope rate limit",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": past,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 0,
maxExpected: 0,
},
{
name: "unsupported model",
account: &Account{
Platform: PlatformAntigravity,
},
requestedModel: "gpt-4",
minExpected: 0,
maxExpected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel)
if result < tt.minExpected || result > tt.maxExpected {
t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected)
}
})
}
}
func TestGetRateLimitRemainingTime(t *testing.T) {
now := time.Now()
future15m := now.Add(15 * time.Minute).Format(time.RFC3339)
@@ -442,45 +338,19 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
maxExpected: 0,
},
{
name: "model remaining > scope remaining - returns model",
name: "model rate limited - 15 minutes",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟
"rate_limit_reset_at": future15m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
maxExpected: 16 * time.Minute,
},
{
name: "scope remaining > model remaining - returns scope",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": future5m, // 5 分钟
},
},
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future15m, // 15 分钟
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 14 * time.Minute, // 应返回较大的 15 分钟
minExpected: 14 * time.Minute,
maxExpected: 16 * time.Minute,
},
{
@@ -499,22 +369,6 @@ func TestGetRateLimitRemainingTime(t *testing.T) {
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "only scope rate limited",
account: &Account{
Platform: PlatformAntigravity,
Extra: map[string]any{
antigravityQuotaScopesKey: map[string]any{
"claude": map[string]any{
"rate_limit_reset_at": future5m,
},
},
},
},
requestedModel: "claude-sonnet-4-5",
minExpected: 4 * time.Minute,
maxExpected: 6 * time.Minute,
},
{
name: "neither rate limited",
account: &Account{

View File

@@ -580,10 +580,6 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
@@ -618,6 +614,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
shuffleWithinSortGroups(available)
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)

View File

@@ -204,22 +204,6 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i
return nil
}
func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
return 0, nil
}
func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) {
return nil, nil
}
func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
return "", 0, false
}
func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)

View File

@@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
}
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
scopeRateLimits := acc.GetAntigravityScopeRateLimits()
if acc.Platform != "" {
if _, ok := platform[acc.Platform]; !ok {
@@ -85,14 +84,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError {
p.ErrorCount++
}
if len(scopeRateLimits) > 0 {
if p.ScopeRateLimitCount == nil {
p.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
p.ScopeRateLimitCount[scope]++
}
}
}
for _, grp := range acc.Groups {
@@ -117,14 +108,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
if hasError {
g.ErrorCount++
}
if len(scopeRateLimits) > 0 {
if g.ScopeRateLimitCount == nil {
g.ScopeRateLimitCount = make(map[string]int64)
}
for scope := range scopeRateLimits {
g.ScopeRateLimitCount[scope]++
}
}
}
displayGroupID := int64(0)
@@ -157,9 +140,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi
item.RateLimitRemainingSec = &remainingSec
}
}
if len(scopeRateLimits) > 0 {
item.ScopeRateLimits = scopeRateLimits
}
if isOverloaded && acc.OverloadUntil != nil {
item.OverloadUntil = acc.OverloadUntil
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())

View File

@@ -50,24 +50,22 @@ type UserConcurrencyInfo struct {
// PlatformAvailability aggregates account availability by platform.
type PlatformAvailability struct {
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
ErrorCount int64 `json:"error_count"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"`
}
// GroupAvailability aggregates account availability by group.
type GroupAvailability struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"`
ErrorCount int64 `json:"error_count"`
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"`
}
// AccountAvailability represents current availability for a single account.
@@ -85,11 +83,10 @@ type AccountAvailability struct {
IsOverloaded bool `json:"is_overloaded"`
HasError bool `json:"has_error"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"`
OverloadUntil *time.Time `json:"overload_until"`
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
ErrorMessage string `json:"error_message"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
OverloadUntil *time.Time `json:"overload_until"`
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
ErrorMessage string `json:"error_message"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
}

View File

@@ -12,6 +12,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
@@ -528,7 +529,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) {
switch reqType {
case opsRetryTypeMessages:
parsed, parseErr := ParseGatewayRequest(body)
parsed, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
if parseErr != nil {
return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr)
}
@@ -596,7 +597,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq
if s.gatewayService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"}
}
parsedReq, parseErr := ParseGatewayRequest(body)
parsedReq, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
if parseErr != nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"}
}

View File

@@ -20,6 +20,10 @@ const (
// retry the specific upstream attempt (not just the client request).
// This value is sanitized+trimmed before being persisted.
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
// OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。
// ops_error_logger 中间件检查此 key为 true 时跳过错误记录。
OpsSkipPassthroughKey = "ops_skip_passthrough"
)
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
@@ -103,6 +107,37 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
evCopy := ev
existing = append(existing, &evCopy)
c.Set(OpsUpstreamErrorsKey, existing)
checkSkipMonitoringForUpstreamEvent(c, &evCopy)
}
// checkSkipMonitoringForUpstreamEvent checks whether the upstream error event
// matches a passthrough rule with skip_monitoring=true and, if so, sets the
// OpsSkipPassthroughKey on the context. This ensures intermediate retry /
// failover errors (which never go through the final applyErrorPassthroughRule
// path) can still suppress ops_error_logs recording.
func checkSkipMonitoringForUpstreamEvent(c *gin.Context, ev *OpsUpstreamErrorEvent) {
if ev.UpstreamStatusCode == 0 {
return
}
svc := getBoundErrorPassthroughService(c)
if svc == nil {
return
}
// Use the best available body representation for keyword matching.
// Even when body is empty, MatchRule can still match rules that only
// specify ErrorCodes (no Keywords), so we always call it.
body := ev.Detail
if body == "" {
body = ev.Message
}
rule := svc.MatchRule(ev.Platform, ev.UpstreamStatusCode, []byte(body))
if rule != nil && rule.SkipMonitoring {
c.Set(OpsSkipPassthroughKey, true)
}
}
func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {

View File

@@ -62,6 +62,32 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali
s.tokenCacheInvalidator = invalidator
}
// ErrorPolicyResult 表示错误策略检查的结果
type ErrorPolicyResult int
const (
ErrorPolicyNone ErrorPolicyResult = iota // 未命中任何策略,继续默认逻辑
ErrorPolicySkipped // 自定义错误码开启但未命中,跳过处理
ErrorPolicyMatched // 自定义错误码命中,应停止调度
ErrorPolicyTempUnscheduled // 临时不可调度规则命中
)
// CheckErrorPolicy 检查自定义错误码和临时不可调度规则。
// 自定义错误码开启时覆盖后续所有逻辑(包括临时不可调度)。
func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Account, statusCode int, responseBody []byte) ErrorPolicyResult {
if account.IsCustomErrorCodesEnabled() {
if account.ShouldHandleErrorCode(statusCode) {
return ErrorPolicyMatched
}
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return ErrorPolicySkipped
}
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
return ErrorPolicyTempUnscheduled
}
return ErrorPolicyNone
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
@@ -597,6 +623,10 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
}
}
// 同时清除模型级别限流
if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil {
slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err)
}
return nil
}

View File

@@ -0,0 +1,318 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ============ shuffleWithinSortGroups 测试 ============
func TestShuffleWithinSortGroups_Empty(t *testing.T) {
shuffleWithinSortGroups(nil)
shuffleWithinSortGroups([]accountWithLoad{})
}
func TestShuffleWithinSortGroups_SingleElement(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
shuffleWithinSortGroups(accounts)
require.Equal(t, int64(1), accounts[0].account.ID)
}
func TestShuffleWithinSortGroups_DifferentGroups_OrderPreserved(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 3, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
// 每个元素都属于不同组Priority 或 LoadRate 或 LastUsedAt 不同),顺序不变
for i := 0; i < 20; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
require.Equal(t, int64(1), cpy[0].account.ID)
require.Equal(t, int64(2), cpy[1].account.ID)
require.Equal(t, int64(3), cpy[2].account.ID)
}
}
func TestShuffleWithinSortGroups_SameGroup_Shuffled(t *testing.T) {
now := time.Now()
// 同一秒的时间戳视为同一组
sameSecond := time.Unix(now.Unix(), 0)
sameSecond2 := time.Unix(now.Unix(), 500_000_000) // 同一秒但不同纳秒
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &sameSecond2}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
// 多次执行,验证所有 ID 都出现在第一个位置(说明确实被打乱了)
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
seen[cpy[0].account.ID] = true
// 无论怎么打乱,所有 ID 都应在候选中
ids := map[int64]bool{}
for _, a := range cpy {
ids[a.account.ID] = true
}
require.True(t, ids[1] && ids[2] && ids[3])
}
// 至少 2 个不同的 ID 出现在首位(随机性验证)
require.GreaterOrEqual(t, len(seen), 2, "shuffle should produce different orderings")
}
func TestShuffleWithinSortGroups_NilLastUsedAt_SameGroup(t *testing.T) {
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
{account: &Account{ID: 3, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}},
}
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
seen[cpy[0].account.ID] = true
}
require.GreaterOrEqual(t, len(seen), 2, "nil LastUsedAt accounts should be shuffled")
}
func TestShuffleWithinSortGroups_MixedGroups(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
sameAsNow := time.Unix(now.Unix(), 0)
// 组1: Priority=1, LoadRate=10, LastUsedAt=earlier (ID 1) — 单元素组
// 组2: Priority=1, LoadRate=20, LastUsedAt=now (ID 2, 3) — 双元素组
// 组3: Priority=2, LoadRate=10, LastUsedAt=earlier (ID 4) — 单元素组
accounts := []accountWithLoad{
{account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
{account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameAsNow}, loadInfo: &AccountLoadInfo{LoadRate: 20}},
{account: &Account{ID: 4, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}},
}
for i := 0; i < 20; i++ {
cpy := make([]accountWithLoad, len(accounts))
copy(cpy, accounts)
shuffleWithinSortGroups(cpy)
// 组间顺序不变
require.Equal(t, int64(1), cpy[0].account.ID, "group 1 position fixed")
require.Equal(t, int64(4), cpy[3].account.ID, "group 3 position fixed")
// 组2 内部可以打乱,但仍在位置 1 和 2
mid := map[int64]bool{cpy[1].account.ID: true, cpy[2].account.ID: true}
require.True(t, mid[2] && mid[3], "group 2 elements should stay in positions 1-2")
}
}
// ============ shuffleWithinPriorityAndLastUsed 测试 ============
func TestShuffleWithinPriorityAndLastUsed_Empty(t *testing.T) {
shuffleWithinPriorityAndLastUsed(nil)
shuffleWithinPriorityAndLastUsed([]*Account{})
}
func TestShuffleWithinPriorityAndLastUsed_SingleElement(t *testing.T) {
accounts := []*Account{{ID: 1, Priority: 1}}
shuffleWithinPriorityAndLastUsed(accounts)
require.Equal(t, int64(1), accounts[0].ID)
}
func TestShuffleWithinPriorityAndLastUsed_SameGroup_Shuffled(t *testing.T) {
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 1, LastUsedAt: nil},
{ID: 3, Priority: 1, LastUsedAt: nil},
}
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
shuffleWithinPriorityAndLastUsed(cpy)
seen[cpy[0].ID] = true
}
require.GreaterOrEqual(t, len(seen), 2, "same group should be shuffled")
}
func TestShuffleWithinPriorityAndLastUsed_DifferentPriority_OrderPreserved(t *testing.T) {
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 2, LastUsedAt: nil},
{ID: 3, Priority: 3, LastUsedAt: nil},
}
for i := 0; i < 20; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
shuffleWithinPriorityAndLastUsed(cpy)
require.Equal(t, int64(1), cpy[0].ID)
require.Equal(t, int64(2), cpy[1].ID)
require.Equal(t, int64(3), cpy[2].ID)
}
}
func TestShuffleWithinPriorityAndLastUsed_DifferentLastUsedAt_OrderPreserved(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 1, LastUsedAt: &earlier},
{ID: 3, Priority: 1, LastUsedAt: &now},
}
for i := 0; i < 20; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
shuffleWithinPriorityAndLastUsed(cpy)
require.Equal(t, int64(1), cpy[0].ID)
require.Equal(t, int64(2), cpy[1].ID)
require.Equal(t, int64(3), cpy[2].ID)
}
}
// ============ sameLastUsedAt 测试 ============
func TestSameLastUsedAt(t *testing.T) {
now := time.Now()
sameSecond := time.Unix(now.Unix(), 0)
sameSecondDiffNano := time.Unix(now.Unix(), 999_999_999)
differentSecond := now.Add(1 * time.Second)
t.Run("both nil", func(t *testing.T) {
require.True(t, sameLastUsedAt(nil, nil))
})
t.Run("one nil one not", func(t *testing.T) {
require.False(t, sameLastUsedAt(nil, &now))
require.False(t, sameLastUsedAt(&now, nil))
})
t.Run("same second different nanoseconds", func(t *testing.T) {
require.True(t, sameLastUsedAt(&sameSecond, &sameSecondDiffNano))
})
t.Run("different seconds", func(t *testing.T) {
require.False(t, sameLastUsedAt(&now, &differentSecond))
})
t.Run("exact same time", func(t *testing.T) {
require.True(t, sameLastUsedAt(&now, &now))
})
}
// ============ sameAccountWithLoadGroup 测试 ============
func TestSameAccountWithLoadGroup(t *testing.T) {
now := time.Now()
sameSecond := time.Unix(now.Unix(), 0)
t.Run("same group", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
require.True(t, sameAccountWithLoadGroup(a, b))
})
t.Run("different priority", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 2, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
require.False(t, sameAccountWithLoadGroup(a, b))
})
t.Run("different load rate", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}}
require.False(t, sameAccountWithLoadGroup(a, b))
})
t.Run("different last used at", func(t *testing.T) {
later := now.Add(1 * time.Second)
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &later}, loadInfo: &AccountLoadInfo{LoadRate: 10}}
require.False(t, sameAccountWithLoadGroup(a, b))
})
t.Run("both nil LastUsedAt", func(t *testing.T) {
a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}
b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}
require.True(t, sameAccountWithLoadGroup(a, b))
})
}
// ============ sameAccountGroup 测试 ============
func TestSameAccountGroup(t *testing.T) {
now := time.Now()
t.Run("same group", func(t *testing.T) {
a := &Account{Priority: 1, LastUsedAt: nil}
b := &Account{Priority: 1, LastUsedAt: nil}
require.True(t, sameAccountGroup(a, b))
})
t.Run("different priority", func(t *testing.T) {
a := &Account{Priority: 1, LastUsedAt: nil}
b := &Account{Priority: 2, LastUsedAt: nil}
require.False(t, sameAccountGroup(a, b))
})
t.Run("different LastUsedAt", func(t *testing.T) {
later := now.Add(1 * time.Second)
a := &Account{Priority: 1, LastUsedAt: &now}
b := &Account{Priority: 1, LastUsedAt: &later}
require.False(t, sameAccountGroup(a, b))
})
}
// ============ sortAccountsByPriorityAndLastUsed 集成随机化测试 ============
func TestSortAccountsByPriorityAndLastUsed_WithShuffle(t *testing.T) {
t.Run("same priority and nil LastUsedAt are shuffled", func(t *testing.T) {
accounts := []*Account{
{ID: 1, Priority: 1, LastUsedAt: nil},
{ID: 2, Priority: 1, LastUsedAt: nil},
{ID: 3, Priority: 1, LastUsedAt: nil},
}
seen := map[int64]bool{}
for i := 0; i < 100; i++ {
cpy := make([]*Account, len(accounts))
copy(cpy, accounts)
sortAccountsByPriorityAndLastUsed(cpy, false)
seen[cpy[0].ID] = true
}
require.GreaterOrEqual(t, len(seen), 2, "identical sort keys should produce different orderings after shuffle")
})
t.Run("different priorities still sorted correctly", func(t *testing.T) {
now := time.Now()
accounts := []*Account{
{ID: 3, Priority: 3, LastUsedAt: &now},
{ID: 1, Priority: 1, LastUsedAt: &now},
{ID: 2, Priority: 2, LastUsedAt: &now},
}
sortAccountsByPriorityAndLastUsed(accounts, false)
require.Equal(t, int64(1), accounts[0].ID)
require.Equal(t, int64(2), accounts[1].ID)
require.Equal(t, int64(3), accounts[2].ID)
})
}

View File

@@ -23,8 +23,7 @@ import (
// - 临时不可调度且未过期:清理
// - 临时不可调度已过期:不清理
// - 正常可调度状态:不清理
// - 模型限流超过阈值:清理
// - 模型限流未超过阈值:不清理
// - 模型限流(任意时长):清理
//
// TestShouldClearStickySession tests the sticky session clearing logic.
// Verifies correct behavior for various account states including:
@@ -35,9 +34,9 @@ func TestShouldClearStickySession(t *testing.T) {
future := now.Add(1 * time.Hour)
past := now.Add(-1 * time.Hour)
// 短限流时间(低于阈值,不应清除粘性会话)
// 短限流时间(有限流即清除粘性会话)
shortRateLimitReset := now.Add(5 * time.Second).Format(time.RFC3339)
// 长限流时间(超过阈值,应清除粘性会话)
// 长限流时间(有限流即清除粘性会话)
longRateLimitReset := now.Add(30 * time.Second).Format(time.RFC3339)
tests := []struct {
@@ -53,7 +52,7 @@ func TestShouldClearStickySession(t *testing.T) {
{name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, requestedModel: "", want: true},
{name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, requestedModel: "", want: false},
{name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, requestedModel: "", want: false},
// 模型限流测试
// 模型限流测试:有限流即清除
{
name: "model rate limited short duration",
account: &Account{
@@ -68,7 +67,7 @@ func TestShouldClearStickySession(t *testing.T) {
},
},
requestedModel: "claude-sonnet-4",
want: false, // 低于阈值,不清除
want: true, // 有限流即清除
},
{
name: "model rate limited long duration",
@@ -84,7 +83,7 @@ func TestShouldClearStickySession(t *testing.T) {
},
},
requestedModel: "claude-sonnet-4",
want: true, // 超过阈值,清除
want: true, // 有限流即清除
},
{
name: "model rate limited different model",

Some files were not shown because too many files have changed in this diff Show More